summaryrefslogtreecommitdiffstats
path: root/src/shader_recompiler/backend/spirv/emit_context.cpp
diff options
context:
space:
mode:
authorReinUsesLisp <reinuseslisp@airmail.cc>2021-04-16 03:46:11 +0200
committerameerj <52414509+ameerj@users.noreply.github.com>2021-07-23 03:51:27 +0200
commit183855e396cc6918d36fbf3e38ea426e934b4e3e (patch)
treea665794753520c09a1d34d8a086352894ec1cb72 /src/shader_recompiler/backend/spirv/emit_context.cpp
parentshader: Mark atomic instructions as writes (diff)
downloadyuzu-183855e396cc6918d36fbf3e38ea426e934b4e3e.tar
yuzu-183855e396cc6918d36fbf3e38ea426e934b4e3e.tar.gz
yuzu-183855e396cc6918d36fbf3e38ea426e934b4e3e.tar.bz2
yuzu-183855e396cc6918d36fbf3e38ea426e934b4e3e.tar.lz
yuzu-183855e396cc6918d36fbf3e38ea426e934b4e3e.tar.xz
yuzu-183855e396cc6918d36fbf3e38ea426e934b4e3e.tar.zst
yuzu-183855e396cc6918d36fbf3e38ea426e934b4e3e.zip
Diffstat (limited to '')
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.cpp147
1 files changed, 105 insertions, 42 deletions
diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index 032cf5e03..067f61613 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -125,19 +125,36 @@ u32 NumVertices(InputTopology input_topology) {
throw InvalidArgument("Invalid input topology {}", input_topology);
}
-Id DefineInput(EmitContext& ctx, Id type, std::optional<spv::BuiltIn> builtin = std::nullopt) {
- if (ctx.stage == Stage::Geometry) {
- const u32 num_vertices{NumVertices(ctx.profile.input_topology)};
- type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], num_vertices));
+Id DefineInput(EmitContext& ctx, Id type, bool per_invocation,
+ std::optional<spv::BuiltIn> builtin = std::nullopt) {
+ switch (ctx.stage) {
+ case Stage::TessellationControl:
+ case Stage::TessellationEval:
+ if (per_invocation) {
+ type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], 32u));
+ }
+ break;
+ case Stage::Geometry:
+ if (per_invocation) {
+ const u32 num_vertices{NumVertices(ctx.profile.input_topology)};
+ type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], num_vertices));
+ }
+ break;
+ default:
+ break;
}
return DefineVariable(ctx, type, builtin, spv::StorageClass::Input);
}
-Id DefineOutput(EmitContext& ctx, Id type, std::optional<spv::BuiltIn> builtin = std::nullopt) {
+Id DefineOutput(EmitContext& ctx, Id type, std::optional<u32> invocations,
+ std::optional<spv::BuiltIn> builtin = std::nullopt) {
+ if (invocations && ctx.stage == Stage::TessellationControl) {
+ type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], *invocations));
+ }
return DefineVariable(ctx, type, builtin, spv::StorageClass::Output);
}
-void DefineGenericOutput(EmitContext& ctx, size_t index) {
+void DefineGenericOutput(EmitContext& ctx, size_t index, std::optional<u32> invocations) {
static constexpr std::string_view swizzle{"xyzw"};
const size_t base_attr_index{static_cast<size_t>(IR::Attribute::Generic0X) + index * 4};
u32 element{0};
@@ -150,7 +167,7 @@ void DefineGenericOutput(EmitContext& ctx, size_t index) {
}
const u32 num_components{xfb_varying ? xfb_varying->components : remainder};
- const Id id{DefineOutput(ctx, ctx.F32[num_components])};
+ const Id id{DefineOutput(ctx, ctx.F32[num_components], invocations)};
ctx.Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
if (element > 0) {
ctx.Decorate(id, spv::Decoration::Component, element);
@@ -161,10 +178,10 @@ void DefineGenericOutput(EmitContext& ctx, size_t index) {
ctx.Decorate(id, spv::Decoration::Offset, xfb_varying->offset);
}
if (num_components < 4 || element > 0) {
- ctx.Name(id, fmt::format("out_attr{}", index));
- } else {
const std::string_view subswizzle{swizzle.substr(element, num_components)};
ctx.Name(id, fmt::format("out_attr{}_{}", index, subswizzle));
+ } else {
+ ctx.Name(id, fmt::format("out_attr{}", index));
}
const GenericElementInfo info{
.id = id,
@@ -383,7 +400,7 @@ EmitContext::EmitContext(const Profile& profile_, IR::Program& program, u32& bin
AddCapability(spv::Capability::Shader);
DefineCommonTypes(program.info);
DefineCommonConstants();
- DefineInterfaces(program.info);
+ DefineInterfaces(program);
DefineLocalMemory(program);
DefineSharedMemory(program);
DefineSharedMemoryFunctions(program);
@@ -472,9 +489,9 @@ void EmitContext::DefineCommonConstants() {
f32_zero_value = Constant(F32[1], 0.0f);
}
-void EmitContext::DefineInterfaces(const Info& info) {
- DefineInputs(info);
- DefineOutputs(info);
+void EmitContext::DefineInterfaces(const IR::Program& program) {
+ DefineInputs(program.info);
+ DefineOutputs(program);
}
void EmitContext::DefineLocalMemory(const IR::Program& program) {
@@ -972,26 +989,29 @@ void EmitContext::DefineLabels(IR::Program& program) {
void EmitContext::DefineInputs(const Info& info) {
if (info.uses_workgroup_id) {
- workgroup_id = DefineInput(*this, U32[3], spv::BuiltIn::WorkgroupId);
+ workgroup_id = DefineInput(*this, U32[3], false, spv::BuiltIn::WorkgroupId);
}
if (info.uses_local_invocation_id) {
- local_invocation_id = DefineInput(*this, U32[3], spv::BuiltIn::LocalInvocationId);
+ local_invocation_id = DefineInput(*this, U32[3], false, spv::BuiltIn::LocalInvocationId);
+ }
+ if (info.uses_invocation_id) {
+ invocation_id = DefineInput(*this, U32[1], false, spv::BuiltIn::InvocationId);
}
if (info.uses_is_helper_invocation) {
- is_helper_invocation = DefineInput(*this, U1, spv::BuiltIn::HelperInvocation);
+ is_helper_invocation = DefineInput(*this, U1, false, spv::BuiltIn::HelperInvocation);
}
if (info.uses_subgroup_mask) {
- subgroup_mask_eq = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupEqMaskKHR);
- subgroup_mask_lt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLtMaskKHR);
- subgroup_mask_le = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLeMaskKHR);
- subgroup_mask_gt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGtMaskKHR);
- subgroup_mask_ge = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGeMaskKHR);
+ subgroup_mask_eq = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupEqMaskKHR);
+ subgroup_mask_lt = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupLtMaskKHR);
+ subgroup_mask_le = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupLeMaskKHR);
+ subgroup_mask_gt = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupGtMaskKHR);
+ subgroup_mask_ge = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupGeMaskKHR);
}
if (info.uses_subgroup_invocation_id ||
(profile.warp_size_potentially_larger_than_guest &&
(info.uses_subgroup_vote || info.uses_subgroup_mask))) {
subgroup_local_invocation_id =
- DefineInput(*this, U32[1], spv::BuiltIn::SubgroupLocalInvocationId);
+ DefineInput(*this, U32[1], false, spv::BuiltIn::SubgroupLocalInvocationId);
}
if (info.uses_fswzadd) {
const Id f32_one{Constant(F32[1], 1.0f)};
@@ -1004,29 +1024,32 @@ void EmitContext::DefineInputs(const Info& info) {
if (info.loads_position) {
const bool is_fragment{stage != Stage::Fragment};
const spv::BuiltIn built_in{is_fragment ? spv::BuiltIn::Position : spv::BuiltIn::FragCoord};
- input_position = DefineInput(*this, F32[4], built_in);
+ input_position = DefineInput(*this, F32[4], true, built_in);
}
if (info.loads_instance_id) {
if (profile.support_vertex_instance_id) {
- instance_id = DefineInput(*this, U32[1], spv::BuiltIn::InstanceId);
+ instance_id = DefineInput(*this, U32[1], true, spv::BuiltIn::InstanceId);
} else {
- instance_index = DefineInput(*this, U32[1], spv::BuiltIn::InstanceIndex);
- base_instance = DefineInput(*this, U32[1], spv::BuiltIn::BaseInstance);
+ instance_index = DefineInput(*this, U32[1], true, spv::BuiltIn::InstanceIndex);
+ base_instance = DefineInput(*this, U32[1], true, spv::BuiltIn::BaseInstance);
}
}
if (info.loads_vertex_id) {
if (profile.support_vertex_instance_id) {
- vertex_id = DefineInput(*this, U32[1], spv::BuiltIn::VertexId);
+ vertex_id = DefineInput(*this, U32[1], true, spv::BuiltIn::VertexId);
} else {
- vertex_index = DefineInput(*this, U32[1], spv::BuiltIn::VertexIndex);
- base_vertex = DefineInput(*this, U32[1], spv::BuiltIn::BaseVertex);
+ vertex_index = DefineInput(*this, U32[1], true, spv::BuiltIn::VertexIndex);
+ base_vertex = DefineInput(*this, U32[1], true, spv::BuiltIn::BaseVertex);
}
}
if (info.loads_front_face) {
- front_face = DefineInput(*this, U1, spv::BuiltIn::FrontFacing);
+ front_face = DefineInput(*this, U1, true, spv::BuiltIn::FrontFacing);
}
if (info.loads_point_coord) {
- point_coord = DefineInput(*this, F32[2], spv::BuiltIn::PointCoord);
+ point_coord = DefineInput(*this, F32[2], true, spv::BuiltIn::PointCoord);
+ }
+ if (info.loads_tess_coord) {
+ tess_coord = DefineInput(*this, F32[3], false, spv::BuiltIn::TessCoord);
}
for (size_t index = 0; index < info.input_generics.size(); ++index) {
const InputVarying generic{info.input_generics[index]};
@@ -1038,7 +1061,7 @@ void EmitContext::DefineInputs(const Info& info) {
continue;
}
const Id type{GetAttributeType(*this, input_type)};
- const Id id{DefineInput(*this, type)};
+ const Id id{DefineInput(*this, type, true)};
Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
Name(id, fmt::format("in_attr{}", index));
input_generics[index] = id;
@@ -1059,58 +1082,98 @@ void EmitContext::DefineInputs(const Info& info) {
break;
}
}
+ if (stage == Stage::TessellationEval) {
+ for (size_t index = 0; index < info.uses_patches.size(); ++index) {
+ if (!info.uses_patches[index]) {
+ continue;
+ }
+ const Id id{DefineInput(*this, F32[4], false)};
+ Decorate(id, spv::Decoration::Patch);
+ Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
+ patches[index] = id;
+ }
+ }
}
-void EmitContext::DefineOutputs(const Info& info) {
+void EmitContext::DefineOutputs(const IR::Program& program) {
+ const Info& info{program.info};
+ const std::optional<u32> invocations{program.invocations};
if (info.stores_position || stage == Stage::VertexB) {
- output_position = DefineOutput(*this, F32[4], spv::BuiltIn::Position);
+ output_position = DefineOutput(*this, F32[4], invocations, spv::BuiltIn::Position);
}
if (info.stores_point_size || profile.fixed_state_point_size) {
if (stage == Stage::Fragment) {
throw NotImplementedException("Storing PointSize in fragment stage");
}
- output_point_size = DefineOutput(*this, F32[1], spv::BuiltIn::PointSize);
+ output_point_size = DefineOutput(*this, F32[1], invocations, spv::BuiltIn::PointSize);
}
if (info.stores_clip_distance) {
if (stage == Stage::Fragment) {
throw NotImplementedException("Storing ClipDistance in fragment stage");
}
const Id type{TypeArray(F32[1], Constant(U32[1], 8U))};
- clip_distances = DefineOutput(*this, type, spv::BuiltIn::ClipDistance);
+ clip_distances = DefineOutput(*this, type, invocations, spv::BuiltIn::ClipDistance);
}
if (info.stores_layer &&
(profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) {
if (stage == Stage::Fragment) {
throw NotImplementedException("Storing Layer in fragment stage");
}
- layer = DefineOutput(*this, U32[1], spv::BuiltIn::Layer);
+ layer = DefineOutput(*this, U32[1], invocations, spv::BuiltIn::Layer);
}
if (info.stores_viewport_index &&
(profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) {
if (stage == Stage::Fragment) {
throw NotImplementedException("Storing ViewportIndex in fragment stage");
}
- viewport_index = DefineOutput(*this, U32[1], spv::BuiltIn::ViewportIndex);
+ viewport_index = DefineOutput(*this, U32[1], invocations, spv::BuiltIn::ViewportIndex);
}
for (size_t index = 0; index < info.stores_generics.size(); ++index) {
if (info.stores_generics[index]) {
- DefineGenericOutput(*this, index);
+ DefineGenericOutput(*this, index, invocations);
}
}
- if (stage == Stage::Fragment) {
+ switch (stage) {
+ case Stage::TessellationControl:
+ if (info.stores_tess_level_outer) {
+ const Id type{TypeArray(F32[1], Constant(U32[1], 4))};
+ output_tess_level_outer =
+ DefineOutput(*this, type, std::nullopt, spv::BuiltIn::TessLevelOuter);
+ Decorate(output_tess_level_outer, spv::Decoration::Patch);
+ }
+ if (info.stores_tess_level_inner) {
+ const Id type{TypeArray(F32[1], Constant(U32[1], 2))};
+ output_tess_level_inner =
+ DefineOutput(*this, type, std::nullopt, spv::BuiltIn::TessLevelInner);
+ Decorate(output_tess_level_inner, spv::Decoration::Patch);
+ }
+ for (size_t index = 0; index < info.uses_patches.size(); ++index) {
+ if (!info.uses_patches[index]) {
+ continue;
+ }
+ const Id id{DefineOutput(*this, F32[4], std::nullopt)};
+ Decorate(id, spv::Decoration::Patch);
+ Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
+ patches[index] = id;
+ }
+ break;
+ case Stage::Fragment:
for (u32 index = 0; index < 8; ++index) {
if (!info.stores_frag_color[index]) {
continue;
}
- frag_color[index] = DefineOutput(*this, F32[4]);
+ frag_color[index] = DefineOutput(*this, F32[4], std::nullopt);
Decorate(frag_color[index], spv::Decoration::Location, index);
Name(frag_color[index], fmt::format("frag_color{}", index));
}
if (info.stores_frag_depth) {
- frag_depth = DefineOutput(*this, F32[1]);
+ frag_depth = DefineOutput(*this, F32[1], std::nullopt);
Decorate(frag_depth, spv::Decoration::BuiltIn, spv::BuiltIn::FragDepth);
Name(frag_depth, "frag_depth");
}
+ break;
+ default:
+ break;
}
}