summaryrefslogblamecommitdiffstats
path: root/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
blob: d67f08cf9c3ba7c6c47cc2c6c2bf53d1632ac0d3 (plain) (tree)
1
2
3
4
5
6
7
8
9



                                            
                     
                 
              
                   
                      
                        
                  


                       

                        
                             
                          

                                
                                          
                                               
                                             
                                           
                                                 
                                                            
                                   
                                        
                                                 
 


                  
 
                
                                 


                                  

                                    
                                                
                                       
 


                     
                                        












                                                                          
 



                                                     
 


                    

  

                     


                              














                                
                             



                                      
 





                                  
                                                
                                














                                                                                                



































































































































                                                                                            






                                                                          
                                                             



                                                                                         

                                                          
                                        


                                                       


                 
                                                    
       
                                                                                          
                                                                                            
                                                                                           




                                                                                    
                                               






                                                                           
                                                       



                                                          

                                                             
                                                       
 








                                                                                                 




                                                                                                   




                                                                            
                                                      

                                                                          


















                                                                                              

                                                          

                                                                                                 


                                                                                  


                                                                             


                                                          

                                                                                   
                                                                              





                                                                                     
         













                                                                                    
     
 









                             
                                 



































                                                                                        
                                                 










                                                                                                   
                                








                                                                     
                                                     



                                           



                                                                                          

                                                                         

                               



                                                       
 





                                                                                 
 
                              
                                            


                                                                          
 

                                 
 

                                                               




                                                                                             
                                 
         
 




                                 

     


                                
 
                                                                                                
 



                                                                                   
         
     
 



                                                                                                 
 


                                          
         




                                                                                        
 
                                   
                                                                                                
                        


                                                                                                   

     


                                                      
         

                                                                             
 




                                                                                             
 




                                                                                             
 
                                                                                                   

     

                                                   
                   


                                                 
 
                                                                                             


                            
                                            
                   
         

                                                                                       

                                           


                            
                                            
                   
         
 

                                                                               

                         
















                                                                                                 












                                                                                                    









                                                                                            
                                   

                                                                    





                                                                                            







                                                                                      







                                                                                      
                               




















                                                                                                 
 



                                                                
         






                                                                                       


                                 
                                                                           
                                                                         
                                                                   




                                                                                      




























                                                                                               
                                   
                                                          



                                             
                                                                    










                                                                                




                                                              
                                                


















                                                                  




                                                                                                



                                                           


















                                                                                     
                                                   










                                                                                                 




                                                                                           
                                                                                       









                                                                                             


                                                              








                                                                                        
                                                                                          


         
                                                                                                
                                                                                   






                                                          

                                                                   

                                                                                                   
                                                                       





                                                                         
                       

     

                                                                

                                                                                    
                                                                                        


                                                                         
                                             
         


























                                                                                                

     
                                      
                                                      


                                     


                                                          



                                                                                               






                                                                                           

                                                                                       

                                                                         
 

                                                                                               














                                                                                                  


                                                                         






                                                                                             
         


                       








                                                                         


                                                                                                 

     


                                                       
 


                                     
                                                                                                  





                                            
         
     
 






















                                                                                      

          


                                                                                    









                                                                                                    



                                                                                                    

                                                                                                   































                                                                                                   
 
                                            

     
                                               
                                     

                                                                


         
                                        
                                                                        

                                                                                             

                                           





                                                                                        
         
 
                                                            

                                               
                                                   
             

                                                                       
 




                                                                              


                                                                         
 
                                                                        






                                                                          
                                                                


                                         
                                                                 
             

                                       
 
                                                              
                                                    














                                                                                                 

                                

                                                    
                                       
                                                                      
                     


                                                                                            
                                                                                    







                                                                                                  
                 


                                                                                   



                                                                                                    
                                  



                                                                                            
                       


                                                                                                   
                       

                                                                                                   

                                                                                               
                                                          

                                                                                   
                                                      
                                   



                                                                                                  

                                                                               
                                                   

                                                    



                                                                                

                                                                                                 
                                                                            



                                                                                            

                                               
 
                                                              
                                                   

                                                                       
                         
                                                                     



                                                                                                
                    

                                    
                                                                                  






                                                                            





                                                                                                    


                                                                  

                                                                                                   
             

                                                           
 
                                                              
                                                                               
         
 







                                                                                 
                                                                               
         
 





                                                                                                  

                                                                                             

                                           


                                                                                         



                                                                              
 

                                          
                                                    


                                           
                    
                                      
             
                                 
                      
         
 

                                                                    






                      
                                                                                      
                                           
                                                           
                                                        
 
                                                       


                                                            
                                    



                                                                                          
                                            
                                                           

                                                        
 
                                                             


                                                            
                                    



                                                                                              
                                             
                                                           


                                                        
 
                                                                   


                                                            
                                    



                                                                                                  
                                                
                                                           



                                                        
 
                                                                         


                                                            
                                    

     
                                            

                                        
 
                            
                                                            



                                                            
                                                                  
 
                                                                     















                                                                                                 
                                                                             



                                                                                               
                                                              
                                      






                                                                                        























                                                                                                 

                                                        







                                                                                                    






                                                                       
















                                                                                                   
                                                                     




                                                                                      
                                                                 

                                                                     
                                                                 
 

                                                                        
 

                            

         




                                                                           
                                                        


                  




                                                                                         

     





                                                                                
 




                                                                              
 





                                                                                  

     








                                                                                             

     













                                                                                       

     


                                                                                                  

     



















                                                                                  

     


                                                                                                  

     





                                                                                    

     



                                                                             

     
                                                   

                                        

                    
                                                                   










                                                                          
                                                                             


                                                                          
                                            


                  




                                                                       





                                                                                

     









                                                                           

     



                                                                    

     


                                                                                          

     
                                                       

                                                                        
                                                            
         








                                                                               
         








                                                                      
         

                                                 
 













                                                                                  

     



                                                                                   

     

                                                                      




                                                                 
















                                                                




                                                               


                                                                                                 
                    


                                                                                                 


             



                                                                                          
                                                                                          

                                                                  

     







                                                                      

                                      






                                                                

                                                                                                 
         
                                                                                               
                                                                  

     


                                                                      
 




                                                                                       

                                    

                                                                                    


                                                                             

                                                                                   
         
                                                                  

     



                                                                      
 


                                                                    

         
                                                   
                                                
                                                                     












                                                                                   

                                               


                                                                


                                                                                        

     


















                                                                                   

     

























                                                                                      
                                                                                           



                                                                  
                                                       


                                               





                                                                             

     










                                                                  


                  
                                                    



                        
                                                    



                        
                                                    



                        
                                                    



                        
                                                   



                        
                                                    



                        
                                                         



                        




                                                                       






                                                                             
                                                        
         
 

                                                             
                                                                                
                                                   
 
                                                                                        

     






                                                                    
         


                  

                                                    
 




                                      
         


                  

                                                                    
                                                                          


                                                                          
 

                                                             


                  
                                                  
                                                                          










                                                                           
         


                  
                    
                                                                                 






                                                                                            
         
                                            


                                                                                 
                                                       
                 
                                    












                                                                                               

                                                                                  
                     


                                                                                              




                                                                                          
                                                                      
             
         

     
                                          
                  


                                     

                                       



                            
         


                  



                                             

                                       



                            
         


                  

                                      


                  

                                                  


                  

                                                         

     
                                   
                                                
                                                      


                          


                                                                     

     



                                                                     

     


                                                                  
 










                                                                                                   

     




                                                                  

     

                                                       

     



                                                                                 

     











                                                                                                  
                                                                                                   

                                                                          
                                                     



                                 



                                                                                        







                                                                        
                                                               

     

                                              
                        
                                
                         
                                 
                         
                                 
                       
                               
                        
                                
                             



                                     
         

     











                                         
                         
                           

                        
                                               
                             






                                                                       
         

     












                                                  
                          







                                
                         
                       
                                              
                             





























                                                          
         


                     













                                                              

                                                            





                                                          






                                                          











                             









                                                                   
                                                           
                       
                                                   
                        
                                                       
                
                            
                      
         

     






                                                                                                  




                                                                                               














                                                                               















                                                                                 
                                                         











                                                                                          

                                       













                                                                                 
                                      




















                                                                                                    
                                                                









                                                                                 
                                                                           






                                                                            
                                                                 





                                                                   
                                 
                                     
                                  
                                    

                                      






                                                                         


                                                                                          






                                                                                           
                                                                           














                                                                                       





                                                                                                
                                                                  





                                                                                                






                                                 
                                          
 
                                    
                                     
                                         



                                              
 














                                                                        
 
                                 
                                         







                                        
                                       
                                  





                                               

                                       


                                                               
 
                                   
                                         

                                          
      
                                                                                                   
 
                           
                       
                           
                                       
                             
                                         
                                                          























                                                                                               

                                                                                               
                                                                                        

                                                                                        
                                                                                        

                                                                                           
                                                                                           

                                                                                              

                                                                                              

                                                                                        



                                                                                                 
                                      

                                                                                               
                                      

                                                                                            





                                                                                                    
 

                     
                                                                                 
                           
                                                                                                




                                                                                                   

                                                   



                                                                                           


                                             




                       
                                
                                       
                                                 
                                     
                      
                       

                                                          
                                                                        

                                                  


                                               


                        

                       



                                                            










                              


                               
                



                            

                             


                                  

  

                      
                                                                        
 
                                        


                                                                 
                                                       

     
                                       


                                                                 
                                                      

     
                                        

                                                                 
                                                  

     
                                              
                                                                           
                                                                        

     
                                             
                                                                                

     
                                        
                                                                                      

     
                                            
                                                           

     

                                                                     

                                                                              
                                                           

     
                                
                                        


        




                            
                                                                       
 
                                            






                                               
                                           



                                                              


                                                                                  




                                               

                                     

     
                                                            


                      
                                                                  


                      
                                                 


                                          
                                           

                                                              
                                                                       

     
                                                           


                     
                                                          


                      
                                            


                                                     
                                                        
                                          

                                    
                                                                                               

                                          




                                               

                                             

                                                              

                                                                         

     
                                           
                                                              



                                                                  


                                                                                      
                            
                                

                                 
                                  
             
                                         
                
                                                   

                                        
                            
                                

                                 
                                  
             
                                              


         
                                          
                                                              



                                                                  




                                                                                      
                
                                                   



                                               


         
                                     




                                                 
                           


                                      
                                                           




                                                                                  
 

                     
                                               
                                    
                              
 
                                    
































                                                                                                 

 
                                                                                           


                                                                                           

 
                     
// Copyright 2019 yuzu Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.

#include <functional>
#include <limits>
#include <map>
#include <optional>
#include <type_traits>
#include <unordered_map>
#include <utility>

#include <fmt/format.h>

#include <sirit/sirit.h>

#include "common/alignment.h"
#include "common/assert.h"
#include "common/common_types.h"
#include "common/logging/log.h"
#include "video_core/engines/maxwell_3d.h"
#include "video_core/engines/shader_bytecode.h"
#include "video_core/engines/shader_header.h"
#include "video_core/engines/shader_type.h"
#include "video_core/renderer_vulkan/vk_device.h"
#include "video_core/renderer_vulkan/vk_shader_decompiler.h"
#include "video_core/shader/node.h"
#include "video_core/shader/shader_ir.h"
#include "video_core/shader/transform_feedback.h"

namespace Vulkan {

namespace {

using Sirit::Id;
using Tegra::Engines::ShaderType;
using Tegra::Shader::Attribute;
using Tegra::Shader::AttributeUse;
using Tegra::Shader::Register;
using namespace VideoCommon::Shader;

using Maxwell = Tegra::Engines::Maxwell3D::Regs;
using Operation = const OperationNode&;

class ASTDecompiler;
class ExprDecompiler;

// TODO(Rodrigo): Use rasterizer's value
constexpr u32 MaxConstBufferFloats = 0x4000;
constexpr u32 MaxConstBufferElements = MaxConstBufferFloats / 4;

constexpr u32 NumInputPatches = 32; // This value seems to be the standard

enum class Type { Void, Bool, Bool2, Float, Int, Uint, HalfFloat };

class Expression final {
public:
    Expression(Id id, Type type) : id{id}, type{type} {
        ASSERT(type != Type::Void);
    }
    Expression() : type{Type::Void} {}

    Id id{};
    Type type{};
};
static_assert(std::is_standard_layout_v<Expression>);

struct TexelBuffer {
    Id image_type{};
    Id image{};
};

struct SampledImage {
    Id image_type{};
    Id sampler_type{};
    Id sampler_pointer_type{};
    Id variable{};
};

struct StorageImage {
    Id image_type{};
    Id image{};
};

struct AttributeType {
    Type type;
    Id scalar;
    Id vector;
};

struct VertexIndices {
    std::optional<u32> position;
    std::optional<u32> layer;
    std::optional<u32> viewport;
    std::optional<u32> point_size;
    std::optional<u32> clip_distances;
};

struct GenericVaryingDescription {
    Id id = nullptr;
    u32 first_element = 0;
    bool is_scalar = false;
};

spv::Dim GetSamplerDim(const Sampler& sampler) {
    ASSERT(!sampler.IsBuffer());
    switch (sampler.GetType()) {
    case Tegra::Shader::TextureType::Texture1D:
        return spv::Dim::Dim1D;
    case Tegra::Shader::TextureType::Texture2D:
        return spv::Dim::Dim2D;
    case Tegra::Shader::TextureType::Texture3D:
        return spv::Dim::Dim3D;
    case Tegra::Shader::TextureType::TextureCube:
        return spv::Dim::Cube;
    default:
        UNIMPLEMENTED_MSG("Unimplemented sampler type={}", static_cast<u32>(sampler.GetType()));
        return spv::Dim::Dim2D;
    }
}

std::pair<spv::Dim, bool> GetImageDim(const Image& image) {
    switch (image.GetType()) {
    case Tegra::Shader::ImageType::Texture1D:
        return {spv::Dim::Dim1D, false};
    case Tegra::Shader::ImageType::TextureBuffer:
        return {spv::Dim::Buffer, false};
    case Tegra::Shader::ImageType::Texture1DArray:
        return {spv::Dim::Dim1D, true};
    case Tegra::Shader::ImageType::Texture2D:
        return {spv::Dim::Dim2D, false};
    case Tegra::Shader::ImageType::Texture2DArray:
        return {spv::Dim::Dim2D, true};
    case Tegra::Shader::ImageType::Texture3D:
        return {spv::Dim::Dim3D, false};
    default:
        UNIMPLEMENTED_MSG("Unimplemented image type={}", static_cast<u32>(image.GetType()));
        return {spv::Dim::Dim2D, false};
    }
}

/// Returns the number of vertices present in a primitive topology.
u32 GetNumPrimitiveTopologyVertices(Maxwell::PrimitiveTopology primitive_topology) {
    switch (primitive_topology) {
    case Maxwell::PrimitiveTopology::Points:
        return 1;
    case Maxwell::PrimitiveTopology::Lines:
    case Maxwell::PrimitiveTopology::LineLoop:
    case Maxwell::PrimitiveTopology::LineStrip:
        return 2;
    case Maxwell::PrimitiveTopology::Triangles:
    case Maxwell::PrimitiveTopology::TriangleStrip:
    case Maxwell::PrimitiveTopology::TriangleFan:
        return 3;
    case Maxwell::PrimitiveTopology::LinesAdjacency:
    case Maxwell::PrimitiveTopology::LineStripAdjacency:
        return 4;
    case Maxwell::PrimitiveTopology::TrianglesAdjacency:
    case Maxwell::PrimitiveTopology::TriangleStripAdjacency:
        return 6;
    case Maxwell::PrimitiveTopology::Quads:
        UNIMPLEMENTED_MSG("Quads");
        return 3;
    case Maxwell::PrimitiveTopology::QuadStrip:
        UNIMPLEMENTED_MSG("QuadStrip");
        return 3;
    case Maxwell::PrimitiveTopology::Polygon:
        UNIMPLEMENTED_MSG("Polygon");
        return 3;
    case Maxwell::PrimitiveTopology::Patches:
        UNIMPLEMENTED_MSG("Patches");
        return 3;
    default:
        UNREACHABLE();
        return 3;
    }
}

spv::ExecutionMode GetExecutionMode(Maxwell::TessellationPrimitive primitive) {
    switch (primitive) {
    case Maxwell::TessellationPrimitive::Isolines:
        return spv::ExecutionMode::Isolines;
    case Maxwell::TessellationPrimitive::Triangles:
        return spv::ExecutionMode::Triangles;
    case Maxwell::TessellationPrimitive::Quads:
        return spv::ExecutionMode::Quads;
    }
    UNREACHABLE();
    return spv::ExecutionMode::Triangles;
}

spv::ExecutionMode GetExecutionMode(Maxwell::TessellationSpacing spacing) {
    switch (spacing) {
    case Maxwell::TessellationSpacing::Equal:
        return spv::ExecutionMode::SpacingEqual;
    case Maxwell::TessellationSpacing::FractionalOdd:
        return spv::ExecutionMode::SpacingFractionalOdd;
    case Maxwell::TessellationSpacing::FractionalEven:
        return spv::ExecutionMode::SpacingFractionalEven;
    }
    UNREACHABLE();
    return spv::ExecutionMode::SpacingEqual;
}

spv::ExecutionMode GetExecutionMode(Maxwell::PrimitiveTopology input_topology) {
    switch (input_topology) {
    case Maxwell::PrimitiveTopology::Points:
        return spv::ExecutionMode::InputPoints;
    case Maxwell::PrimitiveTopology::Lines:
    case Maxwell::PrimitiveTopology::LineLoop:
    case Maxwell::PrimitiveTopology::LineStrip:
        return spv::ExecutionMode::InputLines;
    case Maxwell::PrimitiveTopology::Triangles:
    case Maxwell::PrimitiveTopology::TriangleStrip:
    case Maxwell::PrimitiveTopology::TriangleFan:
        return spv::ExecutionMode::Triangles;
    case Maxwell::PrimitiveTopology::LinesAdjacency:
    case Maxwell::PrimitiveTopology::LineStripAdjacency:
        return spv::ExecutionMode::InputLinesAdjacency;
    case Maxwell::PrimitiveTopology::TrianglesAdjacency:
    case Maxwell::PrimitiveTopology::TriangleStripAdjacency:
        return spv::ExecutionMode::InputTrianglesAdjacency;
    case Maxwell::PrimitiveTopology::Quads:
        UNIMPLEMENTED_MSG("Quads");
        return spv::ExecutionMode::Triangles;
    case Maxwell::PrimitiveTopology::QuadStrip:
        UNIMPLEMENTED_MSG("QuadStrip");
        return spv::ExecutionMode::Triangles;
    case Maxwell::PrimitiveTopology::Polygon:
        UNIMPLEMENTED_MSG("Polygon");
        return spv::ExecutionMode::Triangles;
    case Maxwell::PrimitiveTopology::Patches:
        UNIMPLEMENTED_MSG("Patches");
        return spv::ExecutionMode::Triangles;
    }
    UNREACHABLE();
    return spv::ExecutionMode::Triangles;
}

spv::ExecutionMode GetExecutionMode(Tegra::Shader::OutputTopology output_topology) {
    switch (output_topology) {
    case Tegra::Shader::OutputTopology::PointList:
        return spv::ExecutionMode::OutputPoints;
    case Tegra::Shader::OutputTopology::LineStrip:
        return spv::ExecutionMode::OutputLineStrip;
    case Tegra::Shader::OutputTopology::TriangleStrip:
        return spv::ExecutionMode::OutputTriangleStrip;
    default:
        UNREACHABLE();
        return spv::ExecutionMode::OutputPoints;
    }
}

/// Returns true if an attribute index is one of the 32 generic attributes
constexpr bool IsGenericAttribute(Attribute::Index attribute) {
    return attribute >= Attribute::Index::Attribute_0 &&
           attribute <= Attribute::Index::Attribute_31;
}

/// Returns the location of a generic attribute
u32 GetGenericAttributeLocation(Attribute::Index attribute) {
    ASSERT(IsGenericAttribute(attribute));
    return static_cast<u32>(attribute) - static_cast<u32>(Attribute::Index::Attribute_0);
}

/// Returns true if an object has to be treated as precise
bool IsPrecise(Operation operand) {
    const auto& meta{operand.GetMeta()};
    if (std::holds_alternative<MetaArithmetic>(meta)) {
        return std::get<MetaArithmetic>(meta).precise;
    }
    return false;
}

class SPIRVDecompiler final : public Sirit::Module {
public:
    explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderType stage,
                             const Registry& registry, const Specialization& specialization)
        : Module(0x00010300), device{device}, ir{ir}, stage{stage}, header{ir.GetHeader()},
          registry{registry}, specialization{specialization} {
        if (stage != ShaderType::Compute) {
            transform_feedback = BuildTransformFeedback(registry.GetGraphicsInfo());
        }

        AddCapability(spv::Capability::Shader);
        AddCapability(spv::Capability::UniformAndStorageBuffer16BitAccess);
        AddCapability(spv::Capability::ImageQuery);
        AddCapability(spv::Capability::Image1D);
        AddCapability(spv::Capability::ImageBuffer);
        AddCapability(spv::Capability::ImageGatherExtended);
        AddCapability(spv::Capability::SampledBuffer);
        AddCapability(spv::Capability::StorageImageWriteWithoutFormat);
        AddCapability(spv::Capability::DrawParameters);
        AddCapability(spv::Capability::SubgroupBallotKHR);
        AddCapability(spv::Capability::SubgroupVoteKHR);
        AddExtension("SPV_KHR_shader_ballot");
        AddExtension("SPV_KHR_subgroup_vote");
        AddExtension("SPV_KHR_storage_buffer_storage_class");
        AddExtension("SPV_KHR_variable_pointers");
        AddExtension("SPV_KHR_shader_draw_parameters");

        if (!transform_feedback.empty()) {
            if (device.IsExtTransformFeedbackSupported()) {
                AddCapability(spv::Capability::TransformFeedback);
            } else {
                LOG_ERROR(Render_Vulkan, "Shader requires transform feedbacks but these are not "
                                         "supported on this device");
            }
        }

        if (ir.UsesLayer() || ir.UsesViewportIndex()) {
            if (ir.UsesViewportIndex()) {
                AddCapability(spv::Capability::MultiViewport);
            }
            if (stage != ShaderType::Geometry && device.IsExtShaderViewportIndexLayerSupported()) {
                AddExtension("SPV_EXT_shader_viewport_index_layer");
                AddCapability(spv::Capability::ShaderViewportIndexLayerEXT);
            }
        }

        if (device.IsFormatlessImageLoadSupported()) {
            AddCapability(spv::Capability::StorageImageReadWithoutFormat);
        }

        if (device.IsFloat16Supported()) {
            AddCapability(spv::Capability::Float16);
        }
        t_scalar_half = Name(TypeFloat(device.IsFloat16Supported() ? 16 : 32), "scalar_half");
        t_half = Name(TypeVector(t_scalar_half, 2), "half");

        const Id main = Decompile();

        switch (stage) {
        case ShaderType::Vertex:
            AddEntryPoint(spv::ExecutionModel::Vertex, main, "main", interfaces);
            break;
        case ShaderType::TesselationControl:
            AddCapability(spv::Capability::Tessellation);
            AddEntryPoint(spv::ExecutionModel::TessellationControl, main, "main", interfaces);
            AddExecutionMode(main, spv::ExecutionMode::OutputVertices,
                             header.common2.threads_per_input_primitive);
            break;
        case ShaderType::TesselationEval: {
            const auto& info = registry.GetGraphicsInfo();
            AddCapability(spv::Capability::Tessellation);
            AddEntryPoint(spv::ExecutionModel::TessellationEvaluation, main, "main", interfaces);
            AddExecutionMode(main, GetExecutionMode(info.tessellation_primitive));
            AddExecutionMode(main, GetExecutionMode(info.tessellation_spacing));
            AddExecutionMode(main, info.tessellation_clockwise
                                       ? spv::ExecutionMode::VertexOrderCw
                                       : spv::ExecutionMode::VertexOrderCcw);
            break;
        }
        case ShaderType::Geometry: {
            const auto& info = registry.GetGraphicsInfo();
            AddCapability(spv::Capability::Geometry);
            AddEntryPoint(spv::ExecutionModel::Geometry, main, "main", interfaces);
            AddExecutionMode(main, GetExecutionMode(info.primitive_topology));
            AddExecutionMode(main, GetExecutionMode(header.common3.output_topology));
            AddExecutionMode(main, spv::ExecutionMode::OutputVertices,
                             header.common4.max_output_vertices);
            // TODO(Rodrigo): Where can we get this info from?
            AddExecutionMode(main, spv::ExecutionMode::Invocations, 1U);
            break;
        }
        case ShaderType::Fragment:
            AddEntryPoint(spv::ExecutionModel::Fragment, main, "main", interfaces);
            AddExecutionMode(main, spv::ExecutionMode::OriginUpperLeft);
            if (header.ps.omap.depth) {
                AddExecutionMode(main, spv::ExecutionMode::DepthReplacing);
            }
            break;
        case ShaderType::Compute:
            const auto workgroup_size = specialization.workgroup_size;
            AddExecutionMode(main, spv::ExecutionMode::LocalSize, workgroup_size[0],
                             workgroup_size[1], workgroup_size[2]);
            AddEntryPoint(spv::ExecutionModel::GLCompute, main, "main", interfaces);
            break;
        }
    }

private:
    Id Decompile() {
        DeclareCommon();
        DeclareVertex();
        DeclareTessControl();
        DeclareTessEval();
        DeclareGeometry();
        DeclareFragment();
        DeclareCompute();
        DeclareRegisters();
        DeclareCustomVariables();
        DeclarePredicates();
        DeclareLocalMemory();
        DeclareSharedMemory();
        DeclareInternalFlags();
        DeclareInputAttributes();
        DeclareOutputAttributes();

        u32 binding = specialization.base_binding;
        binding = DeclareConstantBuffers(binding);
        binding = DeclareGlobalBuffers(binding);
        binding = DeclareTexelBuffers(binding);
        binding = DeclareSamplers(binding);
        binding = DeclareImages(binding);

        const Id main = OpFunction(t_void, {}, TypeFunction(t_void));
        AddLabel();

        if (ir.IsDecompiled()) {
            DeclareFlowVariables();
            DecompileAST();
        } else {
            AllocateLabels();
            DecompileBranchMode();
        }

        OpReturn();
        OpFunctionEnd();

        return main;
    }

    void DefinePrologue() {
        if (stage == ShaderType::Vertex) {
            // Clear Position to avoid reading trash on the Z conversion.
            const auto position_index = out_indices.position.value();
            const Id position = AccessElement(t_out_float4, out_vertex, position_index);
            OpStore(position, v_varying_default);

            if (specialization.point_size) {
                const u32 point_size_index = out_indices.point_size.value();
                const Id out_point_size = AccessElement(t_out_float, out_vertex, point_size_index);
                OpStore(out_point_size, Constant(t_float, *specialization.point_size));
            }
        }
    }

    void DecompileAST();

    void DecompileBranchMode() {
        const u32 first_address = ir.GetBasicBlocks().begin()->first;
        const Id loop_label = OpLabel("loop");
        const Id merge_label = OpLabel("merge");
        const Id dummy_label = OpLabel();
        const Id jump_label = OpLabel();
        continue_label = OpLabel("continue");

        std::vector<Sirit::Literal> literals;
        std::vector<Id> branch_labels;
        for (const auto& [literal, label] : labels) {
            literals.push_back(literal);
            branch_labels.push_back(label);
        }

        jmp_to = OpVariable(TypePointer(spv::StorageClass::Function, t_uint),
                            spv::StorageClass::Function, Constant(t_uint, first_address));
        AddLocalVariable(jmp_to);

        std::tie(ssy_flow_stack, ssy_flow_stack_top) = CreateFlowStack();
        std::tie(pbk_flow_stack, pbk_flow_stack_top) = CreateFlowStack();

        Name(jmp_to, "jmp_to");
        Name(ssy_flow_stack, "ssy_flow_stack");
        Name(ssy_flow_stack_top, "ssy_flow_stack_top");
        Name(pbk_flow_stack, "pbk_flow_stack");
        Name(pbk_flow_stack_top, "pbk_flow_stack_top");

        DefinePrologue();

        OpBranch(loop_label);
        AddLabel(loop_label);
        OpLoopMerge(merge_label, continue_label, spv::LoopControlMask::MaskNone);
        OpBranch(dummy_label);

        AddLabel(dummy_label);
        const Id default_branch = OpLabel();
        const Id jmp_to_load = OpLoad(t_uint, jmp_to);
        OpSelectionMerge(jump_label, spv::SelectionControlMask::MaskNone);
        OpSwitch(jmp_to_load, default_branch, literals, branch_labels);

        AddLabel(default_branch);
        OpReturn();

        for (const auto& [address, bb] : ir.GetBasicBlocks()) {
            AddLabel(labels.at(address));

            VisitBasicBlock(bb);

            const auto next_it = labels.lower_bound(address + 1);
            const Id next_label = next_it != labels.end() ? next_it->second : default_branch;
            OpBranch(next_label);
        }

        AddLabel(jump_label);
        OpBranch(continue_label);
        AddLabel(continue_label);
        OpBranch(loop_label);
        AddLabel(merge_label);
    }

private:
    friend class ASTDecompiler;
    friend class ExprDecompiler;

    static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);

    void AllocateLabels() {
        for (const auto& pair : ir.GetBasicBlocks()) {
            const u32 address = pair.first;
            labels.emplace(address, OpLabel(fmt::format("label_0x{:x}", address)));
        }
    }

    void DeclareCommon() {
        thread_id =
            DeclareInputBuiltIn(spv::BuiltIn::SubgroupLocalInvocationId, t_in_uint, "thread_id");
    }

    void DeclareVertex() {
        if (stage != ShaderType::Vertex) {
            return;
        }
        Id out_vertex_struct;
        std::tie(out_vertex_struct, out_indices) = DeclareVertexStruct();
        const Id vertex_ptr = TypePointer(spv::StorageClass::Output, out_vertex_struct);
        out_vertex = OpVariable(vertex_ptr, spv::StorageClass::Output);
        interfaces.push_back(AddGlobalVariable(Name(out_vertex, "out_vertex")));

        // Declare input attributes
        vertex_index = DeclareInputBuiltIn(spv::BuiltIn::VertexIndex, t_in_int, "vertex_index");
        instance_index =
            DeclareInputBuiltIn(spv::BuiltIn::InstanceIndex, t_in_int, "instance_index");
        base_vertex = DeclareInputBuiltIn(spv::BuiltIn::BaseVertex, t_in_int, "base_vertex");
        base_instance = DeclareInputBuiltIn(spv::BuiltIn::BaseInstance, t_in_int, "base_instance");
    }

    void DeclareTessControl() {
        if (stage != ShaderType::TesselationControl) {
            return;
        }
        DeclareInputVertexArray(NumInputPatches);
        DeclareOutputVertexArray(header.common2.threads_per_input_primitive);

        tess_level_outer = DeclareBuiltIn(
            spv::BuiltIn::TessLevelOuter, spv::StorageClass::Output,
            TypePointer(spv::StorageClass::Output, TypeArray(t_float, Constant(t_uint, 4U))),
            "tess_level_outer");
        Decorate(tess_level_outer, spv::Decoration::Patch);

        tess_level_inner = DeclareBuiltIn(
            spv::BuiltIn::TessLevelInner, spv::StorageClass::Output,
            TypePointer(spv::StorageClass::Output, TypeArray(t_float, Constant(t_uint, 2U))),
            "tess_level_inner");
        Decorate(tess_level_inner, spv::Decoration::Patch);

        invocation_id = DeclareInputBuiltIn(spv::BuiltIn::InvocationId, t_in_int, "invocation_id");
    }

    void DeclareTessEval() {
        if (stage != ShaderType::TesselationEval) {
            return;
        }
        DeclareInputVertexArray(NumInputPatches);
        DeclareOutputVertex();

        tess_coord = DeclareInputBuiltIn(spv::BuiltIn::TessCoord, t_in_float3, "tess_coord");
    }

    void DeclareGeometry() {
        if (stage != ShaderType::Geometry) {
            return;
        }
        const auto& info = registry.GetGraphicsInfo();
        const u32 num_input = GetNumPrimitiveTopologyVertices(info.primitive_topology);
        DeclareInputVertexArray(num_input);
        DeclareOutputVertex();
    }

    void DeclareFragment() {
        if (stage != ShaderType::Fragment) {
            return;
        }

        for (u32 rt = 0; rt < static_cast<u32>(std::size(frag_colors)); ++rt) {
            if (!IsRenderTargetEnabled(rt)) {
                continue;
            }
            const Id id = AddGlobalVariable(OpVariable(t_out_float4, spv::StorageClass::Output));
            Name(id, fmt::format("frag_color{}", rt));
            Decorate(id, spv::Decoration::Location, rt);

            frag_colors[rt] = id;
            interfaces.push_back(id);
        }

        if (header.ps.omap.depth) {
            frag_depth = AddGlobalVariable(OpVariable(t_out_float, spv::StorageClass::Output));
            Name(frag_depth, "frag_depth");
            Decorate(frag_depth, spv::Decoration::BuiltIn,
                     static_cast<u32>(spv::BuiltIn::FragDepth));

            interfaces.push_back(frag_depth);
        }

        frag_coord = DeclareInputBuiltIn(spv::BuiltIn::FragCoord, t_in_float4, "frag_coord");
        front_facing = DeclareInputBuiltIn(spv::BuiltIn::FrontFacing, t_in_bool, "front_facing");
        point_coord = DeclareInputBuiltIn(spv::BuiltIn::PointCoord, t_in_float2, "point_coord");
    }

    void DeclareCompute() {
        if (stage != ShaderType::Compute) {
            return;
        }

        workgroup_id = DeclareInputBuiltIn(spv::BuiltIn::WorkgroupId, t_in_uint3, "workgroup_id");
        local_invocation_id =
            DeclareInputBuiltIn(spv::BuiltIn::LocalInvocationId, t_in_uint3, "local_invocation_id");
    }

    void DeclareRegisters() {
        for (const u32 gpr : ir.GetRegisters()) {
            const Id id = OpVariable(t_prv_float, spv::StorageClass::Private, v_float_zero);
            Name(id, fmt::format("gpr_{}", gpr));
            registers.emplace(gpr, AddGlobalVariable(id));
        }
    }

    void DeclareCustomVariables() {
        const u32 num_custom_variables = ir.GetNumCustomVariables();
        for (u32 i = 0; i < num_custom_variables; ++i) {
            const Id id = OpVariable(t_prv_float, spv::StorageClass::Private, v_float_zero);
            Name(id, fmt::format("custom_var_{}", i));
            custom_variables.emplace(i, AddGlobalVariable(id));
        }
    }

    void DeclarePredicates() {
        for (const auto pred : ir.GetPredicates()) {
            const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
            Name(id, fmt::format("pred_{}", static_cast<u32>(pred)));
            predicates.emplace(pred, AddGlobalVariable(id));
        }
    }

    void DeclareFlowVariables() {
        for (u32 i = 0; i < ir.GetASTNumVariables(); i++) {
            const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
            Name(id, fmt::format("flow_var_{}", static_cast<u32>(i)));
            flow_variables.emplace(i, AddGlobalVariable(id));
        }
    }

    void DeclareLocalMemory() {
        // TODO(Rodrigo): Unstub kernel local memory size and pass it from a register at
        // specialization time.
        const u64 lmem_size = stage == ShaderType::Compute ? 0x400 : header.GetLocalMemorySize();
        if (lmem_size == 0) {
            return;
        }
        const auto element_count = static_cast<u32>(Common::AlignUp(lmem_size, 4) / 4);
        const Id type_array = TypeArray(t_float, Constant(t_uint, element_count));
        const Id type_pointer = TypePointer(spv::StorageClass::Private, type_array);
        Name(type_pointer, "LocalMemory");

        local_memory =
            OpVariable(type_pointer, spv::StorageClass::Private, ConstantNull(type_array));
        AddGlobalVariable(Name(local_memory, "local_memory"));
    }

    void DeclareSharedMemory() {
        if (stage != ShaderType::Compute) {
            return;
        }
        t_smem_uint = TypePointer(spv::StorageClass::Workgroup, t_uint);

        const u32 smem_size = specialization.shared_memory_size;
        if (smem_size == 0) {
            // Avoid declaring an empty array.
            return;
        }
        const auto element_count = static_cast<u32>(Common::AlignUp(smem_size, 4) / 4);
        const Id type_array = TypeArray(t_uint, Constant(t_uint, element_count));
        const Id type_pointer = TypePointer(spv::StorageClass::Workgroup, type_array);
        Name(type_pointer, "SharedMemory");

        shared_memory = OpVariable(type_pointer, spv::StorageClass::Workgroup);
        AddGlobalVariable(Name(shared_memory, "shared_memory"));
    }

    void DeclareInternalFlags() {
        constexpr std::array names = {"zero", "sign", "carry", "overflow"};
        for (std::size_t flag = 0; flag < INTERNAL_FLAGS_COUNT; ++flag) {
            const auto flag_code = static_cast<InternalFlag>(flag);
            const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
            internal_flags[flag] = AddGlobalVariable(Name(id, names[flag]));
        }
    }

    void DeclareInputVertexArray(u32 length) {
        constexpr auto storage = spv::StorageClass::Input;
        std::tie(in_indices, in_vertex) = DeclareVertexArray(storage, "in_indices", length);
    }

    void DeclareOutputVertexArray(u32 length) {
        constexpr auto storage = spv::StorageClass::Output;
        std::tie(out_indices, out_vertex) = DeclareVertexArray(storage, "out_indices", length);
    }

    std::tuple<VertexIndices, Id> DeclareVertexArray(spv::StorageClass storage_class,
                                                     std::string name, u32 length) {
        const auto [struct_id, indices] = DeclareVertexStruct();
        const Id vertex_array = TypeArray(struct_id, Constant(t_uint, length));
        const Id vertex_ptr = TypePointer(storage_class, vertex_array);
        const Id vertex = OpVariable(vertex_ptr, storage_class);
        AddGlobalVariable(Name(vertex, std::move(name)));
        interfaces.push_back(vertex);
        return {indices, vertex};
    }

    void DeclareOutputVertex() {
        Id out_vertex_struct;
        std::tie(out_vertex_struct, out_indices) = DeclareVertexStruct();
        const Id out_vertex_ptr = TypePointer(spv::StorageClass::Output, out_vertex_struct);
        out_vertex = OpVariable(out_vertex_ptr, spv::StorageClass::Output);
        interfaces.push_back(AddGlobalVariable(Name(out_vertex, "out_vertex")));
    }

    void DeclareInputAttributes() {
        for (const auto index : ir.GetInputAttributes()) {
            if (!IsGenericAttribute(index)) {
                continue;
            }

            const u32 location = GetGenericAttributeLocation(index);
            const auto type_descriptor = GetAttributeType(location);
            Id type;
            if (IsInputAttributeArray()) {
                type = GetTypeVectorDefinitionLut(type_descriptor.type).at(3);
                type = TypeArray(type, Constant(t_uint, GetNumInputVertices()));
                type = TypePointer(spv::StorageClass::Input, type);
            } else {
                type = type_descriptor.vector;
            }
            const Id id = OpVariable(type, spv::StorageClass::Input);
            AddGlobalVariable(Name(id, fmt::format("in_attr{}", location)));
            input_attributes.emplace(index, id);
            interfaces.push_back(id);

            Decorate(id, spv::Decoration::Location, location);

            if (stage != ShaderType::Fragment) {
                continue;
            }
            switch (header.ps.GetAttributeUse(location)) {
            case AttributeUse::Constant:
                Decorate(id, spv::Decoration::Flat);
                break;
            case AttributeUse::ScreenLinear:
                Decorate(id, spv::Decoration::NoPerspective);
                break;
            case AttributeUse::Perspective:
                // Default
                break;
            default:
                UNREACHABLE_MSG("Unused attribute being fetched");
            }
        }
    }

    void DeclareOutputAttributes() {
        if (stage == ShaderType::Compute || stage == ShaderType::Fragment) {
            return;
        }

        UNIMPLEMENTED_IF(registry.GetGraphicsInfo().tfb_enabled && stage != ShaderType::Vertex);
        for (const auto index : ir.GetOutputAttributes()) {
            if (!IsGenericAttribute(index)) {
                continue;
            }
            DeclareOutputAttribute(index);
        }
    }

    void DeclareOutputAttribute(Attribute::Index index) {
        static constexpr std::string_view swizzle = "xyzw";

        const u32 location = GetGenericAttributeLocation(index);
        u8 element = 0;
        while (element < 4) {
            const std::size_t remainder = 4 - element;

            std::size_t num_components = remainder;
            const std::optional tfb = GetTransformFeedbackInfo(index, element);
            if (tfb) {
                num_components = tfb->components;
            }

            Id type = GetTypeVectorDefinitionLut(Type::Float).at(num_components - 1);
            Id varying_default = v_varying_default;
            if (IsOutputAttributeArray()) {
                const u32 num = GetNumOutputVertices();
                type = TypeArray(type, Constant(t_uint, num));
                if (device.GetDriverID() != vk::DriverIdKHR::eIntelProprietaryWindows) {
                    // Intel's proprietary driver fails to setup defaults for arrayed output
                    // attributes.
                    varying_default = ConstantComposite(type, std::vector(num, varying_default));
                }
            }
            type = TypePointer(spv::StorageClass::Output, type);

            std::string name = fmt::format("out_attr{}", location);
            if (num_components < 4 || element > 0) {
                name = fmt::format("{}_{}", name, swizzle.substr(element, num_components));
            }

            const Id id = OpVariable(type, spv::StorageClass::Output, varying_default);
            Name(AddGlobalVariable(id), name);

            GenericVaryingDescription description;
            description.id = id;
            description.first_element = element;
            description.is_scalar = num_components == 1;
            for (u32 i = 0; i < num_components; ++i) {
                const u8 offset = static_cast<u8>(static_cast<u32>(index) * 4 + element + i);
                output_attributes.emplace(offset, description);
            }
            interfaces.push_back(id);

            Decorate(id, spv::Decoration::Location, location);
            if (element > 0) {
                Decorate(id, spv::Decoration::Component, static_cast<u32>(element));
            }
            if (tfb && device.IsExtTransformFeedbackSupported()) {
                Decorate(id, spv::Decoration::XfbBuffer, static_cast<u32>(tfb->buffer));
                Decorate(id, spv::Decoration::XfbStride, static_cast<u32>(tfb->stride));
                Decorate(id, spv::Decoration::Offset, static_cast<u32>(tfb->offset));
            }

            element = static_cast<u8>(static_cast<std::size_t>(element) + num_components);
        }
    }

    std::optional<VaryingTFB> GetTransformFeedbackInfo(Attribute::Index index, u8 element = 0) {
        const u8 location = static_cast<u8>(static_cast<u32>(index) * 4 + element);
        const auto it = transform_feedback.find(location);
        if (it == transform_feedback.end()) {
            return {};
        }
        return it->second;
    }

    u32 DeclareConstantBuffers(u32 binding) {
        for (const auto& [index, size] : ir.GetConstantBuffers()) {
            const Id type = device.IsKhrUniformBufferStandardLayoutSupported() ? t_cbuf_scalar_ubo
                                                                               : t_cbuf_std140_ubo;
            const Id id = OpVariable(type, spv::StorageClass::Uniform);
            AddGlobalVariable(Name(id, fmt::format("cbuf_{}", index)));

            Decorate(id, spv::Decoration::Binding, binding++);
            Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
            constant_buffers.emplace(index, id);
        }
        return binding;
    }

    u32 DeclareGlobalBuffers(u32 binding) {
        for (const auto& [base, usage] : ir.GetGlobalMemory()) {
            const Id id = OpVariable(t_gmem_ssbo, spv::StorageClass::StorageBuffer);
            AddGlobalVariable(
                Name(id, fmt::format("gmem_{}_{}", base.cbuf_index, base.cbuf_offset)));

            Decorate(id, spv::Decoration::Binding, binding++);
            Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
            global_buffers.emplace(base, id);
        }
        return binding;
    }

    u32 DeclareTexelBuffers(u32 binding) {
        for (const auto& sampler : ir.GetSamplers()) {
            if (!sampler.IsBuffer()) {
                continue;
            }
            ASSERT(!sampler.IsArray());
            ASSERT(!sampler.IsShadow());

            constexpr auto dim = spv::Dim::Buffer;
            constexpr int depth = 0;
            constexpr int arrayed = 0;
            constexpr bool ms = false;
            constexpr int sampled = 1;
            constexpr auto format = spv::ImageFormat::Unknown;
            const Id image_type = TypeImage(t_float, dim, depth, arrayed, ms, sampled, format);
            const Id pointer_type = TypePointer(spv::StorageClass::UniformConstant, image_type);
            const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
            AddGlobalVariable(Name(id, fmt::format("sampler_{}", sampler.GetIndex())));
            Decorate(id, spv::Decoration::Binding, binding++);
            Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);

            texel_buffers.emplace(sampler.GetIndex(), TexelBuffer{image_type, id});
        }
        return binding;
    }

    u32 DeclareSamplers(u32 binding) {
        for (const auto& sampler : ir.GetSamplers()) {
            if (sampler.IsBuffer()) {
                continue;
            }
            const auto dim = GetSamplerDim(sampler);
            const int depth = sampler.IsShadow() ? 1 : 0;
            const int arrayed = sampler.IsArray() ? 1 : 0;
            constexpr bool ms = false;
            constexpr int sampled = 1;
            constexpr auto format = spv::ImageFormat::Unknown;
            const Id image_type = TypeImage(t_float, dim, depth, arrayed, ms, sampled, format);
            const Id sampler_type = TypeSampledImage(image_type);
            const Id sampler_pointer_type =
                TypePointer(spv::StorageClass::UniformConstant, sampler_type);
            const Id type = sampler.IsIndexed()
                                ? TypeArray(sampler_type, Constant(t_uint, sampler.Size()))
                                : sampler_type;
            const Id pointer_type = TypePointer(spv::StorageClass::UniformConstant, type);
            const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
            AddGlobalVariable(Name(id, fmt::format("sampler_{}", sampler.GetIndex())));
            Decorate(id, spv::Decoration::Binding, binding++);
            Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);

            sampled_images.emplace(sampler.GetIndex(), SampledImage{image_type, sampler_type,
                                                                    sampler_pointer_type, id});
        }
        return binding;
    }

    u32 DeclareImages(u32 binding) {
        for (const auto& image : ir.GetImages()) {
            const auto [dim, arrayed] = GetImageDim(image);
            constexpr int depth = 0;
            constexpr bool ms = false;
            constexpr int sampled = 2; // This won't be accessed with a sampler
            constexpr auto format = spv::ImageFormat::Unknown;
            const Id image_type = TypeImage(t_uint, dim, depth, arrayed, ms, sampled, format, {});
            const Id pointer_type = TypePointer(spv::StorageClass::UniformConstant, image_type);
            const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
            AddGlobalVariable(Name(id, fmt::format("image_{}", image.GetIndex())));

            Decorate(id, spv::Decoration::Binding, binding++);
            Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
            if (image.IsRead() && !image.IsWritten()) {
                Decorate(id, spv::Decoration::NonWritable);
            } else if (image.IsWritten() && !image.IsRead()) {
                Decorate(id, spv::Decoration::NonReadable);
            }

            images.emplace(static_cast<u32>(image.GetIndex()), StorageImage{image_type, id});
        }
        return binding;
    }

    bool IsRenderTargetEnabled(u32 rt) const {
        for (u32 component = 0; component < 4; ++component) {
            if (header.ps.IsColorComponentOutputEnabled(rt, component)) {
                return true;
            }
        }
        return false;
    }

    bool IsInputAttributeArray() const {
        return stage == ShaderType::TesselationControl || stage == ShaderType::TesselationEval ||
               stage == ShaderType::Geometry;
    }

    bool IsOutputAttributeArray() const {
        return stage == ShaderType::TesselationControl;
    }

    u32 GetNumInputVertices() const {
        switch (stage) {
        case ShaderType::Geometry:
            return GetNumPrimitiveTopologyVertices(registry.GetGraphicsInfo().primitive_topology);
        case ShaderType::TesselationControl:
        case ShaderType::TesselationEval:
            return NumInputPatches;
        default:
            UNREACHABLE();
            return 1;
        }
    }

    u32 GetNumOutputVertices() const {
        switch (stage) {
        case ShaderType::TesselationControl:
            return header.common2.threads_per_input_primitive;
        default:
            UNREACHABLE();
            return 1;
        }
    }

    std::tuple<Id, VertexIndices> DeclareVertexStruct() {
        struct BuiltIn {
            Id type;
            spv::BuiltIn builtin;
            const char* name;
        };
        std::vector<BuiltIn> members;
        members.reserve(4);

        const auto AddBuiltIn = [&](Id type, spv::BuiltIn builtin, const char* name) {
            const auto index = static_cast<u32>(members.size());
            members.push_back(BuiltIn{type, builtin, name});
            return index;
        };

        VertexIndices indices;
        indices.position = AddBuiltIn(t_float4, spv::BuiltIn::Position, "position");

        if (ir.UsesLayer()) {
            if (stage != ShaderType::Vertex || device.IsExtShaderViewportIndexLayerSupported()) {
                indices.layer = AddBuiltIn(t_int, spv::BuiltIn::Layer, "layer");
            } else {
                LOG_ERROR(
                    Render_Vulkan,
                    "Shader requires Layer but it's not supported on this stage with this device.");
            }
        }

        if (ir.UsesViewportIndex()) {
            if (stage != ShaderType::Vertex || device.IsExtShaderViewportIndexLayerSupported()) {
                indices.viewport = AddBuiltIn(t_int, spv::BuiltIn::ViewportIndex, "viewport_index");
            } else {
                LOG_ERROR(Render_Vulkan, "Shader requires ViewportIndex but it's not supported on "
                                         "this stage with this device.");
            }
        }

        if (ir.UsesPointSize() || specialization.point_size) {
            indices.point_size = AddBuiltIn(t_float, spv::BuiltIn::PointSize, "point_size");
        }

        const auto& output_attributes = ir.GetOutputAttributes();
        const bool declare_clip_distances =
            std::any_of(output_attributes.begin(), output_attributes.end(), [](const auto& index) {
                return index == Attribute::Index::ClipDistances0123 ||
                       index == Attribute::Index::ClipDistances4567;
            });
        if (declare_clip_distances) {
            indices.clip_distances = AddBuiltIn(TypeArray(t_float, Constant(t_uint, 8)),
                                                spv::BuiltIn::ClipDistance, "clip_distances");
        }

        std::vector<Id> member_types;
        member_types.reserve(members.size());
        for (std::size_t i = 0; i < members.size(); ++i) {
            member_types.push_back(members[i].type);
        }
        const Id per_vertex_struct = Name(TypeStruct(member_types), "PerVertex");
        Decorate(per_vertex_struct, spv::Decoration::Block);

        for (std::size_t index = 0; index < members.size(); ++index) {
            const auto& member = members[index];
            MemberName(per_vertex_struct, static_cast<u32>(index), member.name);
            MemberDecorate(per_vertex_struct, static_cast<u32>(index), spv::Decoration::BuiltIn,
                           static_cast<u32>(member.builtin));
        }

        return {per_vertex_struct, indices};
    }

    void VisitBasicBlock(const NodeBlock& bb) {
        for (const auto& node : bb) {
            [[maybe_unused]] const Type type = Visit(node).type;
            ASSERT(type == Type::Void);
        }
    }

    Expression Visit(const Node& node) {
        if (const auto operation = std::get_if<OperationNode>(&*node)) {
            if (const auto amend_index = operation->GetAmendIndex()) {
                [[maybe_unused]] const Type type = Visit(ir.GetAmendNode(*amend_index)).type;
                ASSERT(type == Type::Void);
            }
            const auto operation_index = static_cast<std::size_t>(operation->GetCode());
            const auto decompiler = operation_decompilers[operation_index];
            if (decompiler == nullptr) {
                UNREACHABLE_MSG("Operation decompiler {} not defined", operation_index);
            }
            return (this->*decompiler)(*operation);
        }

        if (const auto gpr = std::get_if<GprNode>(&*node)) {
            const u32 index = gpr->GetIndex();
            if (index == Register::ZeroIndex) {
                return {v_float_zero, Type::Float};
            }
            return {OpLoad(t_float, registers.at(index)), Type::Float};
        }

        if (const auto cv = std::get_if<CustomVarNode>(&*node)) {
            const u32 index = cv->GetIndex();
            return {OpLoad(t_float, custom_variables.at(index)), Type::Float};
        }

        if (const auto immediate = std::get_if<ImmediateNode>(&*node)) {
            return {Constant(t_uint, immediate->GetValue()), Type::Uint};
        }

        if (const auto predicate = std::get_if<PredicateNode>(&*node)) {
            const auto value = [&]() -> Id {
                switch (const auto index = predicate->GetIndex(); index) {
                case Tegra::Shader::Pred::UnusedIndex:
                    return v_true;
                case Tegra::Shader::Pred::NeverExecute:
                    return v_false;
                default:
                    return OpLoad(t_bool, predicates.at(index));
                }
            }();
            if (predicate->IsNegated()) {
                return {OpLogicalNot(t_bool, value), Type::Bool};
            }
            return {value, Type::Bool};
        }

        if (const auto abuf = std::get_if<AbufNode>(&*node)) {
            const auto attribute = abuf->GetIndex();
            const u32 element = abuf->GetElement();
            const auto& buffer = abuf->GetBuffer();

            const auto ArrayPass = [&](Id pointer_type, Id composite, std::vector<u32> indices) {
                std::vector<Id> members;
                members.reserve(std::size(indices) + 1);

                if (buffer && IsInputAttributeArray()) {
                    members.push_back(AsUint(Visit(buffer)));
                }
                for (const u32 index : indices) {
                    members.push_back(Constant(t_uint, index));
                }
                return OpAccessChain(pointer_type, composite, members);
            };

            switch (attribute) {
            case Attribute::Index::Position: {
                if (stage == ShaderType::Fragment) {
                    if (element == 3) {
                        return {Constant(t_float, 1.0f), Type::Float};
                    }
                    return {OpLoad(t_float, AccessElement(t_in_float, frag_coord, element)),
                            Type::Float};
                }
                const std::vector elements = {in_indices.position.value(), element};
                return {OpLoad(t_float, ArrayPass(t_in_float, in_vertex, elements)), Type::Float};
            }
            case Attribute::Index::PointCoord: {
                switch (element) {
                case 0:
                case 1:
                    return {OpCompositeExtract(t_float, OpLoad(t_float2, point_coord), element),
                            Type::Float};
                }
                UNIMPLEMENTED_MSG("Unimplemented point coord element={}", element);
                return {v_float_zero, Type::Float};
            }
            case Attribute::Index::TessCoordInstanceIDVertexID:
                // TODO(Subv): Find out what the values are for the first two elements when inside a
                // vertex shader, and what's the value of the fourth element when inside a Tess Eval
                // shader.
                switch (element) {
                case 0:
                case 1:
                    return {OpLoad(t_float, AccessElement(t_in_float, tess_coord, element)),
                            Type::Float};
                case 2:
                    return {
                        OpISub(t_int, OpLoad(t_int, instance_index), OpLoad(t_int, base_instance)),
                        Type::Int};
                case 3:
                    return {OpISub(t_int, OpLoad(t_int, vertex_index), OpLoad(t_int, base_vertex)),
                            Type::Int};
                }
                UNIMPLEMENTED_MSG("Unmanaged TessCoordInstanceIDVertexID element={}", element);
                return {Constant(t_uint, 0U), Type::Uint};
            case Attribute::Index::FrontFacing:
                // TODO(Subv): Find out what the values are for the other elements.
                ASSERT(stage == ShaderType::Fragment);
                if (element == 3) {
                    const Id is_front_facing = OpLoad(t_bool, front_facing);
                    const Id true_value = Constant(t_int, static_cast<s32>(-1));
                    const Id false_value = Constant(t_int, 0);
                    return {OpSelect(t_int, is_front_facing, true_value, false_value), Type::Int};
                }
                UNIMPLEMENTED_MSG("Unmanaged FrontFacing element={}", element);
                return {v_float_zero, Type::Float};
            default:
                if (IsGenericAttribute(attribute)) {
                    const u32 location = GetGenericAttributeLocation(attribute);
                    const auto type_descriptor = GetAttributeType(location);
                    const Type type = type_descriptor.type;
                    const Id attribute_id = input_attributes.at(attribute);
                    const std::vector elements = {element};
                    const Id pointer = ArrayPass(type_descriptor.scalar, attribute_id, elements);
                    return {OpLoad(GetTypeDefinition(type), pointer), type};
                }
                break;
            }
            UNIMPLEMENTED_MSG("Unhandled input attribute: {}", static_cast<u32>(attribute));
            return {v_float_zero, Type::Float};
        }

        if (const auto cbuf = std::get_if<CbufNode>(&*node)) {
            const Node& offset = cbuf->GetOffset();
            const Id buffer_id = constant_buffers.at(cbuf->GetIndex());

            Id pointer{};
            if (device.IsKhrUniformBufferStandardLayoutSupported()) {
                const Id buffer_offset =
                    OpShiftRightLogical(t_uint, AsUint(Visit(offset)), Constant(t_uint, 2U));
                pointer =
                    OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0U), buffer_offset);
            } else {
                Id buffer_index{};
                Id buffer_element{};
                if (const auto immediate = std::get_if<ImmediateNode>(&*offset)) {
                    // Direct access
                    const u32 offset_imm = immediate->GetValue();
                    ASSERT(offset_imm % 4 == 0);
                    buffer_index = Constant(t_uint, offset_imm / 16);
                    buffer_element = Constant(t_uint, (offset_imm / 4) % 4);
                } else if (std::holds_alternative<OperationNode>(*offset)) {
                    // Indirect access
                    const Id offset_id = AsUint(Visit(offset));
                    const Id unsafe_offset = OpUDiv(t_uint, offset_id, Constant(t_uint, 4));
                    const Id final_offset =
                        OpUMod(t_uint, unsafe_offset, Constant(t_uint, MaxConstBufferElements - 1));
                    buffer_index = OpUDiv(t_uint, final_offset, Constant(t_uint, 4));
                    buffer_element = OpUMod(t_uint, final_offset, Constant(t_uint, 4));
                } else {
                    UNREACHABLE_MSG("Unmanaged offset node type");
                }
                pointer = OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0), buffer_index,
                                        buffer_element);
            }
            return {OpLoad(t_float, pointer), Type::Float};
        }

        if (const auto gmem = std::get_if<GmemNode>(&*node)) {
            return {OpLoad(t_uint, GetGlobalMemoryPointer(*gmem)), Type::Uint};
        }

        if (const auto lmem = std::get_if<LmemNode>(&*node)) {
            Id address = AsUint(Visit(lmem->GetAddress()));
            address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U));
            const Id pointer = OpAccessChain(t_prv_float, local_memory, address);
            return {OpLoad(t_float, pointer), Type::Float};
        }

        if (const auto smem = std::get_if<SmemNode>(&*node)) {
            return {OpLoad(t_uint, GetSharedMemoryPointer(*smem)), Type::Uint};
        }

        if (const auto internal_flag = std::get_if<InternalFlagNode>(&*node)) {
            const Id flag = internal_flags.at(static_cast<std::size_t>(internal_flag->GetFlag()));
            return {OpLoad(t_bool, flag), Type::Bool};
        }

        if (const auto conditional = std::get_if<ConditionalNode>(&*node)) {
            if (const auto amend_index = conditional->GetAmendIndex()) {
                [[maybe_unused]] const Type type = Visit(ir.GetAmendNode(*amend_index)).type;
                ASSERT(type == Type::Void);
            }
            // It's invalid to call conditional on nested nodes, use an operation instead
            const Id true_label = OpLabel();
            const Id skip_label = OpLabel();
            const Id condition = AsBool(Visit(conditional->GetCondition()));
            OpSelectionMerge(skip_label, spv::SelectionControlMask::MaskNone);
            OpBranchConditional(condition, true_label, skip_label);
            AddLabel(true_label);

            conditional_branch_set = true;
            inside_branch = false;
            VisitBasicBlock(conditional->GetCode());
            conditional_branch_set = false;
            if (!inside_branch) {
                OpBranch(skip_label);
            } else {
                inside_branch = false;
            }
            AddLabel(skip_label);
            return {};
        }

        if (const auto comment = std::get_if<CommentNode>(&*node)) {
            Name(OpUndef(t_void), comment->GetText());
            return {};
        }

        UNREACHABLE();
        return {};
    }

    template <Id (Module::*func)(Id, Id), Type result_type, Type type_a = result_type>
    Expression Unary(Operation operation) {
        const Id type_def = GetTypeDefinition(result_type);
        const Id op_a = As(Visit(operation[0]), type_a);

        const Id value = (this->*func)(type_def, op_a);
        if (IsPrecise(operation)) {
            Decorate(value, spv::Decoration::NoContraction);
        }
        return {value, result_type};
    }

    template <Id (Module::*func)(Id, Id, Id), Type result_type, Type type_a = result_type,
              Type type_b = type_a>
    Expression Binary(Operation operation) {
        const Id type_def = GetTypeDefinition(result_type);
        const Id op_a = As(Visit(operation[0]), type_a);
        const Id op_b = As(Visit(operation[1]), type_b);

        const Id value = (this->*func)(type_def, op_a, op_b);
        if (IsPrecise(operation)) {
            Decorate(value, spv::Decoration::NoContraction);
        }
        return {value, result_type};
    }

    template <Id (Module::*func)(Id, Id, Id, Id), Type result_type, Type type_a = result_type,
              Type type_b = type_a, Type type_c = type_b>
    Expression Ternary(Operation operation) {
        const Id type_def = GetTypeDefinition(result_type);
        const Id op_a = As(Visit(operation[0]), type_a);
        const Id op_b = As(Visit(operation[1]), type_b);
        const Id op_c = As(Visit(operation[2]), type_c);

        const Id value = (this->*func)(type_def, op_a, op_b, op_c);
        if (IsPrecise(operation)) {
            Decorate(value, spv::Decoration::NoContraction);
        }
        return {value, result_type};
    }

    template <Id (Module::*func)(Id, Id, Id, Id, Id), Type result_type, Type type_a = result_type,
              Type type_b = type_a, Type type_c = type_b, Type type_d = type_c>
    Expression Quaternary(Operation operation) {
        const Id type_def = GetTypeDefinition(result_type);
        const Id op_a = As(Visit(operation[0]), type_a);
        const Id op_b = As(Visit(operation[1]), type_b);
        const Id op_c = As(Visit(operation[2]), type_c);
        const Id op_d = As(Visit(operation[3]), type_d);

        const Id value = (this->*func)(type_def, op_a, op_b, op_c, op_d);
        if (IsPrecise(operation)) {
            Decorate(value, spv::Decoration::NoContraction);
        }
        return {value, result_type};
    }

    Expression Assign(Operation operation) {
        const Node& dest = operation[0];
        const Node& src = operation[1];

        Expression target{};
        if (const auto gpr = std::get_if<GprNode>(&*dest)) {
            if (gpr->GetIndex() == Register::ZeroIndex) {
                // Writing to Register::ZeroIndex is a no op
                return {};
            }
            target = {registers.at(gpr->GetIndex()), Type::Float};

        } else if (const auto abuf = std::get_if<AbufNode>(&*dest)) {
            const auto& buffer = abuf->GetBuffer();
            const auto ArrayPass = [&](Id pointer_type, Id composite, std::vector<u32> indices) {
                std::vector<Id> members;
                members.reserve(std::size(indices) + 1);

                if (buffer && IsOutputAttributeArray()) {
                    members.push_back(AsUint(Visit(buffer)));
                }
                for (const u32 index : indices) {
                    members.push_back(Constant(t_uint, index));
                }
                return OpAccessChain(pointer_type, composite, members);
            };

            target = [&]() -> Expression {
                const u32 element = abuf->GetElement();
                switch (const auto attribute = abuf->GetIndex(); attribute) {
                case Attribute::Index::Position: {
                    const u32 index = out_indices.position.value();
                    return {ArrayPass(t_out_float, out_vertex, {index, element}), Type::Float};
                }
                case Attribute::Index::LayerViewportPointSize:
                    switch (element) {
                    case 1: {
                        if (!out_indices.layer) {
                            return {};
                        }
                        const u32 index = out_indices.layer.value();
                        return {AccessElement(t_out_int, out_vertex, index), Type::Int};
                    }
                    case 2: {
                        if (!out_indices.viewport) {
                            return {};
                        }
                        const u32 index = out_indices.viewport.value();
                        return {AccessElement(t_out_int, out_vertex, index), Type::Int};
                    }
                    case 3: {
                        const auto index = out_indices.point_size.value();
                        return {AccessElement(t_out_float, out_vertex, index), Type::Float};
                    }
                    default:
                        UNIMPLEMENTED_MSG("LayerViewportPoint element={}", abuf->GetElement());
                        return {};
                    }
                case Attribute::Index::ClipDistances0123: {
                    const u32 index = out_indices.clip_distances.value();
                    return {AccessElement(t_out_float, out_vertex, index, element), Type::Float};
                }
                case Attribute::Index::ClipDistances4567: {
                    const u32 index = out_indices.clip_distances.value();
                    return {AccessElement(t_out_float, out_vertex, index, element + 4),
                            Type::Float};
                }
                default:
                    if (IsGenericAttribute(attribute)) {
                        const u8 offset = static_cast<u8>(static_cast<u8>(attribute) * 4 + element);
                        const GenericVaryingDescription description = output_attributes.at(offset);
                        const Id composite = description.id;
                        std::vector<u32> indices;
                        if (!description.is_scalar) {
                            indices.push_back(element - description.first_element);
                        }
                        return {ArrayPass(t_out_float, composite, indices), Type::Float};
                    }
                    UNIMPLEMENTED_MSG("Unhandled output attribute: {}",
                                      static_cast<u32>(attribute));
                    return {};
                }
            }();

        } else if (const auto patch = std::get_if<PatchNode>(&*dest)) {
            target = [&]() -> Expression {
                const u32 offset = patch->GetOffset();
                switch (offset) {
                case 0:
                case 1:
                case 2:
                case 3:
                    return {AccessElement(t_out_float, tess_level_outer, offset % 4), Type::Float};
                case 4:
                case 5:
                    return {AccessElement(t_out_float, tess_level_inner, offset % 4), Type::Float};
                }
                UNIMPLEMENTED_MSG("Unhandled patch output offset: {}", offset);
                return {};
            }();

        } else if (const auto lmem = std::get_if<LmemNode>(&*dest)) {
            Id address = AsUint(Visit(lmem->GetAddress()));
            address = OpUDiv(t_uint, address, Constant(t_uint, 4));
            target = {OpAccessChain(t_prv_float, local_memory, address), Type::Float};

        } else if (const auto smem = std::get_if<SmemNode>(&*dest)) {
            target = {GetSharedMemoryPointer(*smem), Type::Uint};

        } else if (const auto gmem = std::get_if<GmemNode>(&*dest)) {
            target = {GetGlobalMemoryPointer(*gmem), Type::Uint};

        } else if (const auto cv = std::get_if<CustomVarNode>(&*dest)) {
            target = {custom_variables.at(cv->GetIndex()), Type::Float};

        } else {
            UNIMPLEMENTED();
        }

        if (!target.id) {
            // On failure we return a nullptr target.id, skip these stores.
            return {};
        }

        OpStore(target.id, As(Visit(src), target.type));
        return {};
    }

    template <u32 offset>
    Expression FCastHalf(Operation operation) {
        const Id value = AsHalfFloat(Visit(operation[0]));
        return {GetFloatFromHalfScalar(OpCompositeExtract(t_scalar_half, value, offset)),
                Type::Float};
    }

    Expression FSwizzleAdd(Operation operation) {
        const Id minus = Constant(t_float, -1.0f);
        const Id plus = v_float_one;
        const Id zero = v_float_zero;
        const Id lut_a = ConstantComposite(t_float4, minus, plus, minus, zero);
        const Id lut_b = ConstantComposite(t_float4, minus, minus, plus, minus);

        Id mask = OpLoad(t_uint, thread_id);
        mask = OpBitwiseAnd(t_uint, mask, Constant(t_uint, 3));
        mask = OpShiftLeftLogical(t_uint, mask, Constant(t_uint, 1));
        mask = OpShiftRightLogical(t_uint, AsUint(Visit(operation[2])), mask);
        mask = OpBitwiseAnd(t_uint, mask, Constant(t_uint, 3));

        const Id modifier_a = OpVectorExtractDynamic(t_float, lut_a, mask);
        const Id modifier_b = OpVectorExtractDynamic(t_float, lut_b, mask);

        const Id op_a = OpFMul(t_float, AsFloat(Visit(operation[0])), modifier_a);
        const Id op_b = OpFMul(t_float, AsFloat(Visit(operation[1])), modifier_b);
        return {OpFAdd(t_float, op_a, op_b), Type::Float};
    }

    Expression HNegate(Operation operation) {
        const bool is_f16 = device.IsFloat16Supported();
        const Id minus_one = Constant(t_scalar_half, is_f16 ? 0xbc00 : 0xbf800000);
        const Id one = Constant(t_scalar_half, is_f16 ? 0x3c00 : 0x3f800000);
        const auto GetNegate = [&](std::size_t index) {
            return OpSelect(t_scalar_half, AsBool(Visit(operation[index])), minus_one, one);
        };
        const Id negation = OpCompositeConstruct(t_half, GetNegate(1), GetNegate(2));
        return {OpFMul(t_half, AsHalfFloat(Visit(operation[0])), negation), Type::HalfFloat};
    }

    Expression HClamp(Operation operation) {
        const auto Pack = [&](std::size_t index) {
            const Id scalar = GetHalfScalarFromFloat(AsFloat(Visit(operation[index])));
            return OpCompositeConstruct(t_half, scalar, scalar);
        };
        const Id value = AsHalfFloat(Visit(operation[0]));
        const Id min = Pack(1);
        const Id max = Pack(2);

        const Id clamped = OpFClamp(t_half, value, min, max);
        if (IsPrecise(operation)) {
            Decorate(clamped, spv::Decoration::NoContraction);
        }
        return {clamped, Type::HalfFloat};
    }

    Expression HCastFloat(Operation operation) {
        const Id value = GetHalfScalarFromFloat(AsFloat(Visit(operation[0])));
        return {OpCompositeConstruct(t_half, value, Constant(t_scalar_half, 0)), Type::HalfFloat};
    }

    Expression HUnpack(Operation operation) {
        Expression operand = Visit(operation[0]);
        const auto type = std::get<Tegra::Shader::HalfType>(operation.GetMeta());
        if (type == Tegra::Shader::HalfType::H0_H1) {
            return operand;
        }
        const auto value = [&] {
            switch (std::get<Tegra::Shader::HalfType>(operation.GetMeta())) {
            case Tegra::Shader::HalfType::F32:
                return GetHalfScalarFromFloat(AsFloat(operand));
            case Tegra::Shader::HalfType::H0_H0:
                return OpCompositeExtract(t_scalar_half, AsHalfFloat(operand), 0);
            case Tegra::Shader::HalfType::H1_H1:
                return OpCompositeExtract(t_scalar_half, AsHalfFloat(operand), 1);
            default:
                UNREACHABLE();
                return ConstantNull(t_half);
            }
        }();
        return {OpCompositeConstruct(t_half, value, value), Type::HalfFloat};
    }

    Expression HMergeF32(Operation operation) {
        const Id value = AsHalfFloat(Visit(operation[0]));
        return {GetFloatFromHalfScalar(OpCompositeExtract(t_scalar_half, value, 0)), Type::Float};
    }

    template <u32 offset>
    Expression HMergeHN(Operation operation) {
        const Id target = AsHalfFloat(Visit(operation[0]));
        const Id source = AsHalfFloat(Visit(operation[1]));
        const Id object = OpCompositeExtract(t_scalar_half, source, offset);
        return {OpCompositeInsert(t_half, object, target, offset), Type::HalfFloat};
    }

    Expression HPack2(Operation operation) {
        const Id low = GetHalfScalarFromFloat(AsFloat(Visit(operation[0])));
        const Id high = GetHalfScalarFromFloat(AsFloat(Visit(operation[1])));
        return {OpCompositeConstruct(t_half, low, high), Type::HalfFloat};
    }

    Expression LogicalAssign(Operation operation) {
        const Node& dest = operation[0];
        const Node& src = operation[1];

        Id target{};
        if (const auto pred = std::get_if<PredicateNode>(&*dest)) {
            ASSERT_MSG(!pred->IsNegated(), "Negating logical assignment");

            const auto index = pred->GetIndex();
            switch (index) {
            case Tegra::Shader::Pred::NeverExecute:
            case Tegra::Shader::Pred::UnusedIndex:
                // Writing to these predicates is a no-op
                return {};
            }
            target = predicates.at(index);

        } else if (const auto flag = std::get_if<InternalFlagNode>(&*dest)) {
            target = internal_flags.at(static_cast<u32>(flag->GetFlag()));
        }

        OpStore(target, AsBool(Visit(src)));
        return {};
    }

    Id GetTextureSampler(Operation operation) {
        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
        ASSERT(!meta.sampler.IsBuffer());

        const auto& entry = sampled_images.at(meta.sampler.GetIndex());
        Id sampler = entry.variable;
        if (meta.sampler.IsIndexed()) {
            const Id index = AsInt(Visit(meta.index));
            sampler = OpAccessChain(entry.sampler_pointer_type, sampler, index);
        }
        return OpLoad(entry.sampler_type, sampler);
    }

    Id GetTextureImage(Operation operation) {
        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
        const u32 index = meta.sampler.GetIndex();
        if (meta.sampler.IsBuffer()) {
            const auto& entry = texel_buffers.at(index);
            return OpLoad(entry.image_type, entry.image);
        } else {
            const auto& entry = sampled_images.at(index);
            return OpImage(entry.image_type, GetTextureSampler(operation));
        }
    }

    Id GetImage(Operation operation) {
        const auto& meta = std::get<MetaImage>(operation.GetMeta());
        const auto entry = images.at(meta.image.GetIndex());
        return OpLoad(entry.image_type, entry.image);
    }

    Id AssembleVector(const std::vector<Id>& coords, Type type) {
        const Id coords_type = GetTypeVectorDefinitionLut(type).at(coords.size() - 1);
        return coords.size() == 1 ? coords[0] : OpCompositeConstruct(coords_type, coords);
    }

    Id GetCoordinates(Operation operation, Type type) {
        std::vector<Id> coords;
        for (std::size_t i = 0; i < operation.GetOperandsCount(); ++i) {
            coords.push_back(As(Visit(operation[i]), type));
        }
        if (const auto meta = std::get_if<MetaTexture>(&operation.GetMeta())) {
            // Add array coordinate for textures
            if (meta->sampler.IsArray()) {
                Id array = AsInt(Visit(meta->array));
                if (type == Type::Float) {
                    array = OpConvertSToF(t_float, array);
                }
                coords.push_back(array);
            }
        }
        return AssembleVector(coords, type);
    }

    Id GetOffsetCoordinates(Operation operation) {
        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
        std::vector<Id> coords;
        coords.reserve(meta.aoffi.size());
        for (const auto& coord : meta.aoffi) {
            coords.push_back(AsInt(Visit(coord)));
        }
        return AssembleVector(coords, Type::Int);
    }

    std::pair<Id, Id> GetDerivatives(Operation operation) {
        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
        const auto& derivatives = meta.derivates;
        ASSERT(derivatives.size() % 2 == 0);

        const std::size_t components = derivatives.size() / 2;
        std::vector<Id> dx, dy;
        dx.reserve(components);
        dy.reserve(components);
        for (std::size_t index = 0; index < components; ++index) {
            dx.push_back(AsFloat(Visit(derivatives.at(index * 2 + 0))));
            dy.push_back(AsFloat(Visit(derivatives.at(index * 2 + 1))));
        }
        return {AssembleVector(dx, Type::Float), AssembleVector(dy, Type::Float)};
    }

    Expression GetTextureElement(Operation operation, Id sample_value, Type type) {
        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
        const auto type_def = GetTypeDefinition(type);
        return {OpCompositeExtract(type_def, sample_value, meta.element), type};
    }

    Expression Texture(Operation operation) {
        const auto& meta = std::get<MetaTexture>(operation.GetMeta());

        const bool can_implicit = stage == ShaderType::Fragment;
        const Id sampler = GetTextureSampler(operation);
        const Id coords = GetCoordinates(operation, Type::Float);

        std::vector<Id> operands;
        spv::ImageOperandsMask mask{};
        if (meta.bias) {
            mask = mask | spv::ImageOperandsMask::Bias;
            operands.push_back(AsFloat(Visit(meta.bias)));
        }

        if (!can_implicit) {
            mask = mask | spv::ImageOperandsMask::Lod;
            operands.push_back(v_float_zero);
        }

        if (!meta.aoffi.empty()) {
            mask = mask | spv::ImageOperandsMask::Offset;
            operands.push_back(GetOffsetCoordinates(operation));
        }

        if (meta.depth_compare) {
            // Depth sampling
            UNIMPLEMENTED_IF(meta.bias);
            const Id dref = AsFloat(Visit(meta.depth_compare));
            if (can_implicit) {
                return {
                    OpImageSampleDrefImplicitLod(t_float, sampler, coords, dref, mask, operands),
                    Type::Float};
            } else {
                return {
                    OpImageSampleDrefExplicitLod(t_float, sampler, coords, dref, mask, operands),
                    Type::Float};
            }
        }

        Id texture;
        if (can_implicit) {
            texture = OpImageSampleImplicitLod(t_float4, sampler, coords, mask, operands);
        } else {
            texture = OpImageSampleExplicitLod(t_float4, sampler, coords, mask, operands);
        }
        return GetTextureElement(operation, texture, Type::Float);
    }

    Expression TextureLod(Operation operation) {
        const auto& meta = std::get<MetaTexture>(operation.GetMeta());

        const Id sampler = GetTextureSampler(operation);
        const Id coords = GetCoordinates(operation, Type::Float);
        const Id lod = AsFloat(Visit(meta.lod));

        spv::ImageOperandsMask mask = spv::ImageOperandsMask::Lod;
        std::vector<Id> operands{lod};

        if (!meta.aoffi.empty()) {
            mask = mask | spv::ImageOperandsMask::Offset;
            operands.push_back(GetOffsetCoordinates(operation));
        }

        if (meta.sampler.IsShadow()) {
            const Id dref = AsFloat(Visit(meta.depth_compare));
            return {OpImageSampleDrefExplicitLod(t_float, sampler, coords, dref, mask, operands),
                    Type::Float};
        }
        const Id texture = OpImageSampleExplicitLod(t_float4, sampler, coords, mask, operands);
        return GetTextureElement(operation, texture, Type::Float);
    }

    Expression TextureGather(Operation operation) {
        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
        UNIMPLEMENTED_IF(!meta.aoffi.empty());

        const Id coords = GetCoordinates(operation, Type::Float);
        Id texture{};
        if (meta.sampler.IsShadow()) {
            texture = OpImageDrefGather(t_float4, GetTextureSampler(operation), coords,
                                        AsFloat(Visit(meta.depth_compare)));
        } else {
            u32 component_value = 0;
            if (meta.component) {
                const auto component = std::get_if<ImmediateNode>(&*meta.component);
                ASSERT_MSG(component, "Component is not an immediate value");
                component_value = component->GetValue();
            }
            texture = OpImageGather(t_float4, GetTextureSampler(operation), coords,
                                    Constant(t_uint, component_value));
        }
        return GetTextureElement(operation, texture, Type::Float);
    }

    Expression TextureQueryDimensions(Operation operation) {
        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
        UNIMPLEMENTED_IF(!meta.aoffi.empty());
        UNIMPLEMENTED_IF(meta.depth_compare);

        const auto image_id = GetTextureImage(operation);
        if (meta.element == 3) {
            return {OpImageQueryLevels(t_int, image_id), Type::Int};
        }

        const Id lod = AsUint(Visit(operation[0]));
        const std::size_t coords_count = [&]() {
            switch (const auto type = meta.sampler.GetType(); type) {
            case Tegra::Shader::TextureType::Texture1D:
                return 1;
            case Tegra::Shader::TextureType::Texture2D:
            case Tegra::Shader::TextureType::TextureCube:
                return 2;
            case Tegra::Shader::TextureType::Texture3D:
                return 3;
            default:
                UNREACHABLE_MSG("Invalid texture type={}", static_cast<u32>(type));
                return 2;
            }
        }();

        if (meta.element >= coords_count) {
            return {v_float_zero, Type::Float};
        }

        const std::array<Id, 3> types = {t_int, t_int2, t_int3};
        const Id sizes = OpImageQuerySizeLod(types.at(coords_count - 1), image_id, lod);
        const Id size = OpCompositeExtract(t_int, sizes, meta.element);
        return {size, Type::Int};
    }

    Expression TextureQueryLod(Operation operation) {
        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
        UNIMPLEMENTED_IF(!meta.aoffi.empty());
        UNIMPLEMENTED_IF(meta.depth_compare);

        if (meta.element >= 2) {
            UNREACHABLE_MSG("Invalid element");
            return {v_float_zero, Type::Float};
        }
        const auto sampler_id = GetTextureSampler(operation);

        const Id multiplier = Constant(t_float, 256.0f);
        const Id multipliers = ConstantComposite(t_float2, multiplier, multiplier);

        const Id coords = GetCoordinates(operation, Type::Float);
        Id size = OpImageQueryLod(t_float2, sampler_id, coords);
        size = OpFMul(t_float2, size, multipliers);
        size = OpConvertFToS(t_int2, size);
        return GetTextureElement(operation, size, Type::Int);
    }

    Expression TexelFetch(Operation operation) {
        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
        UNIMPLEMENTED_IF(meta.depth_compare);

        const Id image = GetTextureImage(operation);
        const Id coords = GetCoordinates(operation, Type::Int);
        Id fetch;
        if (meta.lod && !meta.sampler.IsBuffer()) {
            fetch = OpImageFetch(t_float4, image, coords, spv::ImageOperandsMask::Lod,
                                 AsInt(Visit(meta.lod)));
        } else {
            fetch = OpImageFetch(t_float4, image, coords);
        }
        return GetTextureElement(operation, fetch, Type::Float);
    }

    Expression TextureGradient(Operation operation) {
        const auto& meta = std::get<MetaTexture>(operation.GetMeta());
        UNIMPLEMENTED_IF(!meta.aoffi.empty());

        const Id sampler = GetTextureSampler(operation);
        const Id coords = GetCoordinates(operation, Type::Float);
        const auto [dx, dy] = GetDerivatives(operation);
        const std::vector grad = {dx, dy};

        static constexpr auto mask = spv::ImageOperandsMask::Grad;
        const Id texture = OpImageSampleExplicitLod(t_float4, sampler, coords, mask, grad);
        return GetTextureElement(operation, texture, Type::Float);
    }

    Expression ImageLoad(Operation operation) {
        if (!device.IsFormatlessImageLoadSupported()) {
            return {v_float_zero, Type::Float};
        }

        const auto& meta{std::get<MetaImage>(operation.GetMeta())};

        const Id coords = GetCoordinates(operation, Type::Int);
        const Id texel = OpImageRead(t_uint4, GetImage(operation), coords);

        return {OpCompositeExtract(t_uint, texel, meta.element), Type::Uint};
    }

    Expression ImageStore(Operation operation) {
        const auto meta{std::get<MetaImage>(operation.GetMeta())};
        std::vector<Id> colors;
        for (const auto& value : meta.values) {
            colors.push_back(AsUint(Visit(value)));
        }

        const Id coords = GetCoordinates(operation, Type::Int);
        const Id texel = OpCompositeConstruct(t_uint4, colors);

        OpImageWrite(GetImage(operation), coords, texel, {});
        return {};
    }

    Expression AtomicImageAdd(Operation operation) {
        UNIMPLEMENTED();
        return {};
    }

    Expression AtomicImageMin(Operation operation) {
        UNIMPLEMENTED();
        return {};
    }

    Expression AtomicImageMax(Operation operation) {
        UNIMPLEMENTED();
        return {};
    }

    Expression AtomicImageAnd(Operation operation) {
        UNIMPLEMENTED();
        return {};
    }

    Expression AtomicImageOr(Operation operation) {
        UNIMPLEMENTED();
        return {};
    }

    Expression AtomicImageXor(Operation operation) {
        UNIMPLEMENTED();
        return {};
    }

    Expression AtomicImageExchange(Operation operation) {
        UNIMPLEMENTED();
        return {};
    }

    template <Id (Module::*func)(Id, Id, Id, Id, Id), Type result_type,
              Type value_type = result_type>
    Expression Atomic(Operation operation) {
        const Id type_def = GetTypeDefinition(result_type);

        Id pointer;
        if (const auto smem = std::get_if<SmemNode>(&*operation[0])) {
            pointer = GetSharedMemoryPointer(*smem);
        } else if (const auto gmem = std::get_if<GmemNode>(&*operation[0])) {
            pointer = GetGlobalMemoryPointer(*gmem);
        } else {
            UNREACHABLE();
            return {Constant(type_def, 0), result_type};
        }

        const Id value = As(Visit(operation[1]), value_type);

        const Id scope = Constant(t_uint, static_cast<u32>(spv::Scope::Device));
        const Id semantics = Constant(type_def, 0);

        return {(this->*func)(type_def, pointer, scope, semantics, value), result_type};
    }

    Expression Branch(Operation operation) {
        const auto& target = std::get<ImmediateNode>(*operation[0]);
        OpStore(jmp_to, Constant(t_uint, target.GetValue()));
        OpBranch(continue_label);
        inside_branch = true;
        if (!conditional_branch_set) {
            AddLabel();
        }
        return {};
    }

    Expression BranchIndirect(Operation operation) {
        const Id op_a = AsUint(Visit(operation[0]));

        OpStore(jmp_to, op_a);
        OpBranch(continue_label);
        inside_branch = true;
        if (!conditional_branch_set) {
            AddLabel();
        }
        return {};
    }

    Expression PushFlowStack(Operation operation) {
        const auto& target = std::get<ImmediateNode>(*operation[0]);
        const auto [flow_stack, flow_stack_top] = GetFlowStack(operation);
        const Id current = OpLoad(t_uint, flow_stack_top);
        const Id next = OpIAdd(t_uint, current, Constant(t_uint, 1));
        const Id access = OpAccessChain(t_func_uint, flow_stack, current);

        OpStore(access, Constant(t_uint, target.GetValue()));
        OpStore(flow_stack_top, next);
        return {};
    }

    Expression PopFlowStack(Operation operation) {
        const auto [flow_stack, flow_stack_top] = GetFlowStack(operation);
        const Id current = OpLoad(t_uint, flow_stack_top);
        const Id previous = OpISub(t_uint, current, Constant(t_uint, 1));
        const Id access = OpAccessChain(t_func_uint, flow_stack, previous);
        const Id target = OpLoad(t_uint, access);

        OpStore(flow_stack_top, previous);
        OpStore(jmp_to, target);
        OpBranch(continue_label);
        inside_branch = true;
        if (!conditional_branch_set) {
            AddLabel();
        }
        return {};
    }

    void PreExit() {
        if (stage == ShaderType::Vertex && specialization.ndc_minus_one_to_one) {
            const u32 position_index = out_indices.position.value();
            const Id z_pointer = AccessElement(t_out_float, out_vertex, position_index, 2U);
            const Id w_pointer = AccessElement(t_out_float, out_vertex, position_index, 3U);
            Id depth = OpLoad(t_float, z_pointer);
            depth = OpFAdd(t_float, depth, OpLoad(t_float, w_pointer));
            depth = OpFMul(t_float, depth, Constant(t_float, 0.5f));
            OpStore(z_pointer, depth);
        }
        if (stage == ShaderType::Fragment) {
            const auto SafeGetRegister = [&](u32 reg) {
                // TODO(Rodrigo): Replace with contains once C++20 releases
                if (const auto it = registers.find(reg); it != registers.end()) {
                    return OpLoad(t_float, it->second);
                }
                return v_float_zero;
            };

            UNIMPLEMENTED_IF_MSG(header.ps.omap.sample_mask != 0,
                                 "Sample mask write is unimplemented");

            // TODO(Rodrigo): Alpha testing

            // Write the color outputs using the data in the shader registers, disabled
            // rendertargets/components are skipped in the register assignment.
            u32 current_reg = 0;
            for (u32 rt = 0; rt < Maxwell::NumRenderTargets; ++rt) {
                // TODO(Subv): Figure out how dual-source blending is configured in the Switch.
                for (u32 component = 0; component < 4; ++component) {
                    if (!header.ps.IsColorComponentOutputEnabled(rt, component)) {
                        continue;
                    }
                    const Id pointer = AccessElement(t_out_float, frag_colors[rt], component);
                    OpStore(pointer, SafeGetRegister(current_reg));
                    ++current_reg;
                }
            }
            if (header.ps.omap.depth) {
                // The depth output is always 2 registers after the last color output, and
                // current_reg already contains one past the last color register.
                OpStore(frag_depth, SafeGetRegister(current_reg + 1));
            }
        }
    }

    Expression Exit(Operation operation) {
        PreExit();
        inside_branch = true;
        if (conditional_branch_set) {
            OpReturn();
        } else {
            const Id dummy = OpLabel();
            OpBranch(dummy);
            AddLabel(dummy);
            OpReturn();
            AddLabel();
        }
        return {};
    }

    Expression Discard(Operation operation) {
        inside_branch = true;
        if (conditional_branch_set) {
            OpKill();
        } else {
            const Id dummy = OpLabel();
            OpBranch(dummy);
            AddLabel(dummy);
            OpKill();
            AddLabel();
        }
        return {};
    }

    Expression EmitVertex(Operation) {
        OpEmitVertex();
        return {};
    }

    Expression EndPrimitive(Operation operation) {
        OpEndPrimitive();
        return {};
    }

    Expression InvocationId(Operation) {
        return {OpLoad(t_int, invocation_id), Type::Int};
    }

    Expression YNegate(Operation) {
        LOG_WARNING(Render_Vulkan, "(STUBBED)");
        return {Constant(t_float, 1.0f), Type::Float};
    }

    template <u32 element>
    Expression LocalInvocationId(Operation) {
        const Id id = OpLoad(t_uint3, local_invocation_id);
        return {OpCompositeExtract(t_uint, id, element), Type::Uint};
    }

    template <u32 element>
    Expression WorkGroupId(Operation operation) {
        const Id id = OpLoad(t_uint3, workgroup_id);
        return {OpCompositeExtract(t_uint, id, element), Type::Uint};
    }

    Expression BallotThread(Operation operation) {
        const Id predicate = AsBool(Visit(operation[0]));
        const Id ballot = OpSubgroupBallotKHR(t_uint4, predicate);

        if (!device.IsWarpSizePotentiallyBiggerThanGuest()) {
            // Guest-like devices can just return the first index.
            return {OpCompositeExtract(t_uint, ballot, 0U), Type::Uint};
        }

        // The others will have to return what is local to the current thread.
        // For instance a device with a warp size of 64 will return the upper uint when the current
        // thread is 38.
        const Id tid = OpLoad(t_uint, thread_id);
        const Id thread_index = OpShiftRightLogical(t_uint, tid, Constant(t_uint, 5));
        return {OpVectorExtractDynamic(t_uint, ballot, thread_index), Type::Uint};
    }

    template <Id (Module::*func)(Id, Id)>
    Expression Vote(Operation operation) {
        // TODO(Rodrigo): Handle devices with different warp sizes
        const Id predicate = AsBool(Visit(operation[0]));
        return {(this->*func)(t_bool, predicate), Type::Bool};
    }

    Expression ThreadId(Operation) {
        return {OpLoad(t_uint, thread_id), Type::Uint};
    }

    Expression ShuffleIndexed(Operation operation) {
        const Id value = AsFloat(Visit(operation[0]));
        const Id index = AsUint(Visit(operation[1]));
        return {OpSubgroupReadInvocationKHR(t_float, value, index), Type::Float};
    }

    Expression MemoryBarrierGL(Operation) {
        const auto scope = spv::Scope::Device;
        const auto semantics =
            spv::MemorySemanticsMask::AcquireRelease | spv::MemorySemanticsMask::UniformMemory |
            spv::MemorySemanticsMask::WorkgroupMemory |
            spv::MemorySemanticsMask::AtomicCounterMemory | spv::MemorySemanticsMask::ImageMemory;

        OpMemoryBarrier(Constant(t_uint, static_cast<u32>(scope)),
                        Constant(t_uint, static_cast<u32>(semantics)));
        return {};
    }

    Id DeclareBuiltIn(spv::BuiltIn builtin, spv::StorageClass storage, Id type, std::string name) {
        const Id id = OpVariable(type, storage);
        Decorate(id, spv::Decoration::BuiltIn, static_cast<u32>(builtin));
        AddGlobalVariable(Name(id, std::move(name)));
        interfaces.push_back(id);
        return id;
    }

    Id DeclareInputBuiltIn(spv::BuiltIn builtin, Id type, std::string name) {
        return DeclareBuiltIn(builtin, spv::StorageClass::Input, type, std::move(name));
    }

    template <typename... Args>
    Id AccessElement(Id pointer_type, Id composite, Args... elements_) {
        std::vector<Id> members;
        auto elements = {elements_...};
        for (const auto element : elements) {
            members.push_back(Constant(t_uint, element));
        }

        return OpAccessChain(pointer_type, composite, members);
    }

    Id As(Expression expr, Type wanted_type) {
        switch (wanted_type) {
        case Type::Bool:
            return AsBool(expr);
        case Type::Bool2:
            return AsBool2(expr);
        case Type::Float:
            return AsFloat(expr);
        case Type::Int:
            return AsInt(expr);
        case Type::Uint:
            return AsUint(expr);
        case Type::HalfFloat:
            return AsHalfFloat(expr);
        default:
            UNREACHABLE();
            return expr.id;
        }
    }

    Id AsBool(Expression expr) {
        ASSERT(expr.type == Type::Bool);
        return expr.id;
    }

    Id AsBool2(Expression expr) {
        ASSERT(expr.type == Type::Bool2);
        return expr.id;
    }

    Id AsFloat(Expression expr) {
        switch (expr.type) {
        case Type::Float:
            return expr.id;
        case Type::Int:
        case Type::Uint:
            return OpBitcast(t_float, expr.id);
        case Type::HalfFloat:
            if (device.IsFloat16Supported()) {
                return OpBitcast(t_float, expr.id);
            }
            return OpBitcast(t_float, OpPackHalf2x16(t_uint, expr.id));
        default:
            UNREACHABLE();
            return expr.id;
        }
    }

    Id AsInt(Expression expr) {
        switch (expr.type) {
        case Type::Int:
            return expr.id;
        case Type::Float:
        case Type::Uint:
            return OpBitcast(t_int, expr.id);
        case Type::HalfFloat:
            if (device.IsFloat16Supported()) {
                return OpBitcast(t_int, expr.id);
            }
            return OpPackHalf2x16(t_int, expr.id);
        default:
            UNREACHABLE();
            return expr.id;
        }
    }

    Id AsUint(Expression expr) {
        switch (expr.type) {
        case Type::Uint:
            return expr.id;
        case Type::Float:
        case Type::Int:
            return OpBitcast(t_uint, expr.id);
        case Type::HalfFloat:
            if (device.IsFloat16Supported()) {
                return OpBitcast(t_uint, expr.id);
            }
            return OpPackHalf2x16(t_uint, expr.id);
        default:
            UNREACHABLE();
            return expr.id;
        }
    }

    Id AsHalfFloat(Expression expr) {
        switch (expr.type) {
        case Type::HalfFloat:
            return expr.id;
        case Type::Float:
        case Type::Int:
        case Type::Uint:
            if (device.IsFloat16Supported()) {
                return OpBitcast(t_half, expr.id);
            }
            return OpUnpackHalf2x16(t_half, AsUint(expr));
        default:
            UNREACHABLE();
            return expr.id;
        }
    }

    Id GetHalfScalarFromFloat(Id value) {
        if (device.IsFloat16Supported()) {
            return OpFConvert(t_scalar_half, value);
        }
        return value;
    }

    Id GetFloatFromHalfScalar(Id value) {
        if (device.IsFloat16Supported()) {
            return OpFConvert(t_float, value);
        }
        return value;
    }

    AttributeType GetAttributeType(u32 location) const {
        if (stage != ShaderType::Vertex) {
            return {Type::Float, t_in_float, t_in_float4};
        }
        switch (specialization.attribute_types.at(location)) {
        case Maxwell::VertexAttribute::Type::SignedNorm:
        case Maxwell::VertexAttribute::Type::UnsignedNorm:
        case Maxwell::VertexAttribute::Type::UnsignedScaled:
        case Maxwell::VertexAttribute::Type::SignedScaled:
        case Maxwell::VertexAttribute::Type::Float:
            return {Type::Float, t_in_float, t_in_float4};
        case Maxwell::VertexAttribute::Type::SignedInt:
            return {Type::Int, t_in_int, t_in_int4};
        case Maxwell::VertexAttribute::Type::UnsignedInt:
            return {Type::Uint, t_in_uint, t_in_uint4};
        default:
            UNREACHABLE();
            return {Type::Float, t_in_float, t_in_float4};
        }
    }

    Id GetTypeDefinition(Type type) const {
        switch (type) {
        case Type::Bool:
            return t_bool;
        case Type::Bool2:
            return t_bool2;
        case Type::Float:
            return t_float;
        case Type::Int:
            return t_int;
        case Type::Uint:
            return t_uint;
        case Type::HalfFloat:
            return t_half;
        default:
            UNREACHABLE();
            return {};
        }
    }

    std::array<Id, 4> GetTypeVectorDefinitionLut(Type type) const {
        switch (type) {
        case Type::Float:
            return {t_float, t_float2, t_float3, t_float4};
        case Type::Int:
            return {t_int, t_int2, t_int3, t_int4};
        case Type::Uint:
            return {t_uint, t_uint2, t_uint3, t_uint4};
        default:
            UNIMPLEMENTED();
            return {};
        }
    }

    std::tuple<Id, Id> CreateFlowStack() {
        // TODO(Rodrigo): Figure out the actual depth of the flow stack, for now it seems unlikely
        // that shaders will use 20 nested SSYs and PBKs.
        constexpr u32 FLOW_STACK_SIZE = 20;
        constexpr auto storage_class = spv::StorageClass::Function;

        const Id flow_stack_type = TypeArray(t_uint, Constant(t_uint, FLOW_STACK_SIZE));
        const Id stack = OpVariable(TypePointer(storage_class, flow_stack_type), storage_class,
                                    ConstantNull(flow_stack_type));
        const Id top = OpVariable(t_func_uint, storage_class, Constant(t_uint, 0));
        AddLocalVariable(stack);
        AddLocalVariable(top);
        return std::tie(stack, top);
    }

    std::pair<Id, Id> GetFlowStack(Operation operation) {
        const auto stack_class = std::get<MetaStackClass>(operation.GetMeta());
        switch (stack_class) {
        case MetaStackClass::Ssy:
            return {ssy_flow_stack, ssy_flow_stack_top};
        case MetaStackClass::Pbk:
            return {pbk_flow_stack, pbk_flow_stack_top};
        }
        UNREACHABLE();
        return {};
    }

    Id GetGlobalMemoryPointer(const GmemNode& gmem) {
        const Id real = AsUint(Visit(gmem.GetRealAddress()));
        const Id base = AsUint(Visit(gmem.GetBaseAddress()));
        const Id diff = OpISub(t_uint, real, base);
        const Id offset = OpShiftRightLogical(t_uint, diff, Constant(t_uint, 2));
        const Id buffer = global_buffers.at(gmem.GetDescriptor());
        return OpAccessChain(t_gmem_uint, buffer, Constant(t_uint, 0), offset);
    }

    Id GetSharedMemoryPointer(const SmemNode& smem) {
        ASSERT(stage == ShaderType::Compute);
        Id address = AsUint(Visit(smem.GetAddress()));
        address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U));
        return OpAccessChain(t_smem_uint, shared_memory, address);
    }

    static constexpr std::array operation_decompilers = {
        &SPIRVDecompiler::Assign,

        &SPIRVDecompiler::Ternary<&Module::OpSelect, Type::Float, Type::Bool, Type::Float,
                                  Type::Float>,

        &SPIRVDecompiler::Binary<&Module::OpFAdd, Type::Float>,
        &SPIRVDecompiler::Binary<&Module::OpFMul, Type::Float>,
        &SPIRVDecompiler::Binary<&Module::OpFDiv, Type::Float>,
        &SPIRVDecompiler::Ternary<&Module::OpFma, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpFNegate, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpFAbs, Type::Float>,
        &SPIRVDecompiler::Ternary<&Module::OpFClamp, Type::Float>,
        &SPIRVDecompiler::FCastHalf<0>,
        &SPIRVDecompiler::FCastHalf<1>,
        &SPIRVDecompiler::Binary<&Module::OpFMin, Type::Float>,
        &SPIRVDecompiler::Binary<&Module::OpFMax, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpCos, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpSin, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpExp2, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpLog2, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpInverseSqrt, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpSqrt, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpRoundEven, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpFloor, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpCeil, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpTrunc, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpConvertSToF, Type::Float, Type::Int>,
        &SPIRVDecompiler::Unary<&Module::OpConvertUToF, Type::Float, Type::Uint>,
        &SPIRVDecompiler::FSwizzleAdd,

        &SPIRVDecompiler::Binary<&Module::OpIAdd, Type::Int>,
        &SPIRVDecompiler::Binary<&Module::OpIMul, Type::Int>,
        &SPIRVDecompiler::Binary<&Module::OpSDiv, Type::Int>,
        &SPIRVDecompiler::Unary<&Module::OpSNegate, Type::Int>,
        &SPIRVDecompiler::Unary<&Module::OpSAbs, Type::Int>,
        &SPIRVDecompiler::Binary<&Module::OpSMin, Type::Int>,
        &SPIRVDecompiler::Binary<&Module::OpSMax, Type::Int>,

        &SPIRVDecompiler::Unary<&Module::OpConvertFToS, Type::Int, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Int, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Int, Type::Int, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Int, Type::Int, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpShiftRightArithmetic, Type::Int, Type::Int, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Int>,
        &SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Int>,
        &SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Int>,
        &SPIRVDecompiler::Unary<&Module::OpNot, Type::Int>,
        &SPIRVDecompiler::Quaternary<&Module::OpBitFieldInsert, Type::Int>,
        &SPIRVDecompiler::Ternary<&Module::OpBitFieldSExtract, Type::Int>,
        &SPIRVDecompiler::Unary<&Module::OpBitCount, Type::Int>,
        &SPIRVDecompiler::Unary<&Module::OpFindSMsb, Type::Int>,

        &SPIRVDecompiler::Binary<&Module::OpIAdd, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpIMul, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpUDiv, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpUMin, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpUMax, Type::Uint>,
        &SPIRVDecompiler::Unary<&Module::OpConvertFToU, Type::Uint, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Uint, Type::Int>,
        &SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Uint>,
        &SPIRVDecompiler::Unary<&Module::OpNot, Type::Uint>,
        &SPIRVDecompiler::Quaternary<&Module::OpBitFieldInsert, Type::Uint>,
        &SPIRVDecompiler::Ternary<&Module::OpBitFieldUExtract, Type::Uint>,
        &SPIRVDecompiler::Unary<&Module::OpBitCount, Type::Uint>,
        &SPIRVDecompiler::Unary<&Module::OpFindUMsb, Type::Uint>,

        &SPIRVDecompiler::Binary<&Module::OpFAdd, Type::HalfFloat>,
        &SPIRVDecompiler::Binary<&Module::OpFMul, Type::HalfFloat>,
        &SPIRVDecompiler::Ternary<&Module::OpFma, Type::HalfFloat>,
        &SPIRVDecompiler::Unary<&Module::OpFAbs, Type::HalfFloat>,
        &SPIRVDecompiler::HNegate,
        &SPIRVDecompiler::HClamp,
        &SPIRVDecompiler::HCastFloat,
        &SPIRVDecompiler::HUnpack,
        &SPIRVDecompiler::HMergeF32,
        &SPIRVDecompiler::HMergeHN<0>,
        &SPIRVDecompiler::HMergeHN<1>,
        &SPIRVDecompiler::HPack2,

        &SPIRVDecompiler::LogicalAssign,
        &SPIRVDecompiler::Binary<&Module::OpLogicalAnd, Type::Bool>,
        &SPIRVDecompiler::Binary<&Module::OpLogicalOr, Type::Bool>,
        &SPIRVDecompiler::Binary<&Module::OpLogicalNotEqual, Type::Bool>,
        &SPIRVDecompiler::Unary<&Module::OpLogicalNot, Type::Bool>,
        &SPIRVDecompiler::Binary<&Module::OpVectorExtractDynamic, Type::Bool, Type::Bool2,
                                 Type::Uint>,
        &SPIRVDecompiler::Unary<&Module::OpAll, Type::Bool, Type::Bool2>,

        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::Float>,
        &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::Float>,
        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::Float>,
        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::Float>,
        &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::Float>,
        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::Float>,
        &SPIRVDecompiler::Unary<&Module::OpIsNan, Type::Bool, Type::Float>,

        &SPIRVDecompiler::Binary<&Module::OpSLessThan, Type::Bool, Type::Int>,
        &SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Int>,
        &SPIRVDecompiler::Binary<&Module::OpSLessThanEqual, Type::Bool, Type::Int>,
        &SPIRVDecompiler::Binary<&Module::OpSGreaterThan, Type::Bool, Type::Int>,
        &SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Int>,
        &SPIRVDecompiler::Binary<&Module::OpSGreaterThanEqual, Type::Bool, Type::Int>,

        &SPIRVDecompiler::Binary<&Module::OpULessThan, Type::Bool, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpULessThanEqual, Type::Bool, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpUGreaterThan, Type::Bool, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Uint>,
        &SPIRVDecompiler::Binary<&Module::OpUGreaterThanEqual, Type::Bool, Type::Uint>,

        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool2, Type::HalfFloat>,
        &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool2, Type::HalfFloat>,
        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool2, Type::HalfFloat>,
        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool2, Type::HalfFloat>,
        &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool2, Type::HalfFloat>,
        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool2, Type::HalfFloat>,
        // TODO(Rodrigo): Should these use the OpFUnord* variants?
        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool2, Type::HalfFloat>,
        &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool2, Type::HalfFloat>,
        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool2, Type::HalfFloat>,
        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool2, Type::HalfFloat>,
        &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool2, Type::HalfFloat>,
        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool2, Type::HalfFloat>,

        &SPIRVDecompiler::Texture,
        &SPIRVDecompiler::TextureLod,
        &SPIRVDecompiler::TextureGather,
        &SPIRVDecompiler::TextureQueryDimensions,
        &SPIRVDecompiler::TextureQueryLod,
        &SPIRVDecompiler::TexelFetch,
        &SPIRVDecompiler::TextureGradient,

        &SPIRVDecompiler::ImageLoad,
        &SPIRVDecompiler::ImageStore,
        &SPIRVDecompiler::AtomicImageAdd,
        &SPIRVDecompiler::AtomicImageAnd,
        &SPIRVDecompiler::AtomicImageOr,
        &SPIRVDecompiler::AtomicImageXor,
        &SPIRVDecompiler::AtomicImageExchange,

        &SPIRVDecompiler::Atomic<&Module::OpAtomicExchange, Type::Uint>,
        &SPIRVDecompiler::Atomic<&Module::OpAtomicIAdd, Type::Uint>,
        &SPIRVDecompiler::Atomic<&Module::OpAtomicUMin, Type::Uint>,
        &SPIRVDecompiler::Atomic<&Module::OpAtomicUMax, Type::Uint>,
        &SPIRVDecompiler::Atomic<&Module::OpAtomicAnd, Type::Uint>,
        &SPIRVDecompiler::Atomic<&Module::OpAtomicOr, Type::Uint>,
        &SPIRVDecompiler::Atomic<&Module::OpAtomicXor, Type::Uint>,

        &SPIRVDecompiler::Atomic<&Module::OpAtomicExchange, Type::Int>,
        &SPIRVDecompiler::Atomic<&Module::OpAtomicIAdd, Type::Int>,
        &SPIRVDecompiler::Atomic<&Module::OpAtomicSMin, Type::Int>,
        &SPIRVDecompiler::Atomic<&Module::OpAtomicSMax, Type::Int>,
        &SPIRVDecompiler::Atomic<&Module::OpAtomicAnd, Type::Int>,
        &SPIRVDecompiler::Atomic<&Module::OpAtomicOr, Type::Int>,
        &SPIRVDecompiler::Atomic<&Module::OpAtomicXor, Type::Int>,

        &SPIRVDecompiler::Branch,
        &SPIRVDecompiler::BranchIndirect,
        &SPIRVDecompiler::PushFlowStack,
        &SPIRVDecompiler::PopFlowStack,
        &SPIRVDecompiler::Exit,
        &SPIRVDecompiler::Discard,

        &SPIRVDecompiler::EmitVertex,
        &SPIRVDecompiler::EndPrimitive,

        &SPIRVDecompiler::InvocationId,
        &SPIRVDecompiler::YNegate,
        &SPIRVDecompiler::LocalInvocationId<0>,
        &SPIRVDecompiler::LocalInvocationId<1>,
        &SPIRVDecompiler::LocalInvocationId<2>,
        &SPIRVDecompiler::WorkGroupId<0>,
        &SPIRVDecompiler::WorkGroupId<1>,
        &SPIRVDecompiler::WorkGroupId<2>,

        &SPIRVDecompiler::BallotThread,
        &SPIRVDecompiler::Vote<&Module::OpSubgroupAllKHR>,
        &SPIRVDecompiler::Vote<&Module::OpSubgroupAnyKHR>,
        &SPIRVDecompiler::Vote<&Module::OpSubgroupAllEqualKHR>,

        &SPIRVDecompiler::ThreadId,
        &SPIRVDecompiler::ShuffleIndexed,

        &SPIRVDecompiler::MemoryBarrierGL,
    };
    static_assert(operation_decompilers.size() == static_cast<std::size_t>(OperationCode::Amount));

    const VKDevice& device;
    const ShaderIR& ir;
    const ShaderType stage;
    const Tegra::Shader::Header header;
    const Registry& registry;
    const Specialization& specialization;
    std::unordered_map<u8, VaryingTFB> transform_feedback;

    const Id t_void = Name(TypeVoid(), "void");

    const Id t_bool = Name(TypeBool(), "bool");
    const Id t_bool2 = Name(TypeVector(t_bool, 2), "bool2");

    const Id t_int = Name(TypeInt(32, true), "int");
    const Id t_int2 = Name(TypeVector(t_int, 2), "int2");
    const Id t_int3 = Name(TypeVector(t_int, 3), "int3");
    const Id t_int4 = Name(TypeVector(t_int, 4), "int4");

    const Id t_uint = Name(TypeInt(32, false), "uint");
    const Id t_uint2 = Name(TypeVector(t_uint, 2), "uint2");
    const Id t_uint3 = Name(TypeVector(t_uint, 3), "uint3");
    const Id t_uint4 = Name(TypeVector(t_uint, 4), "uint4");

    const Id t_float = Name(TypeFloat(32), "float");
    const Id t_float2 = Name(TypeVector(t_float, 2), "float2");
    const Id t_float3 = Name(TypeVector(t_float, 3), "float3");
    const Id t_float4 = Name(TypeVector(t_float, 4), "float4");

    const Id t_prv_bool = Name(TypePointer(spv::StorageClass::Private, t_bool), "prv_bool");
    const Id t_prv_float = Name(TypePointer(spv::StorageClass::Private, t_float), "prv_float");

    const Id t_func_uint = Name(TypePointer(spv::StorageClass::Function, t_uint), "func_uint");

    const Id t_in_bool = Name(TypePointer(spv::StorageClass::Input, t_bool), "in_bool");
    const Id t_in_int = Name(TypePointer(spv::StorageClass::Input, t_int), "in_int");
    const Id t_in_int4 = Name(TypePointer(spv::StorageClass::Input, t_int4), "in_int4");
    const Id t_in_uint = Name(TypePointer(spv::StorageClass::Input, t_uint), "in_uint");
    const Id t_in_uint3 = Name(TypePointer(spv::StorageClass::Input, t_uint3), "in_uint3");
    const Id t_in_uint4 = Name(TypePointer(spv::StorageClass::Input, t_uint4), "in_uint4");
    const Id t_in_float = Name(TypePointer(spv::StorageClass::Input, t_float), "in_float");
    const Id t_in_float2 = Name(TypePointer(spv::StorageClass::Input, t_float2), "in_float2");
    const Id t_in_float3 = Name(TypePointer(spv::StorageClass::Input, t_float3), "in_float3");
    const Id t_in_float4 = Name(TypePointer(spv::StorageClass::Input, t_float4), "in_float4");

    const Id t_out_int = Name(TypePointer(spv::StorageClass::Output, t_int), "out_int");

    const Id t_out_float = Name(TypePointer(spv::StorageClass::Output, t_float), "out_float");
    const Id t_out_float4 = Name(TypePointer(spv::StorageClass::Output, t_float4), "out_float4");

    const Id t_cbuf_float = TypePointer(spv::StorageClass::Uniform, t_float);
    const Id t_cbuf_std140 = Decorate(
        Name(TypeArray(t_float4, Constant(t_uint, MaxConstBufferElements)), "CbufStd140Array"),
        spv::Decoration::ArrayStride, 16U);
    const Id t_cbuf_scalar = Decorate(
        Name(TypeArray(t_float, Constant(t_uint, MaxConstBufferFloats)), "CbufScalarArray"),
        spv::Decoration::ArrayStride, 4U);
    const Id t_cbuf_std140_struct = MemberDecorate(
        Decorate(TypeStruct(t_cbuf_std140), spv::Decoration::Block), 0, spv::Decoration::Offset, 0);
    const Id t_cbuf_scalar_struct = MemberDecorate(
        Decorate(TypeStruct(t_cbuf_scalar), spv::Decoration::Block), 0, spv::Decoration::Offset, 0);
    const Id t_cbuf_std140_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_std140_struct);
    const Id t_cbuf_scalar_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_scalar_struct);

    Id t_smem_uint{};

    const Id t_gmem_uint = TypePointer(spv::StorageClass::StorageBuffer, t_uint);
    const Id t_gmem_array =
        Name(Decorate(TypeRuntimeArray(t_uint), spv::Decoration::ArrayStride, 4U), "GmemArray");
    const Id t_gmem_struct = MemberDecorate(
        Decorate(TypeStruct(t_gmem_array), spv::Decoration::Block), 0, spv::Decoration::Offset, 0);
    const Id t_gmem_ssbo = TypePointer(spv::StorageClass::StorageBuffer, t_gmem_struct);

    const Id v_float_zero = Constant(t_float, 0.0f);
    const Id v_float_one = Constant(t_float, 1.0f);

    // Nvidia uses these defaults for varyings (e.g. position and generic attributes)
    const Id v_varying_default =
        ConstantComposite(t_float4, v_float_zero, v_float_zero, v_float_zero, v_float_one);

    const Id v_true = ConstantTrue(t_bool);
    const Id v_false = ConstantFalse(t_bool);

    Id t_scalar_half{};
    Id t_half{};

    Id out_vertex{};
    Id in_vertex{};
    std::map<u32, Id> registers;
    std::map<u32, Id> custom_variables;
    std::map<Tegra::Shader::Pred, Id> predicates;
    std::map<u32, Id> flow_variables;
    Id local_memory{};
    Id shared_memory{};
    std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{};
    std::map<Attribute::Index, Id> input_attributes;
    std::unordered_map<u8, GenericVaryingDescription> output_attributes;
    std::map<u32, Id> constant_buffers;
    std::map<GlobalMemoryBase, Id> global_buffers;
    std::map<u32, TexelBuffer> texel_buffers;
    std::map<u32, SampledImage> sampled_images;
    std::map<u32, StorageImage> images;

    Id instance_index{};
    Id vertex_index{};
    Id base_instance{};
    Id base_vertex{};
    std::array<Id, Maxwell::NumRenderTargets> frag_colors{};
    Id frag_depth{};
    Id frag_coord{};
    Id front_facing{};
    Id point_coord{};
    Id tess_level_outer{};
    Id tess_level_inner{};
    Id tess_coord{};
    Id invocation_id{};
    Id workgroup_id{};
    Id local_invocation_id{};
    Id thread_id{};

    VertexIndices in_indices;
    VertexIndices out_indices;

    std::vector<Id> interfaces;

    Id jmp_to{};
    Id ssy_flow_stack_top{};
    Id pbk_flow_stack_top{};
    Id ssy_flow_stack{};
    Id pbk_flow_stack{};
    Id continue_label{};
    std::map<u32, Id> labels;

    bool conditional_branch_set{};
    bool inside_branch{};
};

class ExprDecompiler {
public:
    explicit ExprDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {}

    Id operator()(const ExprAnd& expr) {
        const Id type_def = decomp.GetTypeDefinition(Type::Bool);
        const Id op1 = Visit(expr.operand1);
        const Id op2 = Visit(expr.operand2);
        return decomp.OpLogicalAnd(type_def, op1, op2);
    }

    Id operator()(const ExprOr& expr) {
        const Id type_def = decomp.GetTypeDefinition(Type::Bool);
        const Id op1 = Visit(expr.operand1);
        const Id op2 = Visit(expr.operand2);
        return decomp.OpLogicalOr(type_def, op1, op2);
    }

    Id operator()(const ExprNot& expr) {
        const Id type_def = decomp.GetTypeDefinition(Type::Bool);
        const Id op1 = Visit(expr.operand1);
        return decomp.OpLogicalNot(type_def, op1);
    }

    Id operator()(const ExprPredicate& expr) {
        const auto pred = static_cast<Tegra::Shader::Pred>(expr.predicate);
        return decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred));
    }

    Id operator()(const ExprCondCode& expr) {
        return decomp.AsBool(decomp.Visit(decomp.ir.GetConditionCode(expr.cc)));
    }

    Id operator()(const ExprVar& expr) {
        return decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index));
    }

    Id operator()(const ExprBoolean& expr) {
        return expr.value ? decomp.v_true : decomp.v_false;
    }

    Id operator()(const ExprGprEqual& expr) {
        const Id target = decomp.Constant(decomp.t_uint, expr.value);
        Id gpr = decomp.OpLoad(decomp.t_float, decomp.registers.at(expr.gpr));
        gpr = decomp.OpBitcast(decomp.t_uint, gpr);
        return decomp.OpIEqual(decomp.t_bool, gpr, target);
    }

    Id Visit(const Expr& node) {
        return std::visit(*this, *node);
    }

private:
    SPIRVDecompiler& decomp;
};

class ASTDecompiler {
public:
    explicit ASTDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {}

    void operator()(const ASTProgram& ast) {
        ASTNode current = ast.nodes.GetFirst();
        while (current) {
            Visit(current);
            current = current->GetNext();
        }
    }

    void operator()(const ASTIfThen& ast) {
        ExprDecompiler expr_parser{decomp};
        const Id condition = expr_parser.Visit(ast.condition);
        const Id then_label = decomp.OpLabel();
        const Id endif_label = decomp.OpLabel();
        decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone);
        decomp.OpBranchConditional(condition, then_label, endif_label);
        decomp.AddLabel(then_label);
        ASTNode current = ast.nodes.GetFirst();
        while (current) {
            Visit(current);
            current = current->GetNext();
        }
        decomp.OpBranch(endif_label);
        decomp.AddLabel(endif_label);
    }

    void operator()([[maybe_unused]] const ASTIfElse& ast) {
        UNREACHABLE();
    }

    void operator()([[maybe_unused]] const ASTBlockEncoded& ast) {
        UNREACHABLE();
    }

    void operator()(const ASTBlockDecoded& ast) {
        decomp.VisitBasicBlock(ast.nodes);
    }

    void operator()(const ASTVarSet& ast) {
        ExprDecompiler expr_parser{decomp};
        const Id condition = expr_parser.Visit(ast.condition);
        decomp.OpStore(decomp.flow_variables.at(ast.index), condition);
    }

    void operator()([[maybe_unused]] const ASTLabel& ast) {
        // Do nothing
    }

    void operator()([[maybe_unused]] const ASTGoto& ast) {
        UNREACHABLE();
    }

    void operator()(const ASTDoWhile& ast) {
        const Id loop_label = decomp.OpLabel();
        const Id endloop_label = decomp.OpLabel();
        const Id loop_start_block = decomp.OpLabel();
        const Id loop_continue_block = decomp.OpLabel();
        current_loop_exit = endloop_label;
        decomp.OpBranch(loop_label);
        decomp.AddLabel(loop_label);
        decomp.OpLoopMerge(endloop_label, loop_continue_block, spv::LoopControlMask::MaskNone);
        decomp.OpBranch(loop_start_block);
        decomp.AddLabel(loop_start_block);
        ASTNode current = ast.nodes.GetFirst();
        while (current) {
            Visit(current);
            current = current->GetNext();
        }
        decomp.OpBranch(loop_continue_block);
        decomp.AddLabel(loop_continue_block);
        ExprDecompiler expr_parser{decomp};
        const Id condition = expr_parser.Visit(ast.condition);
        decomp.OpBranchConditional(condition, loop_label, endloop_label);
        decomp.AddLabel(endloop_label);
    }

    void operator()(const ASTReturn& ast) {
        if (!VideoCommon::Shader::ExprIsTrue(ast.condition)) {
            ExprDecompiler expr_parser{decomp};
            const Id condition = expr_parser.Visit(ast.condition);
            const Id then_label = decomp.OpLabel();
            const Id endif_label = decomp.OpLabel();
            decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone);
            decomp.OpBranchConditional(condition, then_label, endif_label);
            decomp.AddLabel(then_label);
            if (ast.kills) {
                decomp.OpKill();
            } else {
                decomp.PreExit();
                decomp.OpReturn();
            }
            decomp.AddLabel(endif_label);
        } else {
            const Id next_block = decomp.OpLabel();
            decomp.OpBranch(next_block);
            decomp.AddLabel(next_block);
            if (ast.kills) {
                decomp.OpKill();
            } else {
                decomp.PreExit();
                decomp.OpReturn();
            }
            decomp.AddLabel(decomp.OpLabel());
        }
    }

    void operator()(const ASTBreak& ast) {
        if (!VideoCommon::Shader::ExprIsTrue(ast.condition)) {
            ExprDecompiler expr_parser{decomp};
            const Id condition = expr_parser.Visit(ast.condition);
            const Id then_label = decomp.OpLabel();
            const Id endif_label = decomp.OpLabel();
            decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone);
            decomp.OpBranchConditional(condition, then_label, endif_label);
            decomp.AddLabel(then_label);
            decomp.OpBranch(current_loop_exit);
            decomp.AddLabel(endif_label);
        } else {
            const Id next_block = decomp.OpLabel();
            decomp.OpBranch(next_block);
            decomp.AddLabel(next_block);
            decomp.OpBranch(current_loop_exit);
            decomp.AddLabel(decomp.OpLabel());
        }
    }

    void Visit(const ASTNode& node) {
        std::visit(*this, *node->GetInnerData());
    }

private:
    SPIRVDecompiler& decomp;
    Id current_loop_exit{};
};

void SPIRVDecompiler::DecompileAST() {
    const u32 num_flow_variables = ir.GetASTNumVariables();
    for (u32 i = 0; i < num_flow_variables; i++) {
        const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
        Name(id, fmt::format("flow_var_{}", i));
        flow_variables.emplace(i, AddGlobalVariable(id));
    }

    DefinePrologue();

    const ASTNode program = ir.GetASTProgram();
    ASTDecompiler decompiler{*this};
    decompiler.Visit(program);

    const Id next_block = OpLabel();
    OpBranch(next_block);
    AddLabel(next_block);
}

} // Anonymous namespace

ShaderEntries GenerateShaderEntries(const VideoCommon::Shader::ShaderIR& ir) {
    ShaderEntries entries;
    for (const auto& cbuf : ir.GetConstantBuffers()) {
        entries.const_buffers.emplace_back(cbuf.second, cbuf.first);
    }
    for (const auto& [base, usage] : ir.GetGlobalMemory()) {
        entries.global_buffers.emplace_back(base.cbuf_index, base.cbuf_offset, usage.is_written);
    }
    for (const auto& sampler : ir.GetSamplers()) {
        if (sampler.IsBuffer()) {
            entries.texel_buffers.emplace_back(sampler);
        } else {
            entries.samplers.emplace_back(sampler);
        }
    }
    for (const auto& image : ir.GetImages()) {
        entries.images.emplace_back(image);
    }
    for (const auto& attribute : ir.GetInputAttributes()) {
        if (IsGenericAttribute(attribute)) {
            entries.attributes.insert(GetGenericAttributeLocation(attribute));
        }
    }
    entries.clip_distances = ir.GetClipDistances();
    entries.shader_length = ir.GetLength();
    entries.uses_warps = ir.UsesWarps();
    return entries;
}

std::vector<u32> Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
                           ShaderType stage, const VideoCommon::Shader::Registry& registry,
                           const Specialization& specialization) {
    return SPIRVDecompiler(device, ir, stage, registry, specialization).Assemble();
}

} // namespace Vulkan