summaryrefslogtreecommitdiffstats
path: root/src/shader_recompiler/backend/spirv/emit_context.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/shader_recompiler/backend/spirv/emit_context.cpp')
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.cpp115
1 files changed, 114 insertions, 1 deletions
diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index a8ca33c1d..96d0e9b4d 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -9,6 +9,7 @@
#include <fmt/format.h>
#include "common/common_types.h"
+#include "common/div_ceil.h"
#include "shader_recompiler/backend/spirv/emit_context.h"
namespace Shader::Backend::SPIRV {
@@ -96,11 +97,13 @@ void VectorTypes::Define(Sirit::Module& sirit_ctx, Id base_type, std::string_vie
}
EmitContext::EmitContext(const Profile& profile_, IR::Program& program, u32& binding)
- : Sirit::Module(0x00010000), profile{profile_}, stage{program.stage} {
+ : Sirit::Module(profile_.supported_spirv), profile{profile_}, stage{program.stage} {
AddCapability(spv::Capability::Shader);
DefineCommonTypes(program.info);
DefineCommonConstants();
DefineInterfaces(program.info);
+ DefineLocalMemory(program);
+ DefineSharedMemory(program);
DefineConstantBuffers(program.info, binding);
DefineStorageBuffers(program.info, binding);
DefineTextures(program.info, binding);
@@ -143,6 +146,8 @@ void EmitContext::DefineCommonTypes(const Info& info) {
F32.Define(*this, TypeFloat(32), "f32");
U32.Define(*this, TypeInt(32, false), "u32");
+ private_u32 = Name(TypePointer(spv::StorageClass::Private, U32[1]), "private_u32");
+
input_f32 = Name(TypePointer(spv::StorageClass::Input, F32[1]), "input_f32");
input_u32 = Name(TypePointer(spv::StorageClass::Input, U32[1]), "input_u32");
input_s32 = Name(TypePointer(spv::StorageClass::Input, TypeInt(32, true)), "input_s32");
@@ -184,6 +189,105 @@ void EmitContext::DefineInterfaces(const Info& info) {
DefineOutputs(info);
}
+void EmitContext::DefineLocalMemory(const IR::Program& program) {
+ if (program.local_memory_size == 0) {
+ return;
+ }
+ const u32 num_elements{Common::DivCeil(program.local_memory_size, 4U)};
+ const Id type{TypeArray(U32[1], Constant(U32[1], num_elements))};
+ const Id pointer{TypePointer(spv::StorageClass::Private, type)};
+ local_memory = AddGlobalVariable(pointer, spv::StorageClass::Private);
+ if (profile.supported_spirv >= 0x00010400) {
+ interfaces.push_back(local_memory);
+ }
+}
+
+void EmitContext::DefineSharedMemory(const IR::Program& program) {
+ if (program.shared_memory_size == 0) {
+ return;
+ }
+ const auto make{[&](Id element_type, u32 element_size) {
+ const u32 num_elements{Common::DivCeil(program.shared_memory_size, element_size)};
+ const Id array_type{TypeArray(element_type, Constant(U32[1], num_elements))};
+ Decorate(array_type, spv::Decoration::ArrayStride, element_size);
+
+ const Id struct_type{TypeStruct(array_type)};
+ MemberDecorate(struct_type, 0U, spv::Decoration::Offset, 0U);
+ Decorate(struct_type, spv::Decoration::Block);
+
+ const Id pointer{TypePointer(spv::StorageClass::Workgroup, struct_type)};
+ const Id element_pointer{TypePointer(spv::StorageClass::Workgroup, element_type)};
+ const Id variable{AddGlobalVariable(pointer, spv::StorageClass::Workgroup)};
+ Decorate(variable, spv::Decoration::Aliased);
+ interfaces.push_back(variable);
+
+ return std::make_pair(variable, element_pointer);
+ }};
+ if (profile.support_explicit_workgroup_layout) {
+ AddExtension("SPV_KHR_workgroup_memory_explicit_layout");
+ AddCapability(spv::Capability::WorkgroupMemoryExplicitLayoutKHR);
+ if (program.info.uses_int8) {
+ AddCapability(spv::Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
+ std::tie(shared_memory_u8, shared_u8) = make(U8, 1);
+ }
+ if (program.info.uses_int16) {
+ AddCapability(spv::Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
+ std::tie(shared_memory_u16, shared_u16) = make(U16, 2);
+ }
+ std::tie(shared_memory_u32, shared_u32) = make(U32[1], 4);
+ std::tie(shared_memory_u32x2, shared_u32x2) = make(U32[2], 8);
+ std::tie(shared_memory_u32x4, shared_u32x4) = make(U32[4], 16);
+ }
+ const u32 num_elements{Common::DivCeil(program.shared_memory_size, 4U)};
+ const Id type{TypeArray(U32[1], Constant(U32[1], num_elements))};
+ const Id pointer_type{TypePointer(spv::StorageClass::Workgroup, type)};
+ shared_u32 = TypePointer(spv::StorageClass::Workgroup, U32[1]);
+ shared_memory_u32 = AddGlobalVariable(pointer_type, spv::StorageClass::Workgroup);
+ interfaces.push_back(shared_memory_u32);
+
+ const Id func_type{TypeFunction(void_id, U32[1], U32[1])};
+ const auto make_function{[&](u32 mask, u32 size) {
+ const Id loop_header{OpLabel()};
+ const Id continue_block{OpLabel()};
+ const Id merge_block{OpLabel()};
+
+ const Id func{OpFunction(void_id, spv::FunctionControlMask::MaskNone, func_type)};
+ const Id offset{OpFunctionParameter(U32[1])};
+ const Id insert_value{OpFunctionParameter(U32[1])};
+ AddLabel();
+ OpBranch(loop_header);
+
+ AddLabel(loop_header);
+ const Id word_offset{OpShiftRightArithmetic(U32[1], offset, Constant(U32[1], 2U))};
+ const Id shift_offset{OpShiftLeftLogical(U32[1], offset, Constant(U32[1], 3U))};
+ const Id bit_offset{OpBitwiseAnd(U32[1], shift_offset, Constant(U32[1], mask))};
+ const Id count{Constant(U32[1], size)};
+ OpLoopMerge(merge_block, continue_block, spv::LoopControlMask::MaskNone);
+ OpBranch(continue_block);
+
+ AddLabel(continue_block);
+ const Id word_pointer{OpAccessChain(shared_u32, shared_memory_u32, word_offset)};
+ const Id old_value{OpLoad(U32[1], word_pointer)};
+ const Id new_value{OpBitFieldInsert(U32[1], old_value, insert_value, bit_offset, count)};
+ const Id atomic_res{OpAtomicCompareExchange(U32[1], word_pointer, Constant(U32[1], 1U),
+ u32_zero_value, u32_zero_value, new_value,
+ old_value)};
+ const Id success{OpIEqual(U1, atomic_res, old_value)};
+ OpBranchConditional(success, merge_block, loop_header);
+
+ AddLabel(merge_block);
+ OpReturn();
+ OpFunctionEnd();
+ return func;
+ }};
+ if (program.info.uses_int8) {
+ shared_store_u8_func = make_function(24, 8);
+ }
+ if (program.info.uses_int16) {
+ shared_store_u16_func = make_function(16, 16);
+ }
+}
+
void EmitContext::DefineConstantBuffers(const Info& info, u32& binding) {
if (info.constant_buffer_descriptors.empty()) {
return;
@@ -234,6 +338,9 @@ void EmitContext::DefineStorageBuffers(const Info& info, u32& binding) {
Decorate(id, spv::Decoration::Binding, binding);
Decorate(id, spv::Decoration::DescriptorSet, 0U);
Name(id, fmt::format("ssbo{}", index));
+ if (profile.supported_spirv >= 0x00010400) {
+ interfaces.push_back(id);
+ }
std::fill_n(ssbos.data() + index, desc.count, id);
index += desc.count;
binding += desc.count;
@@ -261,6 +368,9 @@ void EmitContext::DefineTextures(const Info& info, u32& binding) {
.image_type{image_type},
});
}
+ if (profile.supported_spirv >= 0x00010400) {
+ interfaces.push_back(id);
+ }
binding += desc.count;
}
}
@@ -363,6 +473,9 @@ void EmitContext::DefineConstantBuffers(const Info& info, Id UniformDefinitions:
for (size_t i = 0; i < desc.count; ++i) {
cbufs[desc.index + i].*member_type = id;
}
+ if (profile.supported_spirv >= 0x00010400) {
+ interfaces.push_back(id);
+ }
binding += desc.count;
}
}