From 58914796c06662f4f901a4f195057ee1327cf055 Mon Sep 17 00:00:00 2001 From: ReinUsesLisp Date: Tue, 16 Feb 2021 19:50:23 -0300 Subject: shader: Add XMAD multiplication folding optimization --- .../ir_opt/constant_propagation_pass.cpp | 82 ++++++++++++++++++++-- 1 file changed, 77 insertions(+), 5 deletions(-) (limited to 'src') diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp index f1ad16d60..9eb61b54c 100644 --- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp +++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp @@ -9,6 +9,7 @@ #include "common/bit_cast.h" #include "common/bit_util.h" #include "shader_recompiler/exception.h" +#include "shader_recompiler/frontend/ir/ir_emitter.h" #include "shader_recompiler/frontend/ir/microinstruction.h" #include "shader_recompiler/ir_opt/passes.h" @@ -99,8 +100,71 @@ void FoldGetPred(IR::Inst& inst) { } } +/// Replaces the pattern generated by two XMAD multiplications +bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) { + /* + * We are looking for this pattern: + * %rhs_bfe = BitFieldUExtract %factor_a, #0, #16 (uses: 1) + * %rhs_mul = IMul32 %rhs_bfe, %factor_b (uses: 1) + * %lhs_bfe = BitFieldUExtract %factor_a, #16, #16 (uses: 1) + * %rhs_mul = IMul32 %lhs_bfe, %factor_b (uses: 1) + * %lhs_shl = ShiftLeftLogical32 %rhs_mul, #16 (uses: 1) + * %result = IAdd32 %lhs_shl, %rhs_mul (uses: 10) + * + * And replacing it with + * %result = IMul32 %factor_a, %factor_b + * + * This optimization has been proven safe by LLVM and MSVC. + */ + const IR::Value lhs_arg{inst.Arg(0)}; + const IR::Value rhs_arg{inst.Arg(1)}; + if (lhs_arg.IsImmediate() || rhs_arg.IsImmediate()) { + return false; + } + IR::Inst* const lhs_shl{lhs_arg.InstRecursive()}; + if (lhs_shl->Opcode() != IR::Opcode::ShiftLeftLogical32 || lhs_shl->Arg(1) != IR::Value{16U}) { + return false; + } + if (lhs_shl->Arg(0).IsImmediate()) { + return false; + } + IR::Inst* const lhs_mul{lhs_shl->Arg(0).InstRecursive()}; + IR::Inst* const rhs_mul{rhs_arg.InstRecursive()}; + if (lhs_mul->Opcode() != IR::Opcode::IMul32 || rhs_mul->Opcode() != IR::Opcode::IMul32) { + return false; + } + if (lhs_mul->Arg(1).Resolve() != rhs_mul->Arg(1).Resolve()) { + return false; + } + const IR::U32 factor_b{lhs_mul->Arg(1)}; + if (lhs_mul->Arg(0).IsImmediate() || rhs_mul->Arg(0).IsImmediate()) { + return false; + } + IR::Inst* const lhs_bfe{lhs_mul->Arg(0).InstRecursive()}; + IR::Inst* const rhs_bfe{rhs_mul->Arg(0).InstRecursive()}; + if (lhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) { + return false; + } + if (rhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) { + return false; + } + if (lhs_bfe->Arg(1) != IR::Value{16U} || lhs_bfe->Arg(2) != IR::Value{16U}) { + return false; + } + if (rhs_bfe->Arg(1) != IR::Value{0U} || rhs_bfe->Arg(2) != IR::Value{16U}) { + return false; + } + if (lhs_bfe->Arg(0).Resolve() != rhs_bfe->Arg(0).Resolve()) { + return false; + } + const IR::U32 factor_a{lhs_bfe->Arg(0)}; + IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)}; + inst.ReplaceUsesWith(ir.IMul(factor_a, factor_b)); + return true; +} + template -void FoldAdd(IR::Inst& inst) { +void FoldAdd(IR::Block& block, IR::Inst& inst) { if (inst.HasAssociatedPseudoOperation()) { return; } @@ -110,6 +174,12 @@ void FoldAdd(IR::Inst& inst) { const IR::Value rhs{inst.Arg(1)}; if (rhs.IsImmediate() && Arg(rhs) == 0) { inst.ReplaceUsesWith(inst.Arg(0)); + return; + } + if constexpr (std::is_same_v) { + if (FoldXmadMultiply(block, inst)) { + return; + } } } @@ -244,14 +314,14 @@ void FoldBranchConditional(IR::Inst& inst) { } } -void ConstantPropagation(IR::Inst& inst) { +void ConstantPropagation(IR::Block& block, IR::Inst& inst) { switch (inst.Opcode()) { case IR::Opcode::GetRegister: return FoldGetRegister(inst); case IR::Opcode::GetPred: return FoldGetPred(inst); case IR::Opcode::IAdd32: - return FoldAdd(inst); + return FoldAdd(block, inst); case IR::Opcode::ISub32: return FoldISub32(inst); case IR::Opcode::BitCastF32U32: @@ -259,7 +329,7 @@ void ConstantPropagation(IR::Inst& inst) { case IR::Opcode::BitCastU32F32: return FoldBitCast(inst, IR::Opcode::BitCastF32U32); case IR::Opcode::IAdd64: - return FoldAdd(inst); + return FoldAdd(block, inst); case IR::Opcode::Select32: return FoldSelect(inst); case IR::Opcode::LogicalAnd: @@ -292,7 +362,9 @@ void ConstantPropagation(IR::Inst& inst) { } // Anonymous namespace void ConstantPropagationPass(IR::Block& block) { - std::ranges::for_each(block, ConstantPropagation); + for (IR::Inst& inst : block) { + ConstantPropagation(block, inst); + } } } // namespace Shader::Optimization -- cgit v1.2.3