From ca9901867e91cd0be0cc75094ee8ea2fb2767c47 Mon Sep 17 00:00:00 2001 From: Fernando Sahmkow Date: Sun, 25 Aug 2019 15:32:00 -0400 Subject: vk_shader_compiler: Implement the decompiler in SPIR-V --- .../renderer_vulkan/vk_shader_decompiler.cpp | 298 +++++++++++++++++++-- 1 file changed, 276 insertions(+), 22 deletions(-) (limited to 'src/video_core/renderer_vulkan') diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp index 77fc58f25..505e49570 100644 --- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp +++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp @@ -88,6 +88,9 @@ bool IsPrecise(Operation operand) { } // namespace +class ASTDecompiler; +class ExprDecompiler; + class SPIRVDecompiler : public Sirit::Module { public: explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderStage stage) @@ -97,27 +100,7 @@ public: AddExtension("SPV_KHR_variable_pointers"); } - void Decompile() { - AllocateBindings(); - AllocateLabels(); - - DeclareVertex(); - DeclareGeometry(); - DeclareFragment(); - DeclareRegisters(); - DeclarePredicates(); - DeclareLocalMemory(); - DeclareInternalFlags(); - DeclareInputAttributes(); - DeclareOutputAttributes(); - DeclareConstantBuffers(); - DeclareGlobalBuffers(); - DeclareSamplers(); - - execute_function = - Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void))); - Emit(OpLabel()); - + void DecompileBranchMode() { const u32 first_address = ir.GetBasicBlocks().begin()->first; const Id loop_label = OpLabel("loop"); const Id merge_label = OpLabel("merge"); @@ -174,6 +157,43 @@ public: Emit(continue_label); Emit(OpBranch(loop_label)); Emit(merge_label); + } + + void DecompileAST(); + + void Decompile() { + const bool is_fully_decompiled = ir.IsDecompiled(); + AllocateBindings(); + if (!is_fully_decompiled) { + AllocateLabels(); + } + + DeclareVertex(); + DeclareGeometry(); + DeclareFragment(); + DeclareRegisters(); + DeclarePredicates(); + if (is_fully_decompiled) { + DeclareFlowVariables(); + } + DeclareLocalMemory(); + DeclareInternalFlags(); + DeclareInputAttributes(); + DeclareOutputAttributes(); + DeclareConstantBuffers(); + DeclareGlobalBuffers(); + DeclareSamplers(); + + execute_function = + Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void))); + Emit(OpLabel()); + + if (is_fully_decompiled) { + DecompileAST(); + } else { + DecompileBranchMode(); + } + Emit(OpReturn()); Emit(OpFunctionEnd()); } @@ -206,6 +226,9 @@ public: } private: + friend class ASTDecompiler; + friend class ExprDecompiler; + static constexpr auto INTERNAL_FLAGS_COUNT = static_cast(InternalFlag::Amount); void AllocateBindings() { @@ -294,6 +317,14 @@ private: } } + void DeclareFlowVariables() { + for (u32 i = 0; i < ir.GetASTNumVariables(); i++) { + const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false); + Name(id, fmt::format("flow_var_{}", static_cast(i))); + flow_variables.emplace(i, AddGlobalVariable(id)); + } + } + void DeclareLocalMemory() { if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) { const auto element_count = static_cast(Common::AlignUp(local_memory_size, 4) / 4); @@ -1019,7 +1050,7 @@ private: return {}; } - Id Exit(Operation operation) { + Id PreExit() { switch (stage) { case ShaderStage::Vertex: { // TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't @@ -1067,6 +1098,11 @@ private: } } + return {}; + } + + Id Exit(Operation operation) { + PreExit(); BranchingOp([&]() { Emit(OpReturn()); }); return {}; } @@ -1545,6 +1581,7 @@ private: Id per_vertex{}; std::map registers; std::map predicates; + std::map flow_variables; Id local_memory{}; std::array internal_flags{}; std::map input_attributes; @@ -1580,6 +1617,223 @@ private: std::map labels; }; +class ExprDecompiler { +public: + ExprDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {} + + void operator()(VideoCommon::Shader::ExprAnd& expr) { + const Id type_def = decomp.GetTypeDefinition(Type::Bool); + const Id op1 = Visit(expr.operand1); + const Id op2 = Visit(expr.operand2); + current_id = decomp.Emit(decomp.OpLogicalAnd(type_def, op1, op2)); + } + + void operator()(VideoCommon::Shader::ExprOr& expr) { + const Id type_def = decomp.GetTypeDefinition(Type::Bool); + const Id op1 = Visit(expr.operand1); + const Id op2 = Visit(expr.operand2); + current_id = decomp.Emit(decomp.OpLogicalOr(type_def, op1, op2)); + } + + void operator()(VideoCommon::Shader::ExprNot& expr) { + const Id type_def = decomp.GetTypeDefinition(Type::Bool); + const Id op1 = Visit(expr.operand1); + current_id = decomp.Emit(decomp.OpLogicalNot(type_def, op1)); + } + + void operator()(VideoCommon::Shader::ExprPredicate& expr) { + auto pred = static_cast(expr.predicate); + current_id = decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred))); + } + + void operator()(VideoCommon::Shader::ExprCondCode& expr) { + Node cc = decomp.ir.GetConditionCode(expr.cc); + Id target; + + if (const auto pred = std::get_if(&*cc)) { + const auto index = pred->GetIndex(); + switch (index) { + case Tegra::Shader::Pred::NeverExecute: + target = decomp.v_false; + case Tegra::Shader::Pred::UnusedIndex: + target = decomp.v_true; + default: + target = decomp.predicates.at(index); + } + } else if (const auto flag = std::get_if(&*cc)) { + target = decomp.internal_flags.at(static_cast(flag->GetFlag())); + } + current_id = decomp.Emit(decomp.OpLoad(decomp.t_bool, target)); + } + + void operator()(VideoCommon::Shader::ExprVar& expr) { + current_id = decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index))); + } + + void operator()(VideoCommon::Shader::ExprBoolean& expr) { + current_id = expr.value ? decomp.v_true : decomp.v_false; + } + + Id GetResult() { + return current_id; + } + + Id Visit(VideoCommon::Shader::Expr& node) { + std::visit(*this, *node); + return current_id; + } + +private: + Id current_id; + SPIRVDecompiler& decomp; +}; + +class ASTDecompiler { +public: + ASTDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {} + + void operator()(VideoCommon::Shader::ASTProgram& ast) { + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + } + + void operator()(VideoCommon::Shader::ASTIfThen& ast) { + ExprDecompiler expr_parser{decomp}; + const Id condition = expr_parser.Visit(ast.condition); + const Id then_label = decomp.OpLabel(); + const Id endif_label = decomp.OpLabel(); + decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone)); + decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label)); + decomp.Emit(then_label); + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + decomp.Emit(endif_label); + } + + void operator()(VideoCommon::Shader::ASTIfElse& ast) { + UNREACHABLE(); + } + + void operator()(VideoCommon::Shader::ASTBlockEncoded& ast) { + UNREACHABLE(); + } + + void operator()(VideoCommon::Shader::ASTBlockDecoded& ast) { + decomp.VisitBasicBlock(ast.nodes); + } + + void operator()(VideoCommon::Shader::ASTVarSet& ast) { + ExprDecompiler expr_parser{decomp}; + const Id condition = expr_parser.Visit(ast.condition); + decomp.Emit(decomp.OpStore(decomp.flow_variables.at(ast.index), condition)); + } + + void operator()(VideoCommon::Shader::ASTLabel& ast) { + // Do nothing + } + + void operator()(VideoCommon::Shader::ASTGoto& ast) { + UNREACHABLE(); + } + + void operator()(VideoCommon::Shader::ASTDoWhile& ast) { + const Id loop_label = decomp.OpLabel(); + const Id endloop_label = decomp.OpLabel(); + const Id loop_start_block = decomp.OpLabel(); + const Id loop_end_block = decomp.OpLabel(); + current_loop_exit = endloop_label; + decomp.Emit(loop_label); + decomp.Emit(decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone)); + decomp.Emit(decomp.OpBranch(loop_start_block)); + decomp.Emit(loop_start_block); + ASTNode current = ast.nodes.GetFirst(); + while (current) { + Visit(current); + current = current->GetNext(); + } + decomp.Emit(decomp.OpBranch(loop_end_block)); + decomp.Emit(loop_end_block); + ExprDecompiler expr_parser{decomp}; + const Id condition = expr_parser.Visit(ast.condition); + decomp.Emit(decomp.OpBranchConditional(condition, loop_label, endloop_label)); + decomp.Emit(endloop_label); + } + + void operator()(VideoCommon::Shader::ASTReturn& ast) { + bool is_true = VideoCommon::Shader::ExprIsTrue(ast.condition); + if (!is_true) { + ExprDecompiler expr_parser{decomp}; + const Id condition = expr_parser.Visit(ast.condition); + const Id then_label = decomp.OpLabel(); + const Id endif_label = decomp.OpLabel(); + decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone)); + decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label)); + decomp.Emit(then_label); + if (ast.kills) { + decomp.Emit(decomp.OpKill()); + } else { + decomp.PreExit(); + decomp.Emit(decomp.OpReturn()); + } + decomp.Emit(endif_label); + } else { + decomp.Emit(decomp.OpLabel()); + if (ast.kills) { + decomp.Emit(decomp.OpKill()); + } else { + decomp.PreExit(); + decomp.Emit(decomp.OpReturn()); + } + decomp.Emit(decomp.OpLabel()); + } + } + + void operator()(VideoCommon::Shader::ASTBreak& ast) { + bool is_true = VideoCommon::Shader::ExprIsTrue(ast.condition); + if (!is_true) { + ExprDecompiler expr_parser{decomp}; + const Id condition = expr_parser.Visit(ast.condition); + const Id then_label = decomp.OpLabel(); + const Id endif_label = decomp.OpLabel(); + decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone)); + decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label)); + decomp.Emit(then_label); + decomp.Emit(decomp.OpBranch(current_loop_exit)); + decomp.Emit(endif_label); + } else { + decomp.Emit(decomp.OpLabel()); + decomp.Emit(decomp.OpBranch(current_loop_exit)); + decomp.Emit(decomp.OpLabel()); + } + } + + void Visit(VideoCommon::Shader::ASTNode& node) { + std::visit(*this, *node->GetInnerData()); + } + +private: + SPIRVDecompiler& decomp; + Id current_loop_exit; +}; + +void SPIRVDecompiler::DecompileAST() { + u32 num_flow_variables = ir.GetASTNumVariables(); + for (u32 i = 0; i < num_flow_variables; i++) { + const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false); + Name(id, fmt::format("flow_var_{}", i)); + flow_variables.emplace(i, AddGlobalVariable(id)); + } + ASTDecompiler decompiler{*this}; + VideoCommon::Shader::ASTNode program = ir.GetASTProgram(); + decompiler.Visit(program); +} + DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir, Maxwell::ShaderStage stage) { auto decompiler = std::make_unique(device, ir, stage); -- cgit v1.2.3