summaryrefslogtreecommitdiffstats
path: root/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/video_core/renderer_vulkan/vk_shader_decompiler.cpp359
1 files changed, 320 insertions, 39 deletions
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
index 77fc58f25..8bcd04221 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
@@ -88,6 +88,9 @@ bool IsPrecise(Operation operand) {
} // namespace
+class ASTDecompiler;
+class ExprDecompiler;
+
class SPIRVDecompiler : public Sirit::Module {
public:
explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderStage stage)
@@ -97,27 +100,7 @@ public:
AddExtension("SPV_KHR_variable_pointers");
}
- void Decompile() {
- AllocateBindings();
- AllocateLabels();
-
- DeclareVertex();
- DeclareGeometry();
- DeclareFragment();
- DeclareRegisters();
- DeclarePredicates();
- DeclareLocalMemory();
- DeclareInternalFlags();
- DeclareInputAttributes();
- DeclareOutputAttributes();
- DeclareConstantBuffers();
- DeclareGlobalBuffers();
- DeclareSamplers();
-
- execute_function =
- Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void)));
- Emit(OpLabel());
-
+ void DecompileBranchMode() {
const u32 first_address = ir.GetBasicBlocks().begin()->first;
const Id loop_label = OpLabel("loop");
const Id merge_label = OpLabel("merge");
@@ -174,6 +157,43 @@ public:
Emit(continue_label);
Emit(OpBranch(loop_label));
Emit(merge_label);
+ }
+
+ void DecompileAST();
+
+ void Decompile() {
+ const bool is_fully_decompiled = ir.IsDecompiled();
+ AllocateBindings();
+ if (!is_fully_decompiled) {
+ AllocateLabels();
+ }
+
+ DeclareVertex();
+ DeclareGeometry();
+ DeclareFragment();
+ DeclareRegisters();
+ DeclarePredicates();
+ if (is_fully_decompiled) {
+ DeclareFlowVariables();
+ }
+ DeclareLocalMemory();
+ DeclareInternalFlags();
+ DeclareInputAttributes();
+ DeclareOutputAttributes();
+ DeclareConstantBuffers();
+ DeclareGlobalBuffers();
+ DeclareSamplers();
+
+ execute_function =
+ Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void)));
+ Emit(OpLabel());
+
+ if (is_fully_decompiled) {
+ DecompileAST();
+ } else {
+ DecompileBranchMode();
+ }
+
Emit(OpReturn());
Emit(OpFunctionEnd());
}
@@ -206,6 +226,9 @@ public:
}
private:
+ friend class ASTDecompiler;
+ friend class ExprDecompiler;
+
static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
void AllocateBindings() {
@@ -294,6 +317,14 @@ private:
}
}
+ void DeclareFlowVariables() {
+ for (u32 i = 0; i < ir.GetASTNumVariables(); i++) {
+ const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
+ Name(id, fmt::format("flow_var_{}", static_cast<u32>(i)));
+ flow_variables.emplace(i, AddGlobalVariable(id));
+ }
+ }
+
void DeclareLocalMemory() {
if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) {
const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4);
@@ -615,9 +646,15 @@ private:
Emit(OpBranchConditional(condition, true_label, skip_label));
Emit(true_label);
+ ++conditional_nest_count;
VisitBasicBlock(conditional->GetCode());
+ --conditional_nest_count;
- Emit(OpBranch(skip_label));
+ if (inside_branch == 0) {
+ Emit(OpBranch(skip_label));
+ } else {
+ inside_branch--;
+ }
Emit(skip_label);
return {};
@@ -980,7 +1017,11 @@ private:
UNIMPLEMENTED_IF(!target);
Emit(OpStore(jmp_to, Constant(t_uint, target->GetValue())));
- BranchingOp([&]() { Emit(OpBranch(continue_label)); });
+ Emit(OpBranch(continue_label));
+ inside_branch = conditional_nest_count;
+ if (conditional_nest_count == 0) {
+ Emit(OpLabel());
+ }
return {};
}
@@ -988,7 +1029,11 @@ private:
const Id op_a = VisitOperand<Type::Uint>(operation, 0);
Emit(OpStore(jmp_to, op_a));
- BranchingOp([&]() { Emit(OpBranch(continue_label)); });
+ Emit(OpBranch(continue_label));
+ inside_branch = conditional_nest_count;
+ if (conditional_nest_count == 0) {
+ Emit(OpLabel());
+ }
return {};
}
@@ -1015,11 +1060,15 @@ private:
Emit(OpStore(flow_stack_top, previous));
Emit(OpStore(jmp_to, target));
- BranchingOp([&]() { Emit(OpBranch(continue_label)); });
+ Emit(OpBranch(continue_label));
+ inside_branch = conditional_nest_count;
+ if (conditional_nest_count == 0) {
+ Emit(OpLabel());
+ }
return {};
}
- Id Exit(Operation operation) {
+ Id PreExit() {
switch (stage) {
case ShaderStage::Vertex: {
// TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't
@@ -1067,12 +1116,35 @@ private:
}
}
- BranchingOp([&]() { Emit(OpReturn()); });
+ return {};
+ }
+
+ Id Exit(Operation operation) {
+ PreExit();
+ inside_branch = conditional_nest_count;
+ if (conditional_nest_count > 0) {
+ Emit(OpReturn());
+ } else {
+ const Id dummy = OpLabel();
+ Emit(OpBranch(dummy));
+ Emit(dummy);
+ Emit(OpReturn());
+ Emit(OpLabel());
+ }
return {};
}
Id Discard(Operation operation) {
- BranchingOp([&]() { Emit(OpKill()); });
+ inside_branch = conditional_nest_count;
+ if (conditional_nest_count > 0) {
+ Emit(OpKill());
+ } else {
+ const Id dummy = OpLabel();
+ Emit(OpBranch(dummy));
+ Emit(dummy);
+ Emit(OpKill());
+ Emit(OpLabel());
+ }
return {};
}
@@ -1267,17 +1339,6 @@ private:
return {};
}
- void BranchingOp(std::function<void()> call) {
- const Id true_label = OpLabel();
- const Id skip_label = OpLabel();
- Emit(OpSelectionMerge(skip_label, spv::SelectionControlMask::Flatten));
- Emit(OpBranchConditional(v_true, true_label, skip_label, 1, 0));
- Emit(true_label);
- call();
-
- Emit(skip_label);
- }
-
std::tuple<Id, Id> CreateFlowStack() {
// TODO(Rodrigo): Figure out the actual depth of the flow stack, for now it seems unlikely
// that shaders will use 20 nested SSYs and PBKs.
@@ -1483,6 +1544,8 @@ private:
const ShaderIR& ir;
const ShaderStage stage;
const Tegra::Shader::Header header;
+ u64 conditional_nest_count{};
+ u64 inside_branch{};
const Id t_void = Name(TypeVoid(), "void");
@@ -1545,6 +1608,7 @@ private:
Id per_vertex{};
std::map<u32, Id> registers;
std::map<Tegra::Shader::Pred, Id> predicates;
+ std::map<u32, Id> flow_variables;
Id local_memory{};
std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{};
std::map<Attribute::Index, Id> input_attributes;
@@ -1580,6 +1644,223 @@ private:
std::map<u32, Id> labels;
};
+class ExprDecompiler {
+public:
+ explicit ExprDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {}
+
+ Id operator()(VideoCommon::Shader::ExprAnd& expr) {
+ const Id type_def = decomp.GetTypeDefinition(Type::Bool);
+ const Id op1 = Visit(expr.operand1);
+ const Id op2 = Visit(expr.operand2);
+ return decomp.Emit(decomp.OpLogicalAnd(type_def, op1, op2));
+ }
+
+ Id operator()(VideoCommon::Shader::ExprOr& expr) {
+ const Id type_def = decomp.GetTypeDefinition(Type::Bool);
+ const Id op1 = Visit(expr.operand1);
+ const Id op2 = Visit(expr.operand2);
+ return decomp.Emit(decomp.OpLogicalOr(type_def, op1, op2));
+ }
+
+ Id operator()(VideoCommon::Shader::ExprNot& expr) {
+ const Id type_def = decomp.GetTypeDefinition(Type::Bool);
+ const Id op1 = Visit(expr.operand1);
+ return decomp.Emit(decomp.OpLogicalNot(type_def, op1));
+ }
+
+ Id operator()(VideoCommon::Shader::ExprPredicate& expr) {
+ const auto pred = static_cast<Tegra::Shader::Pred>(expr.predicate);
+ return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred)));
+ }
+
+ Id operator()(VideoCommon::Shader::ExprCondCode& expr) {
+ const Node cc = decomp.ir.GetConditionCode(expr.cc);
+ Id target;
+
+ if (const auto pred = std::get_if<PredicateNode>(&*cc)) {
+ const auto index = pred->GetIndex();
+ switch (index) {
+ case Tegra::Shader::Pred::NeverExecute:
+ target = decomp.v_false;
+ case Tegra::Shader::Pred::UnusedIndex:
+ target = decomp.v_true;
+ default:
+ target = decomp.predicates.at(index);
+ }
+ } else if (const auto flag = std::get_if<InternalFlagNode>(&*cc)) {
+ target = decomp.internal_flags.at(static_cast<u32>(flag->GetFlag()));
+ }
+ return decomp.Emit(decomp.OpLoad(decomp.t_bool, target));
+ }
+
+ Id operator()(VideoCommon::Shader::ExprVar& expr) {
+ return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index)));
+ }
+
+ Id operator()(VideoCommon::Shader::ExprBoolean& expr) {
+ return expr.value ? decomp.v_true : decomp.v_false;
+ }
+
+ Id Visit(VideoCommon::Shader::Expr& node) {
+ return std::visit(*this, *node);
+ }
+
+private:
+ SPIRVDecompiler& decomp;
+};
+
+class ASTDecompiler {
+public:
+ explicit ASTDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {}
+
+ void operator()(VideoCommon::Shader::ASTProgram& ast) {
+ ASTNode current = ast.nodes.GetFirst();
+ while (current) {
+ Visit(current);
+ current = current->GetNext();
+ }
+ }
+
+ void operator()(VideoCommon::Shader::ASTIfThen& ast) {
+ ExprDecompiler expr_parser{decomp};
+ const Id condition = expr_parser.Visit(ast.condition);
+ const Id then_label = decomp.OpLabel();
+ const Id endif_label = decomp.OpLabel();
+ decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
+ decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
+ decomp.Emit(then_label);
+ ASTNode current = ast.nodes.GetFirst();
+ while (current) {
+ Visit(current);
+ current = current->GetNext();
+ }
+ decomp.Emit(decomp.OpBranch(endif_label));
+ decomp.Emit(endif_label);
+ }
+
+ void operator()(VideoCommon::Shader::ASTIfElse& ast) {
+ UNREACHABLE();
+ }
+
+ void operator()(VideoCommon::Shader::ASTBlockEncoded& ast) {
+ UNREACHABLE();
+ }
+
+ void operator()(VideoCommon::Shader::ASTBlockDecoded& ast) {
+ decomp.VisitBasicBlock(ast.nodes);
+ }
+
+ void operator()(VideoCommon::Shader::ASTVarSet& ast) {
+ ExprDecompiler expr_parser{decomp};
+ const Id condition = expr_parser.Visit(ast.condition);
+ decomp.Emit(decomp.OpStore(decomp.flow_variables.at(ast.index), condition));
+ }
+
+ void operator()(VideoCommon::Shader::ASTLabel& ast) {
+ // Do nothing
+ }
+
+ void operator()(VideoCommon::Shader::ASTGoto& ast) {
+ UNREACHABLE();
+ }
+
+ void operator()(VideoCommon::Shader::ASTDoWhile& ast) {
+ const Id loop_label = decomp.OpLabel();
+ const Id endloop_label = decomp.OpLabel();
+ const Id loop_start_block = decomp.OpLabel();
+ const Id loop_end_block = decomp.OpLabel();
+ current_loop_exit = endloop_label;
+ decomp.Emit(decomp.OpBranch(loop_label));
+ decomp.Emit(loop_label);
+ decomp.Emit(
+ decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone));
+ decomp.Emit(decomp.OpBranch(loop_start_block));
+ decomp.Emit(loop_start_block);
+ ASTNode current = ast.nodes.GetFirst();
+ while (current) {
+ Visit(current);
+ current = current->GetNext();
+ }
+ ExprDecompiler expr_parser{decomp};
+ const Id condition = expr_parser.Visit(ast.condition);
+ decomp.Emit(decomp.OpBranchConditional(condition, loop_label, endloop_label));
+ decomp.Emit(endloop_label);
+ }
+
+ void operator()(VideoCommon::Shader::ASTReturn& ast) {
+ if (!VideoCommon::Shader::ExprIsTrue(ast.condition)) {
+ ExprDecompiler expr_parser{decomp};
+ const Id condition = expr_parser.Visit(ast.condition);
+ const Id then_label = decomp.OpLabel();
+ const Id endif_label = decomp.OpLabel();
+ decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
+ decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
+ decomp.Emit(then_label);
+ if (ast.kills) {
+ decomp.Emit(decomp.OpKill());
+ } else {
+ decomp.PreExit();
+ decomp.Emit(decomp.OpReturn());
+ }
+ decomp.Emit(endif_label);
+ } else {
+ const Id next_block = decomp.OpLabel();
+ decomp.Emit(decomp.OpBranch(next_block));
+ decomp.Emit(next_block);
+ if (ast.kills) {
+ decomp.Emit(decomp.OpKill());
+ } else {
+ decomp.PreExit();
+ decomp.Emit(decomp.OpReturn());
+ }
+ decomp.Emit(decomp.OpLabel());
+ }
+ }
+
+ void operator()(VideoCommon::Shader::ASTBreak& ast) {
+ if (!VideoCommon::Shader::ExprIsTrue(ast.condition)) {
+ ExprDecompiler expr_parser{decomp};
+ const Id condition = expr_parser.Visit(ast.condition);
+ const Id then_label = decomp.OpLabel();
+ const Id endif_label = decomp.OpLabel();
+ decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
+ decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
+ decomp.Emit(then_label);
+ decomp.Emit(decomp.OpBranch(current_loop_exit));
+ decomp.Emit(endif_label);
+ } else {
+ const Id next_block = decomp.OpLabel();
+ decomp.Emit(decomp.OpBranch(next_block));
+ decomp.Emit(next_block);
+ decomp.Emit(decomp.OpBranch(current_loop_exit));
+ decomp.Emit(decomp.OpLabel());
+ }
+ }
+
+ void Visit(VideoCommon::Shader::ASTNode& node) {
+ std::visit(*this, *node->GetInnerData());
+ }
+
+private:
+ SPIRVDecompiler& decomp;
+ Id current_loop_exit{};
+};
+
+void SPIRVDecompiler::DecompileAST() {
+ const u32 num_flow_variables = ir.GetASTNumVariables();
+ for (u32 i = 0; i < num_flow_variables; i++) {
+ const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
+ Name(id, fmt::format("flow_var_{}", i));
+ flow_variables.emplace(i, AddGlobalVariable(id));
+ }
+ ASTDecompiler decompiler{*this};
+ VideoCommon::Shader::ASTNode program = ir.GetASTProgram();
+ decompiler.Visit(program);
+ const Id next_block = OpLabel();
+ Emit(OpBranch(next_block));
+ Emit(next_block);
+}
+
DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
Maxwell::ShaderStage stage) {
auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage);