Merge pull request #111013 from stuartcarnie/shader_container_ext

Renderer: Move `reflect_spirv` to `RenderingShaderContainer`
This commit is contained in:
Thaddeus Crews
2025-10-01 17:54:09 -05:00
12 changed files with 425 additions and 391 deletions

View File

@@ -268,7 +268,7 @@ uint32_t RenderingShaderContainerD3D12::_to_bytes_footer_extra_data(uint8_t *p_b
}
#if NIR_ENABLED
bool RenderingShaderContainerD3D12::_convert_spirv_to_nir(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv, const nir_shader_compiler_options *p_compiler_options, HashMap<int, nir_shader *> &r_stages_nir_shaders, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed) {
bool RenderingShaderContainerD3D12::_convert_spirv_to_nir(Span<ReflectedShaderStage> p_spirv, const nir_shader_compiler_options *p_compiler_options, HashMap<int, nir_shader *> &r_stages_nir_shaders, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed) {
r_stages_processed.clear();
dxil_spirv_runtime_conf dxil_runtime_conf = {};
@@ -287,7 +287,7 @@ bool RenderingShaderContainerD3D12::_convert_spirv_to_nir(const Vector<Rendering
dxil_runtime_conf.inferred_read_only_images_as_srvs = false;
// Translate SPIR-V to NIR.
for (int64_t i = 0; i < p_spirv.size(); i++) {
for (uint64_t i = 0; i < p_spirv.size(); i++) {
RenderingDeviceCommons::ShaderStage stage = p_spirv[i].shader_stage;
RenderingDeviceCommons::ShaderStage stage_flag = (RenderingDeviceCommons::ShaderStage)(1 << stage);
r_stages.push_back(stage);
@@ -302,9 +302,10 @@ bool RenderingShaderContainerD3D12::_convert_spirv_to_nir(const Vector<Rendering
MESA_SHADER_COMPUTE, // SHADER_STAGE_COMPUTE
};
Span<uint32_t> code = p_spirv[i].spirv();
nir_shader *shader = spirv_to_nir(
(const uint32_t *)(p_spirv[i].spirv.ptr()),
p_spirv[i].spirv.size() / sizeof(uint32_t),
code.ptr(),
code.size(),
nullptr,
0,
SPIRV_TO_MESA_STAGES[stage],
@@ -429,7 +430,7 @@ bool RenderingShaderContainerD3D12::_convert_nir_to_dxil(const HashMap<int, nir_
return true;
}
bool RenderingShaderContainerD3D12::_convert_spirv_to_dxil(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv, HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_dxil_blobs, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed) {
bool RenderingShaderContainerD3D12::_convert_spirv_to_dxil(Span<ReflectedShaderStage> p_spirv, HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_dxil_blobs, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed) {
r_dxil_blobs.clear();
HashMap<int, nir_shader *> stages_nir_shaders;
@@ -764,7 +765,7 @@ void RenderingShaderContainerD3D12::_nir_report_bitcode_bit_offset(uint64_t p_bi
}
#endif
void RenderingShaderContainerD3D12::_set_from_shader_reflection_post(const String &p_shader_name, const RenderingDeviceCommons::ShaderReflection &p_reflection) {
void RenderingShaderContainerD3D12::_set_from_shader_reflection_post(const RenderingDeviceCommons::ShaderReflection &p_reflection) {
reflection_binding_set_uniforms_data_d3d12.resize(reflection_binding_set_uniforms_data.size());
reflection_specialization_data_d3d12.resize(reflection_specialization_data.size());
@@ -780,7 +781,7 @@ void RenderingShaderContainerD3D12::_set_from_shader_reflection_post(const Strin
}
}
bool RenderingShaderContainerD3D12::_set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) {
bool RenderingShaderContainerD3D12::_set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) {
#if NIR_ENABLED
reflection_data_d3d12.nir_runtime_data_root_param_idx = UINT32_MAX;

View File

@@ -122,9 +122,9 @@ protected:
uint32_t root_signature_crc = 0;
#if NIR_ENABLED
bool _convert_spirv_to_nir(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv, const nir_shader_compiler_options *p_compiler_options, HashMap<int, nir_shader *> &r_stages_nir_shaders, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed);
bool _convert_spirv_to_nir(Span<ReflectedShaderStage> p_spirv, const nir_shader_compiler_options *p_compiler_options, HashMap<int, nir_shader *> &r_stages_nir_shaders, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed);
bool _convert_nir_to_dxil(const HashMap<int, nir_shader *> &p_stages_nir_shaders, BitField<RenderingDeviceCommons::ShaderStage> p_stages_processed, HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_dxil_blobs);
bool _convert_spirv_to_dxil(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv, HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_dxil_blobs, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed);
bool _convert_spirv_to_dxil(Span<ReflectedShaderStage> p_spirv, HashMap<RenderingDeviceCommons::ShaderStage, Vector<uint8_t>> &r_dxil_blobs, Vector<RenderingDeviceCommons::ShaderStage> &r_stages, BitField<RenderingDeviceCommons::ShaderStage> &r_stages_processed);
bool _generate_root_signature(BitField<RenderingDeviceCommons::ShaderStage> p_stages_processed);
// GodotNirCallbacks.
@@ -146,8 +146,8 @@ protected:
virtual uint32_t _to_bytes_reflection_binding_uniform_extra_data(uint8_t *p_bytes, uint32_t p_index) const override;
virtual uint32_t _to_bytes_reflection_specialization_extra_data(uint8_t *p_bytes, uint32_t p_index) const override;
virtual uint32_t _to_bytes_footer_extra_data(uint8_t *p_bytes) const override;
virtual void _set_from_shader_reflection_post(const String &p_shader_name, const RenderingDeviceCommons::ShaderReflection &p_reflection) override;
virtual bool _set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) override;
virtual void _set_from_shader_reflection_post(const RenderingDeviceCommons::ShaderReflection &p_reflection) override;
virtual bool _set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) override;
public:
struct ShaderReflectionD3D12 {

View File

@@ -292,7 +292,7 @@ protected:
virtual uint32_t _format() const override;
virtual uint32_t _format_version() const override;
virtual bool _set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) override;
virtual bool _set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) override;
};
class RenderingShaderContainerFormatMetal : public RenderingShaderContainerFormat {

View File

@@ -252,7 +252,7 @@ Error RenderingShaderContainerMetal::compile_metal_source(const char *p_source,
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunguarded-availability"
bool RenderingShaderContainerMetal::_set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) {
bool RenderingShaderContainerMetal::_set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) {
using namespace spirv_cross;
using spirv_cross::CompilerMSL;
using spirv_cross::Resource;
@@ -353,12 +353,11 @@ bool RenderingShaderContainerMetal::_set_code_from_spirv(const Vector<RenderingD
for (uint32_t i = 0; i < p_spirv.size(); i++) {
StageData &stage_data = mtl_shaders.write[i];
RD::ShaderStageSPIRVData const &v = p_spirv[i];
const ReflectedShaderStage &v = p_spirv[i];
RD::ShaderStage stage = v.shader_stage;
char const *stage_name = RD::SHADER_STAGE_NAMES[stage];
uint32_t const *const ir = reinterpret_cast<uint32_t const *const>(v.spirv.ptr());
size_t word_count = v.spirv.size() / sizeof(uint32_t);
Parser parser(ir, word_count);
Span<uint32_t> spirv = v.spirv();
Parser parser(spirv.ptr(), spirv.size());
try {
parser.parse();
} catch (CompilerError &e) {

View File

@@ -44,21 +44,21 @@ uint32_t RenderingShaderContainerVulkan::_format_version() const {
return FORMAT_VERSION;
}
bool RenderingShaderContainerVulkan::_set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) {
bool RenderingShaderContainerVulkan::_set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) {
PackedByteArray code_bytes;
shaders.resize(p_spirv.size());
for (int64_t i = 0; i < p_spirv.size(); i++) {
for (uint64_t i = 0; i < p_spirv.size(); i++) {
RenderingShaderContainer::Shader &shader = shaders.ptrw()[i];
if (debug_info_enabled) {
// Store SPIR-V as is when debug info is required.
shader.code_compressed_bytes = p_spirv[i].spirv;
shader.code_compressed_bytes = p_spirv[i].spirv_data();
shader.code_compression_flags = 0;
shader.code_decompressed_size = 0;
} else {
// Encode into smolv.
Span<uint8_t> spirv = p_spirv[i].spirv().reinterpret<uint8_t>();
smolv::ByteArray smolv_bytes;
bool smolv_encoded = smolv::Encode(p_spirv[i].spirv.ptr(), p_spirv[i].spirv.size(), smolv_bytes, smolv::kEncodeFlagStripDebugInfo);
bool smolv_encoded = smolv::Encode(spirv.ptr(), spirv.size(), smolv_bytes, smolv::kEncodeFlagStripDebugInfo);
ERR_FAIL_COND_V_MSG(!smolv_encoded, false, "Failed to compress SPIR-V into smolv.");
code_bytes.resize(smolv_bytes.size());

View File

@@ -47,7 +47,7 @@ public:
protected:
virtual uint32_t _format() const override;
virtual uint32_t _format_version() const override;
virtual bool _set_code_from_spirv(const Vector<RenderingDeviceCommons::ShaderStageSPIRVData> &p_spirv) override;
virtual bool _set_code_from_spirv(Span<ReflectedShaderStage> p_spirv) override;
public:
RenderingShaderContainerVulkan(bool p_debug_info_enabled);