blob: 0f894945621364075ec09f2cf282a225d5d616d3 [file] [log] [blame]
/* Copyright (c) 2024-2026 LunarG, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pass.h"
#include <vulkan/vulkan_core.h>
#include <cstdint>
#include <spirv/unified1/spirv.hpp>
#include "cooperative_matrix.h"
#include "function_basic_block.h"
#include "generated/spirv_grammar_helper.h"
#include "link.h"
#include "state_tracker/shader_instruction.h"
#include "module.h"
#include "gpuav/shaders/gpuav_error_codes.h"
#include "type_manager.h"
namespace gpuav {
namespace spirv {
Pass::Pass(Module& module, const OfflineModule& offline)
: module_(module), type_manager_(module_.type_manager_), link_info_(offline) {}
bool Pass::Run() {
const bool modified = Instrument();
if (module_.settings_.print_debug_info) {
PrintDebugInfo();
}
// Detect if any functions were applied that we need to add now
if (modified && !link_info_.functions.empty()) {
module_.link_infos_.emplace_back(link_info_);
}
return modified;
}
const Variable& Pass::GetBuiltInVariable(uint32_t built_in) {
uint32_t variable_id = 0;
for (const auto& annotation : module_.annotations_) {
if (annotation->Opcode() == spv::OpDecorate && annotation->Word(2) == spv::DecorationBuiltIn &&
annotation->Word(3) == built_in) {
variable_id = annotation->Word(1);
break;
}
}
if (variable_id == 0) {
variable_id = module_.TakeNextId();
auto new_inst = std::make_unique<Instruction>(4, spv::OpDecorate);
new_inst->Fill({variable_id, spv::DecorationBuiltIn, built_in});
module_.annotations_.emplace_back(std::move(new_inst));
}
// Currently we only ever needed Input variables and the built-ins we are using are not those that can be used by both Input and
// Output storage classes
const Variable* built_in_variable = type_manager_.FindVariableById(variable_id);
if (!built_in_variable) {
const Type& pointer_type = type_manager_.GetTypePointerBuiltInInput(spv::BuiltIn(built_in));
auto new_inst = std::make_unique<Instruction>(4, spv::OpVariable);
new_inst->Fill({pointer_type.Id(), variable_id, spv::StorageClassInput});
built_in_variable = &type_manager_.AddVariable(std::move(new_inst), pointer_type);
module_.AddInterfaceVariables(built_in_variable->Id(), spv::StorageClassInput);
} else {
// Slang with the --preserve-params option will leave built-in variables that aren't in any interface.
const uint32_t built_in_variable_id = built_in_variable->Id();
const Instruction* entry_point = module_.GetTargetEntryPoint();
bool found_variable = false;
uint32_t word = entry_point->GetEntryPointInterfaceStart();
const uint32_t total_words = entry_point->Length();
for (; word < total_words; word++) {
const uint32_t interface_id = entry_point->Word(word);
if (interface_id == built_in_variable_id) {
found_variable = true;
break;
}
}
if (!found_variable) {
module_.AddInterfaceVariables(variable_id, spv::StorageClassInput);
}
}
return *built_in_variable;
}
// Special function to map to the internal representation of the execution models used for GenerateStageMessage()
static uint32_t GetNormalizedExecutionModel(VkShaderStageFlagBits shader_stage) {
switch (shader_stage) {
case VK_SHADER_STAGE_VERTEX_BIT:
return glsl::kExecutionModel_Vertex;
case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT:
return glsl::kExecutionModel_TessellationControl;
case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT:
return glsl::kExecutionModel_TessellationEvaluation;
case VK_SHADER_STAGE_GEOMETRY_BIT:
return glsl::kExecutionModel_Geometry;
case VK_SHADER_STAGE_FRAGMENT_BIT:
return glsl::kExecutionModel_Fragment;
case VK_SHADER_STAGE_COMPUTE_BIT:
return glsl::kExecutionModel_GLCompute;
case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
return glsl::kExecutionModel_RayGenerationKHR;
case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
return glsl::kExecutionModel_AnyHitKHR;
case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
return glsl::kExecutionModel_ClosestHitKHR;
case VK_SHADER_STAGE_MISS_BIT_KHR:
return glsl::kExecutionModel_MissKHR;
case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
return glsl::kExecutionModel_IntersectionKHR;
case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
return glsl::kExecutionModel_CallableKHR;
case VK_SHADER_STAGE_TASK_BIT_EXT:
return glsl::kExecutionModel_TaskEXT;
case VK_SHADER_STAGE_MESH_BIT_EXT:
return glsl::kExecutionModel_MeshEXT;
case VK_SHADER_STAGE_SUBPASS_SHADING_BIT_HUAWEI:
case VK_SHADER_STAGE_CLUSTER_CULLING_BIT_HUAWEI:
case VK_SHADER_STAGE_ALL_GRAPHICS:
case VK_SHADER_STAGE_ALL:
break;
}
return glsl::kExecutionModel_Unknown;
}
// To reduce having to load this information everytime we do a OpFunctionCall, instead just create it once per Function block and
// reference it each time
uint32_t Pass::GetStageInfo(Function& function, const BasicBlock& target_block_it, InstructionIt& out_inst_it) {
// Cached so only need to compute this once
if (function.stage_info_id_ != 0) {
return function.stage_info_id_;
}
// Save original for later to restore
const Instruction& target_instruction = *out_inst_it->get();
BasicBlock& block = function.GetFirstBlock();
InstructionIt inst_it = block.GetFirstInjectableInstrution();
// Stage info is always passed in as a uvec4
const Type& uint32_type = type_manager_.GetTypeInt(32, false);
const Type& uvec4_type = type_manager_.GetTypeVector(uint32_type, 4);
const uint32_t uint32_0_id = type_manager_.GetConstantZeroUint32().Id();
uint32_t stage_info[4] = {uint32_0_id, uint32_0_id, uint32_0_id, uint32_0_id};
const uint32_t execution_model = GetNormalizedExecutionModel(module_.interface_.entry_point_stage);
stage_info[0] = type_manager_.GetConstantUInt32(execution_model).Id();
// Gets BuiltIn variable and creates a valid OpLoad of it
auto create_load = [this, &block, &inst_it](spv::BuiltIn built_in) {
const Variable& variable = GetBuiltInVariable(built_in);
const Type* pointer_type = variable.PointerType(type_manager_);
const uint32_t load_id = module_.TakeNextId();
block.CreateInstruction(spv::OpLoad, {pointer_type->Id(), load_id, variable.Id()}, &inst_it);
return load_id;
};
switch (module_.interface_.entry_point_stage) {
case VK_SHADER_STAGE_VERTEX_BIT: {
uint32_t load_id = create_load(spv::BuiltInVertexIndex);
stage_info[1] = CastToUint32(load_id, block, &inst_it);
load_id = create_load(spv::BuiltInInstanceIndex);
stage_info[2] = CastToUint32(load_id, block, &inst_it);
} break;
case VK_SHADER_STAGE_FRAGMENT_BIT: {
const uint32_t load_id = create_load(spv::BuiltInFragCoord);
// convert vec4 to uvec4
const uint32_t bitcast_id = module_.TakeNextId();
block.CreateInstruction(spv::OpBitcast, {uvec4_type.Id(), bitcast_id, load_id}, &inst_it);
for (uint32_t i = 0; i < 2; i++) {
const uint32_t extract_id = module_.TakeNextId();
block.CreateInstruction(spv::OpCompositeExtract, {uint32_type.Id(), extract_id, bitcast_id, i}, &inst_it);
stage_info[i + 1] = extract_id;
}
} break;
case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
case VK_SHADER_STAGE_MISS_BIT_KHR:
case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
case VK_SHADER_STAGE_CALLABLE_BIT_KHR: {
const uint32_t load_id = create_load(spv::BuiltInLaunchIdKHR);
for (uint32_t i = 0; i < 3; i++) {
const uint32_t extract_id = module_.TakeNextId();
block.CreateInstruction(spv::OpCompositeExtract, {uint32_type.Id(), extract_id, load_id, i}, &inst_it);
stage_info[i + 1] = extract_id;
}
} break;
case VK_SHADER_STAGE_COMPUTE_BIT:
case VK_SHADER_STAGE_TASK_BIT_NV:
case VK_SHADER_STAGE_MESH_BIT_NV: {
// This can be both a uvec3 or ivec3 so need to cast if ivec3
const Variable& variable = GetBuiltInVariable(spv::BuiltInGlobalInvocationId);
const Type* pointer_type = variable.PointerType(type_manager_);
const uint32_t load_id = module_.TakeNextId();
block.CreateInstruction(spv::OpLoad, {pointer_type->Id(), load_id, variable.Id()}, &inst_it);
uint32_t final_load_id = load_id;
if (pointer_type->IsIVec3(type_manager_)) {
const Type& vec3_type = type_manager_.GetTypeVector(uint32_type, 3);
final_load_id = module_.TakeNextId();
block.CreateInstruction(spv::OpBitcast, {vec3_type.Id(), final_load_id, load_id}, &inst_it);
}
for (uint32_t i = 0; i < 3; i++) {
const uint32_t extract_id = module_.TakeNextId();
block.CreateInstruction(spv::OpCompositeExtract, {uint32_type.Id(), extract_id, final_load_id, i}, &inst_it);
stage_info[i + 1] = extract_id;
}
} break;
case VK_SHADER_STAGE_GEOMETRY_BIT: {
const uint32_t primitive_id = create_load(spv::BuiltInPrimitiveId);
stage_info[1] = CastToUint32(primitive_id, block, &inst_it);
const uint32_t load_id = create_load(spv::BuiltInInvocationId);
stage_info[2] = CastToUint32(load_id, block, &inst_it);
} break;
case VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT: {
const uint32_t load_id = create_load(spv::BuiltInInvocationId);
stage_info[1] = CastToUint32(load_id, block, &inst_it);
const uint32_t primitive_id = create_load(spv::BuiltInPrimitiveId);
stage_info[2] = CastToUint32(primitive_id, block, &inst_it);
} break;
case VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT: {
const uint32_t primitive_id = create_load(spv::BuiltInPrimitiveId);
stage_info[1] = CastToUint32(primitive_id, block, &inst_it);
// convert vec3 to uvec3
const Type& uvec3_type = type_manager_.GetTypeVector(uint32_type, 3);
const uint32_t load_id = create_load(spv::BuiltInTessCoord);
const uint32_t bitcast_id = module_.TakeNextId();
block.CreateInstruction(spv::OpBitcast, {uvec3_type.Id(), bitcast_id, load_id}, &inst_it);
// TessCoord.uv values from it
for (uint32_t i = 0; i < 2; i++) {
const uint32_t extract_id = module_.TakeNextId();
block.CreateInstruction(spv::OpCompositeExtract, {uint32_type.Id(), extract_id, bitcast_id, i}, &inst_it);
stage_info[i + 2] = extract_id;
}
} break;
default:
module_.InternalError(Name(), "GetStageInfo has unsupported stage");
break;
}
function.stage_info_id_ = module_.TakeNextId();
block.CreateInstruction(spv::OpCompositeConstruct,
{uvec4_type.Id(), function.stage_info_id_, stage_info[0], stage_info[1], stage_info[2], stage_info[3]},
&inst_it);
function.stage_info_x_id_ = stage_info[0];
function.stage_info_y_id_ = stage_info[1];
function.stage_info_z_id_ = stage_info[2];
function.stage_info_w_id_ = stage_info[3];
// because we are injecting things in the first block, there is a chance we just destroyed the iterator if the target
// instruction was also in the first block, so need to regain it for the caller
if (target_block_it.GetLabelId() == block.GetLabelId()) {
out_inst_it = FindTargetInstruction(block, target_instruction);
}
return function.stage_info_id_;
}
const Instruction* Pass::GetDecoration(uint32_t id, spv::Decoration decoration) const {
for (const auto& annotation : module_.annotations_) {
if (annotation->Opcode() == spv::OpDecorate && annotation->Word(1) == id &&
spv::Decoration(annotation->Word(2)) == decoration) {
return annotation.get();
}
}
return nullptr;
}
const Instruction* Pass::GetMemberDecoration(uint32_t id, uint32_t member_index, spv::Decoration decoration) const {
for (const auto& annotation : module_.annotations_) {
if (annotation->Opcode() == spv::OpMemberDecorate && annotation->Word(1) == id && annotation->Word(2) == member_index &&
spv::Decoration(annotation->Word(3)) == decoration) {
return annotation.get();
}
}
return nullptr;
}
// In an ideal world, this would be baked into the Type class when we construct it. The core issue is OpTypeMatrix size can be
// different depending where it is used. Because of this, we need to have a higher level view what is going on in order to correctly
// figure out the size of a given type.
uint32_t Pass::FindTypeByteSize(uint32_t type_id, uint32_t matrix_stride, bool col_major, bool in_matrix) const {
const Type& type = *type_manager_.FindTypeById(type_id);
switch (type.spv_type_) {
case SpvType::kPointer:
return 8; // Assuming PhysicalStorageBuffer pointer
break;
case SpvType::kMatrix: {
if (matrix_stride == 0) {
module_.InternalError("FindTypeByteSize", "missing matrix stride");
}
if (col_major) {
return type.inst_.Word(3) * matrix_stride;
} else {
const Type* vector_type = type_manager_.FindTypeById(type.inst_.Word(2));
return vector_type->inst_.Word(3) * matrix_stride;
}
}
case SpvType::kVector: {
uint32_t size = type.inst_.Word(3);
const Type* component_type = type_manager_.FindTypeById(type.inst_.Word(2));
// if vector in row major matrix, the vector is strided so return the number of bytes spanned by the vector
if (in_matrix && !col_major && matrix_stride > 0) {
return (size - 1) * matrix_stride + FindTypeByteSize(component_type->Id());
} else if (component_type->spv_type_ == SpvType::kFloat || component_type->spv_type_ == SpvType::kInt) {
const uint32_t width = component_type->inst_.Word(2);
size *= width;
} else {
module_.InternalError("FindTypeByteSize", "unexpected vector type");
}
return size / 8;
}
case SpvType::kFloat:
case SpvType::kInt: {
const uint32_t width = type.inst_.Word(2);
return width / 8;
}
case SpvType::kArray: {
const uint32_t array_stride = GetDecoration(type_id, spv::DecorationArrayStride)->Word(3);
const Constant* count = type_manager_.FindConstantById(type.inst_.Operand(1));
// TODO - Need to handle spec constant here, for now return one to have things not blowup
assert(count && !count->is_spec_constant_);
const uint32_t array_length = (count && !count->is_spec_constant_) ? count->inst_.Operand(0) : 1;
return array_length * array_stride;
}
case SpvType::kStruct: {
const uint32_t struct_length = type.inst_.Length() - 2;
const uint32_t struct_id = type.inst_.ResultId();
// We do our best to find the "size" of the struct (see https://gitlab.khronos.org/spirv/SPIR-V/-/issues/763)
uint32_t highest_element_index = 0;
uint32_t highest_element_offset = 0;
for (uint32_t i = 0; i < struct_length; i++) {
for (const auto& annotation : module_.annotations_) {
if (annotation->Opcode() == spv::OpMemberDecorate && annotation->Word(1) == struct_id &&
annotation->Word(2) == i && spv::Decoration(annotation->Word(3)) == spv::DecorationOffset) {
const uint32_t member_offset = annotation->Word(4);
if (member_offset > highest_element_offset) {
highest_element_index = i;
highest_element_offset = member_offset;
}
break;
}
}
}
const uint32_t last_offset_id = type.inst_.Operand(highest_element_index);
const Type* last_offset_type = type_manager_.FindTypeById(last_offset_id);
uint32_t highest_element_size = 0;
if (last_offset_type->spv_type_ == SpvType::kMatrix) {
// TODO - We need a better way to handle Matrix at the end of structs
const Instruction* decoration_matrix_stride =
GetMemberDecoration(struct_id, highest_element_index, spv::DecorationMatrixStride);
matrix_stride = decoration_matrix_stride ? decoration_matrix_stride->Word(4) : 0;
const Instruction* decoration_col_major =
GetMemberDecoration(struct_id, highest_element_index, spv::DecorationColMajor);
col_major = decoration_col_major != nullptr;
highest_element_size = FindTypeByteSize(last_offset_id, matrix_stride, col_major, true);
} else {
highest_element_size = FindTypeByteSize(last_offset_id);
}
return highest_element_offset + highest_element_size;
}
default:
break;
}
return 1;
}
// Find outermost buffer type and its access chain index.
// Because access chains indexes can be runtime values, we need to build arithmetic logic in the SPIR-V to get the runtime value of
// the indexing
uint32_t Pass::GetLastByte(const Type& descriptor_type, const std::vector<const Instruction*>& access_chain_insts,
const CooperativeMatrixAccess& coop_mat_access, BasicBlock& block, InstructionIt* inst_it) {
assert(!access_chain_insts.empty());
uint32_t current_type_id = 0;
const uint32_t reset_ac_word = 4; // points to first "Index" operand of an OpAccessChain
uint32_t ac_word_index = reset_ac_word;
if (descriptor_type.IsArray()) {
current_type_id = descriptor_type.inst_.Operand(0);
ac_word_index++; // this jumps over the array of descriptors so we first start on the descriptor itself
} else if (descriptor_type.spv_type_ == SpvType::kStruct) {
current_type_id = descriptor_type.Id();
} else {
module_.InternalError(Name(), "GetLastByte has unexpected descriptor type");
return 0;
}
const uint32_t uint32_type_id = type_manager_.GetTypeInt(32, false).Id();
// instruction that will have calculated the sum of the byte offset
uint32_t sum_id = 0;
uint32_t matrix_stride = 0;
bool col_major = false;
uint32_t matrix_stride_id = 0;
bool in_matrix = false;
// This loop gets use to the last element, so if we have something like
//
// Struct foo {
// uint a; // 4 bytes
// vec4 b; // 16 bytes
// float c; <--- accessing
// }
//
// it will get us to 20 bytes
auto access_chain_iter = access_chain_insts.rbegin();
// This occurs in things like Slang where they have a single OpAccessChain for the descriptor
// (GLSL/HLSL will combine 2 indexes into the last OpAccessChain)
if (ac_word_index >= (*access_chain_iter)->Length()) {
++access_chain_iter;
ac_word_index = reset_ac_word;
}
while (access_chain_iter != access_chain_insts.rend()) {
const uint32_t ac_index_id = (*access_chain_iter)->Word(ac_word_index);
uint32_t current_offset_id = 0;
const Type* current_type = type_manager_.FindTypeById(current_type_id);
switch (current_type->spv_type_) {
case SpvType::kArray:
case SpvType::kRuntimeArray: {
// Get array stride and multiply by current index
const uint32_t array_stride = GetDecoration(current_type_id, spv::DecorationArrayStride)->Word(3);
const uint32_t array_stride_id = type_manager_.GetConstantUInt32(array_stride).Id();
const uint32_t ac_index_id_32 = ConvertTo32(ac_index_id, block, inst_it);
current_offset_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIMul, {uint32_type_id, current_offset_id, array_stride_id, ac_index_id_32}, inst_it);
// Get element type for next step
current_type_id = current_type->inst_.Operand(0);
} break;
case SpvType::kMatrix: {
if (matrix_stride == 0) {
module_.InternalError(Name(), "GetLastByte is missing matrix stride");
}
matrix_stride_id = type_manager_.GetConstantUInt32(matrix_stride).Id();
uint32_t vec_type_id = current_type->inst_.Operand(0);
// If column major, multiply column index by matrix stride, otherwise by vector component size and save matrix
// stride for vector (row) index
uint32_t col_stride_id = 0;
if (col_major) {
col_stride_id = matrix_stride_id;
} else {
const uint32_t component_type_id = type_manager_.FindTypeById(vec_type_id)->inst_.Operand(0);
const uint32_t col_stride = FindTypeByteSize(component_type_id);
col_stride_id = type_manager_.GetConstantUInt32(col_stride).Id();
}
const uint32_t ac_index_id_32 = ConvertTo32(ac_index_id, block, inst_it);
current_offset_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIMul, {uint32_type_id, current_offset_id, col_stride_id, ac_index_id_32}, inst_it);
// Get element type for next step
current_type_id = vec_type_id;
in_matrix = true;
} break;
case SpvType::kVector: {
// If inside a row major matrix type, multiply index by matrix stride,
// else multiply by component size
const uint32_t component_type_id = current_type->inst_.Operand(0);
const uint32_t ac_index_id_32 = ConvertTo32(ac_index_id, block, inst_it);
if (in_matrix && !col_major) {
current_offset_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIMul, {uint32_type_id, current_offset_id, matrix_stride_id, ac_index_id_32},
inst_it);
} else {
const uint32_t component_type_size = FindTypeByteSize(component_type_id);
const uint32_t size_id = type_manager_.GetConstantUInt32(component_type_size).Id();
current_offset_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIMul, {uint32_type_id, current_offset_id, size_id, ac_index_id_32}, inst_it);
}
// Get element type for next step
current_type_id = component_type_id;
} break;
case SpvType::kStruct: {
// Get buffer byte offset for the referenced member
const Constant* member_constant = type_manager_.FindConstantById(ac_index_id);
assert(!member_constant->is_spec_constant_);
uint32_t member_index = member_constant->inst_.Operand(0);
uint32_t member_offset = GetMemberDecoration(current_type_id, member_index, spv::DecorationOffset)->Word(4);
current_offset_id = type_manager_.GetConstantUInt32(member_offset).Id();
// Look for matrix stride for this member if there is one. The matrix
// stride is not on the matrix type, but in a OpMemberDecorate on the
// enclosing struct type at the member index. If none found, reset
// stride to 0.
const Instruction* decoration_matrix_stride =
GetMemberDecoration(current_type_id, member_index, spv::DecorationMatrixStride);
matrix_stride = decoration_matrix_stride ? decoration_matrix_stride->Word(4) : 0;
const Instruction* decoration_col_major =
GetMemberDecoration(current_type_id, member_index, spv::DecorationColMajor);
col_major = decoration_col_major != nullptr;
// Get element type for next step
current_type_id = current_type->inst_.Operand(member_index);
} break;
default: {
module_.InternalError(Name(), "GetLastByte has unexpected non-composite type");
} break;
}
if (sum_id == 0) {
sum_id = current_offset_id;
} else {
const uint32_t new_sum_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIAdd, {uint32_type_id, new_sum_id, sum_id, current_offset_id}, inst_it);
sum_id = new_sum_id;
}
ac_word_index++;
if (ac_word_index >= (*access_chain_iter)->Length()) {
++access_chain_iter;
ac_word_index = reset_ac_word;
}
}
// Add in offset of last byte of referenced object.
uint32_t accessed_type_size = 0;
// For CooperativeMatrix the |current_type_id| will be an an Int or Float as that is the element type being accessed
if (coop_mat_access.used) {
// The stride here could be constant, so if it is, just use it, otherwise will need to build it via SPIR-V
if (coop_mat_access.stride_value != 0) {
accessed_type_size = coop_mat_access.Size();
} else if (coop_mat_access.is_row_major) {
// equation: ((rows - 1) * stride + columns) * component_size
const uint32_t rows_m1_id = type_manager_.GetConstantUInt32(coop_mat_access.rows - 1).Id();
const uint32_t columns_id = type_manager_.GetConstantUInt32(coop_mat_access.columns).Id();
const uint32_t component_size_id = type_manager_.GetConstantUInt32(coop_mat_access.component_size).Id();
uint32_t x1 = module_.TakeNextId();
uint32_t x2 = module_.TakeNextId();
sum_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIMul, {uint32_type_id, x1, rows_m1_id, coop_mat_access.stride_id}, inst_it);
block.CreateInstruction(spv::OpIAdd, {uint32_type_id, x2, x1, columns_id}, inst_it);
block.CreateInstruction(spv::OpIMul, {uint32_type_id, sum_id, x2, component_size_id}, inst_it);
} else {
// equation: ((columns - 1) * stride_value + rows) * component_size;
const uint32_t columns_m1_id = type_manager_.GetConstantUInt32(coop_mat_access.columns - 1).Id();
const uint32_t row_id = type_manager_.GetConstantUInt32(coop_mat_access.rows).Id();
const uint32_t component_size_id = type_manager_.GetConstantUInt32(coop_mat_access.component_size).Id();
uint32_t x1 = module_.TakeNextId();
uint32_t x2 = module_.TakeNextId();
sum_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIMul, {uint32_type_id, x1, columns_m1_id, coop_mat_access.stride_id}, inst_it);
block.CreateInstruction(spv::OpIAdd, {uint32_type_id, x2, x1, row_id}, inst_it);
block.CreateInstruction(spv::OpIMul, {uint32_type_id, sum_id, x2, component_size_id}, inst_it);
}
} else {
accessed_type_size = FindTypeByteSize(current_type_id, matrix_stride, col_major, in_matrix);
}
const uint32_t last_byte_index = accessed_type_size - 1;
const uint32_t last_byte_index_id = type_manager_.GetConstantUInt32(last_byte_index).Id();
const uint32_t new_sum_id = module_.TakeNextId();
block.CreateInstruction(spv::OpIAdd, {uint32_type_id, new_sum_id, sum_id, last_byte_index_id}, inst_it);
return new_sum_id;
}
// Finds the upper bound offset into the struct an instruction would access
// If it is a non-constant value, will return zero to indicate its a runtime value
//
// If shader looks for 'b' in a descriptor like
//
// struct X {
// uint a;
// uint b;
// uint c;
// }
//
// it will return `7` because it covers [4, 7] bytes of the descriptor
// (This matches the GetLastByte() check)
uint32_t Pass::FindOffsetInStruct(uint32_t struct_id, const CooperativeMatrixAccess* coop_mat_access, bool is_descriptor_array,
const std::vector<const Instruction*>& access_chain_insts) const {
assert(!access_chain_insts.empty());
uint32_t last_byte_offset = 0;
const uint32_t reset_ac_word = 4; // points to first "Index" operand of an OpAccessChain
uint32_t ac_word_index = reset_ac_word;
if (is_descriptor_array) {
ac_word_index++; // this jumps over the array of descriptors so we first start on the descriptor itself
}
uint32_t known_access_size = 0;
if (coop_mat_access && coop_mat_access->used) {
known_access_size = coop_mat_access->Size();
if (known_access_size == 0) {
// If size is known to be dynamic don't spend more time
return 0;
}
}
uint32_t matrix_stride = 0;
bool col_major = false;
bool in_matrix = false;
auto access_chain_iter = access_chain_insts.rbegin();
// This occurs in things like Slang where they have a single OpAccessChain for the descriptor
// (GLSL/HLSL will combine 2 indexes into the last OpAccessChain)
if (ac_word_index >= (*access_chain_iter)->Length()) {
++access_chain_iter;
ac_word_index = reset_ac_word;
}
uint32_t current_type_id = struct_id;
// Walk down access chains to build up the offset
while (access_chain_iter != access_chain_insts.rend()) {
const uint32_t ac_index_id = (*access_chain_iter)->Word(ac_word_index);
const Constant* index_constant = type_manager_.FindConstantById(ac_index_id);
if (!index_constant || index_constant->inst_.Opcode() != spv::OpConstant) {
return 0; // Access Chain has dynamic value
}
const uint32_t constant_value = index_constant->GetValueUint32();
uint32_t current_offset = 0;
const Type* current_type = type_manager_.FindTypeById(current_type_id);
switch (current_type->spv_type_) {
case SpvType::kArray:
case SpvType::kRuntimeArray: {
// Get array stride and multiply by current index
const uint32_t array_stride = GetDecoration(current_type_id, spv::DecorationArrayStride)->Word(3);
current_offset = constant_value * array_stride;
current_type_id = current_type->inst_.Operand(0); // Get element type for next step
} break;
case SpvType::kMatrix: {
if (matrix_stride == 0) {
module_.InternalError(Name(), "FindOffsetInStruct is missing matrix stride");
}
in_matrix = true;
uint32_t vec_type_id = current_type->inst_.Operand(0);
// If column major, multiply column index by matrix stride, otherwise by vector component size and save matrix
// stride for vector (row) index
uint32_t col_stride = 0;
if (col_major) {
col_stride = matrix_stride;
} else {
const uint32_t component_type_id = type_manager_.FindTypeById(vec_type_id)->inst_.Operand(0);
col_stride = FindTypeByteSize(component_type_id);
}
current_offset = constant_value * col_stride;
current_type_id = vec_type_id; // Get element type for next step
} break;
case SpvType::kVector: {
// If inside a row major matrix type, multiply index by matrix stride,
// else multiply by component size
const uint32_t component_type_id = current_type->inst_.Operand(0);
if (in_matrix && !col_major) {
current_offset = constant_value * matrix_stride;
} else {
const uint32_t component_type_size = FindTypeByteSize(component_type_id);
current_offset = constant_value * component_type_size;
}
current_type_id = component_type_id; // Get element type for next step
} break;
case SpvType::kStruct: {
// Get buffer byte offset for the referenced member
current_offset = GetMemberDecoration(current_type_id, constant_value, spv::DecorationOffset)->Word(4);
// Look for matrix stride for this member if there is one. The matrix
// stride is not on the matrix type, but in a OpMemberDecorate on the
// enclosing struct type at the member index. If none is found, reset
// stride to 0.
const Instruction* decoration_matrix_stride =
GetMemberDecoration(current_type_id, constant_value, spv::DecorationMatrixStride);
matrix_stride = decoration_matrix_stride ? decoration_matrix_stride->Word(4) : 0;
const Instruction* decoration_col_major =
GetMemberDecoration(current_type_id, constant_value, spv::DecorationColMajor);
col_major = decoration_col_major != nullptr;
current_type_id = current_type->inst_.Operand(constant_value); // Get element type for next step
} break;
case SpvType::kPointer: {
// So what this means is we have a pointer of structs, found in Slang doing something like
//
// struct PerFrame {};
// PerFrame* bufPerFrame;
//
// And the first PtrAccessChain is going to try and index into.
// There should be ArrayStride to know how large the Struct is
const uint32_t array_stride = GetDecoration(current_type_id, spv::DecorationArrayStride)->Word(3);
current_offset = constant_value * array_stride;
current_type_id = current_type->inst_.Word(3); // Get pointer type
} break;
default: {
module_.InternalError(Name(), "FindOffsetInStruct has unexpected non-composite type");
} break;
}
last_byte_offset += current_offset;
ac_word_index++;
if (ac_word_index >= (*access_chain_iter)->Length()) {
++access_chain_iter;
ac_word_index = reset_ac_word;
}
}
// Add in offset of last byte of referenced object
// Things like CooperativeMatrix will have the size calculated already
const uint32_t accessed_type_size =
(known_access_size != 0) ? known_access_size : FindTypeByteSize(current_type_id, matrix_stride, col_major, in_matrix);
const uint32_t last_byte_index = accessed_type_size - 1;
last_byte_offset += last_byte_index;
return last_byte_offset;
}
// Unlike a normal load/store where we get the size by looking at the type that is loaded/stored,
// With CoopMat, we need both the OpTypeCooperativeMatrixKHR and the OpCooperativeMatrixLoadKHR/OpCooperativeMatrixStoreKHR together
// to calculate the access size.
CooperativeMatrixAccess Pass::GetCooperativeMatrixAccess(const Instruction& inst, const Function& function) const {
CooperativeMatrixAccess info;
// TODO - When adding Coop Mat to Descriptor Indexing, will likely want a better way to signal things than this bool
info.used = true;
info.is_load = inst.Opcode() == spv::OpCooperativeMatrixLoadKHR; // else is store
// For stores, we assume the Object operand points to a load to get the type
uint32_t coop_mat_type_id = info.is_load ? inst.TypeId() : function.FindInstruction(inst.Word(2))->TypeId();
info.type = type_manager_.FindTypeById(coop_mat_type_id);
assert(info.type && info.type->spv_type_ == SpvType::kCooperativeMatrixKHR);
// Currently we don't save/cache the size of each type because we still need to extract the rows/column info. This the tradeoff
// of having a simplified single Type class
const Type* component_type = type_manager_.FindTypeById(info.type->inst_.Word(2));
info.component_size = type_manager_.TypeLength(*component_type);
const Constant* rows_const = type_manager_.FindConstantById(info.type->inst_.Word(4));
const Constant* columns_const = type_manager_.FindConstantById(info.type->inst_.Word(5));
// TODO - Need to handle spec constant here, for now return zero to have things not blowup
assert(rows_const && !rows_const->is_spec_constant_ && columns_const && !columns_const->is_spec_constant_);
info.rows = rows_const->inst_.Operand(0);
info.columns = columns_const->inst_.Operand(0);
info.stride_id = info.is_load ? inst.Word(5) : inst.Word(4);
if (const Constant* stride = type_manager_.FindConstantById(info.stride_id)) {
info.stride_value = stride->inst_.Operand(0);
} else {
info.stride_value = 0;
}
const uint32_t memory_layout_id = info.is_load ? inst.Word(4) : inst.Word(3);
const Constant* memory_layout = type_manager_.FindConstantById(memory_layout_id);
assert(memory_layout && !memory_layout->is_spec_constant_);
const uint32_t memory_layout_value = memory_layout->inst_.Operand(0);
info.is_row_major = memory_layout_value == spv::CooperativeMatrixLayoutRowMajorKHR;
assert(info.is_row_major || memory_layout_value == spv::CooperativeMatrixLayoutColumnMajorKHR);
return info;
}
// Generate code to convert integer id to 32bit, if needed.
uint32_t Pass::ConvertTo32(uint32_t id, BasicBlock& block, InstructionIt* inst_it) const {
// Find type doing the indexing into the access chain
const Type* type = nullptr;
const Constant* constant = type_manager_.FindConstantById(id);
if (constant) {
type = &constant->type_;
} else {
const Instruction* inst = block.function_->FindInstruction(id);
if (inst) {
type = type_manager_.FindTypeById(inst->TypeId());
}
}
if (!type) {
return id;
}
assert(type->spv_type_ == SpvType::kInt);
if (type->inst_.Word(2) == 32) {
return id;
}
const bool is_signed = type->inst_.Word(3) != 0;
const uint32_t new_id = module_.TakeNextId();
const Type& uint32_type = type_manager_.GetTypeInt(32, false);
if (is_signed) {
block.CreateInstruction(spv::OpSConvert, {uint32_type.Id(), new_id, id}, inst_it);
} else {
block.CreateInstruction(spv::OpUConvert, {uint32_type.Id(), new_id, id}, inst_it);
}
return new_id; // Return an id to the 32bit equivalent.
}
// Generate code to cast integer it to 32bit unsigned, if needed.
uint32_t Pass::CastToUint32(uint32_t id, BasicBlock& block, InstructionIt* inst_it) const {
// Convert value to 32-bit if necessary
uint32_t int32_id = ConvertTo32(id, block, inst_it);
const Type* type = nullptr;
const Constant* constant = type_manager_.FindConstantById(int32_id);
if (constant) {
type = &constant->type_;
} else {
const Instruction* inst = block.function_->FindInstruction(int32_id);
if (inst) {
type = type_manager_.FindTypeById(inst->TypeId());
}
}
if (!type) {
return int32_id;
}
assert(type->spv_type_ == SpvType::kInt);
const bool is_signed = type->inst_.Word(3) != 0;
if (!is_signed) {
return int32_id;
}
const Type& uint32_type = type_manager_.GetTypeInt(32, false);
const uint32_t new_id = module_.TakeNextId();
block.CreateInstruction(spv::OpBitcast, {uint32_type.Id(), new_id, int32_id}, inst_it);
return new_id; // Return an id to the Uint equivalent.
}
InstructionIt Pass::FindTargetInstruction(BasicBlock& block, const Instruction& target_instruction) const {
const uint32_t target_id = target_instruction.ResultId();
for (auto inst_it = block.instructions_.begin(); inst_it != block.instructions_.end(); ++inst_it) {
// This has to re-loop the entire block to find the instruction, using the ResultID, we can quickly compare
if ((*inst_it)->ResultId() == target_id) {
// Things like OpStore will have a result id of zero, so need to do deep instruction comparison
if (*(*inst_it) == target_instruction) {
return inst_it;
}
}
}
module_.InternalError(Name(), "failed to find instruction");
return block.instructions_.end();
}
bool Pass::IsMaxInstrumentationsCount() const {
return (module_.settings_.max_instrumentations_count != 0) &&
(instrumentations_count_ >= module_.settings_.max_instrumentations_count);
}
// A type of common pass that will inject a function call and link it up later,
// We will have wrap the checks to be safe from bad values crashing things
// For OpStore we will just ignore the store if it is invalid, example:
// Before:
// bda.data[index] = value;
// After:
// if (isValid(bda.data, index)) {
// bda.data[index] = value;
// }
//
// For OpLoad we replace the value with Zero (via Phi node) if it is invalid, example
// Before:
// int X = bda.data[index];
// int Y = bda.data[X];
// After:
// if (isValid(bda.data, index)) {
// int X = bda.data[index];
// } else {
// int X = 0;
// }
// if (isValid(bda.data, X)) {
// int Y = bda.data[X];
// } else {
// int Y = 0;
// }
InjectConditionalData Pass::InjectFunctionPre(Function& function, const BasicBlockIt original_block_it, InstructionIt inst_it) {
// We turn the block into 4 separate blocks
BasicBlock& original_block = **original_block_it;
const uint32_t original_label = original_block.GetLabelId();
// Where we call targeted instruction if it is valid
BasicBlockIt valid_block_it = function.InsertNewBlock(original_block_it);
BasicBlock& valid_block = **valid_block_it;
const uint32_t valid_block_label = valid_block.GetLabelId();
// will be an empty block, used for the Phi node, even if no result, create for simplicity
BasicBlockIt invalid_block_it = function.InsertNewBlock(valid_block_it);
BasicBlock& invalid_block = **invalid_block_it;
const uint32_t invalid_block_label = invalid_block.GetLabelId();
// All the remaining block instructions after targeted instruction
BasicBlockIt merge_block_it = function.InsertNewBlock(invalid_block_it);
BasicBlock& merge_block = **merge_block_it;
const uint32_t merge_block_label = merge_block.GetLabelId();
// need to preserve the control-flow of how things, like a OpPhi, are accessed from a predecessor block
function.ReplaceAllUsesWith(original_label, merge_block_label);
// Move the targeted instruction to a valid block
const Instruction& target_inst = *valid_block.instructions_.emplace_back(std::move(*inst_it));
inst_it = original_block.instructions_.erase(inst_it);
valid_block.CreateInstruction(spv::OpBranch, {merge_block_label});
// If thre is a result, we need to create an additional BasicBlock to hold the |else| case, then after we create a Phi node to
// hold the result
const uint32_t target_inst_id = target_inst.ResultId();
if (target_inst_id != 0) {
const uint32_t phi_id = module_.TakeNextId();
const Type& phi_type = *type_manager_.FindTypeById(target_inst.TypeId());
uint32_t null_id = 0;
// Can't create ConstantNull of pointer type, so convert uint64 zero to pointer
if (phi_type.spv_type_ == SpvType::kPointer) {
const Type& uint64_type = type_manager_.GetTypeInt(64, false);
const Constant& null_constant = type_manager_.GetConstantNull(uint64_type);
null_id = module_.TakeNextId();
// We need to put any intermittent instructions here so Phi is first in the merge block
invalid_block.CreateInstruction(spv::OpConvertUToPtr, {phi_type.Id(), null_id, null_constant.Id()});
module_.AddCapability(spv::CapabilityInt64);
} else {
if ((phi_type.spv_type_ == SpvType::kInt || phi_type.spv_type_ == SpvType::kFloat) && phi_type.inst_.Word(2) < 32) {
// You can't make a constant of a 8-int, 16-int, 16-float without having the capability
// The only way this situation occurs if they use something like
// OpCapability StorageBuffer8BitAccess
// but there is not explicit Int8
// It should be more than safe to inject it for them
spv::Capability capability = (phi_type.spv_type_ == SpvType::kFloat) ? spv::CapabilityFloat16
: (phi_type.inst_.Word(2) == 16) ? spv::CapabilityInt16
: spv::CapabilityInt8;
module_.AddCapability(capability);
}
null_id = type_manager_.GetConstantNull(phi_type).Id();
}
// replace before creating instruction, otherwise will over-write itself
function.ReplaceAllUsesWith(target_inst_id, phi_id);
merge_block.CreateInstruction(spv::OpPhi,
{phi_type.Id(), phi_id, target_inst_id, valid_block_label, null_id, invalid_block_label});
}
// When skipping some instructions, we need something valid to replace it
if (target_inst.Opcode() == spv::OpRayQueryInitializeKHR) {
// Currently assume the RayQuery and AS object were valid already
const uint32_t uint32_0_id = type_manager_.GetConstantZeroUint32().Id();
const uint32_t float32_0_id = type_manager_.GetConstantZeroFloat32().Id();
const uint32_t vec3_0_id = type_manager_.GetConstantZeroVec3().Id();
invalid_block.CreateInstruction(spv::OpRayQueryInitializeKHR,
{target_inst.Operand(0), target_inst.Operand(1), uint32_0_id, uint32_0_id, vec3_0_id,
float32_0_id, vec3_0_id, float32_0_id});
} else if (target_inst.Opcode() == spv::OpSetMeshOutputsEXT) {
// TODO - Setup a callback system so MeshShading pass can set this as we should use the max values from the ExecutionMode
// instead, but don't want to be storing eveything in the Pass class as a global
const uint32_t three_id = type_manager_.CreateConstantUInt32(3).Id();
const uint32_t one_id = type_manager_.GetConstantOneUint32().Id();
invalid_block.CreateInstruction(spv::OpSetMeshOutputsEXT, {three_id, one_id});
}
invalid_block.CreateInstruction(spv::OpBranch, {merge_block_label});
// move all remaining instructions to the newly created merge block
merge_block.instructions_.insert(merge_block.instructions_.end(), std::make_move_iterator(inst_it),
std::make_move_iterator(original_block.instructions_.end()));
original_block.instructions_.erase(inst_it, original_block.instructions_.end());
return InjectConditionalData{merge_block_label, valid_block_label, invalid_block_label, 0, merge_block_it};
}
void Pass::InjectFunctionPost(BasicBlock& original_block, const InjectConditionalData& ic_data) {
original_block.CreateInstruction(spv::OpSelectionMerge, {ic_data.merge_block_label, spv::SelectionControlMaskNone});
original_block.CreateInstruction(spv::OpBranchConditional,
{ic_data.function_result_id, ic_data.valid_block_label, ic_data.invalid_block_label});
}
void Pass::ControlFlow::Update(const BasicBlock& block) {
if (in_loop) {
if (block.GetLabelId() == merge_target_id) {
in_loop = false;
merge_target_id = 0;
}
} else if (block.IsLoopHeader()) {
in_loop = true;
merge_target_id = block.loop_header_merge_target_;
}
}
// Helper for passes with multiple linked functions they may grab
// Pass in cached link_function_id and only update it the first time
uint32_t Pass::GetLinkFunction(uint32_t& link_function_id, const OfflineFunction& offline) {
if (link_function_id == 0) {
link_function_id = module_.TakeNextId();
link_info_.functions.emplace_back(LinkFunction{offline, link_function_id});
}
return link_function_id;
}
void DescriptroIndexPushConstantAccess::Update(const Module& module, InstructionIt inst_it) {
if (!(*inst_it)->IsNonPtrAccessChain()) {
return;
}
const Variable* pc_variable = module.type_manager_.FindPushConstantVariable();
if (!pc_variable) {
return; // shader doesn't use Push Constant
}
if ((*inst_it)->Operand(0) != pc_variable->Id()) {
return; // Access chain is not aimmed at the Push Constant
}
const Constant* member_index_constant = module.type_manager_.FindConstantById((*inst_it)->Operand(1));
if (!member_index_constant) {
return; // dynamic access into Push Constant (which is crazy and not likely)
}
const uint32_t found_member_index = member_index_constant->Id();
// We save memory/time tracking every instruction and know from viewing SPIR-V this pattern always will look like
// %a = OpAccessChain %ptr %pc %uint_x
// %b = OpLoad %uint %a
// %c = OpIAdd %uint %b %uint_y (optional)
//
// We use this and just do a quick look ahead for load
const uint32_t access_chain_id = (*inst_it)->ResultId();
inst_it++;
if ((*inst_it)->Opcode() != spv::OpLoad || (*inst_it)->Operand(0) != access_chain_id) {
return;
}
const Type* access_type = module.type_manager_.FindTypeById((*inst_it)->TypeId());
if (!access_type || access_type->spv_type_ != SpvType::kInt) {
return; // might be grabbing a uvec2 or float instead we want to ignore
}
uint32_t found_descriptor_index_id = (*inst_it)->ResultId();
uint32_t found_add_id_value = 0;
inst_it++;
if ((*inst_it)->Opcode() == spv::OpIAdd) {
const uint32_t add_0_id = (*inst_it)->Operand(0);
const uint32_t add_1_id = (*inst_it)->Operand(1);
// Might be (pc + constant) or (constant + pc)
if (add_0_id == found_descriptor_index_id) {
found_add_id_value = add_1_id;
} else if (add_1_id == found_descriptor_index_id) {
found_add_id_value = add_0_id;
} else {
return; // we have hit a strange case and rather be safe and exit
}
found_descriptor_index_id = (*inst_it)->ResultId();
}
next_alias_id = found_descriptor_index_id;
if (add_id_value != found_add_id_value || member_index != found_member_index) {
// First time seeing the Push Constant, set starting values.
// Also if found a new uint being used, need to reset.
descriptor_index_id = found_descriptor_index_id;
add_id_value = found_add_id_value;
member_index = found_member_index;
}
}
bool FunctionDuplicateTracker::FindAndUpdate(BlockDuplicateTracker& block, uint32_t hash) {
// Subtle, but important, if you have
//
// inst_post_process(hash) A
// if (x)
// inst_post_process(hash) B
// if (x)
// inst_post_process(hash) C
//
// A, B, and C are the same, we will be adding the hash here still for B, but never add the actual OpFunctionCall, then C will
// detect the block B is in and also do the same. This means we create a Post-Dominated chain effect without having to store any
// list of some sort.
auto insert_pair = block.hashes.insert(hash);
if (!insert_pair.second) {
return true; // found in this block
}
// Here we look back and see if this block is post-dominated by something with same instrumentation already
if (block.merge_select_predecessor != 0) {
BlockDuplicateTracker& predecessor_tracker = blocks_[block.merge_select_predecessor];
if (predecessor_tracker.hashes.find(hash) != predecessor_tracker.hashes.end()) {
return true;
}
}
if (block.branch_conditional_predecessor != 0) {
BlockDuplicateTracker& predecessor_tracker = blocks_[block.branch_conditional_predecessor];
if (predecessor_tracker.hashes.find(hash) != predecessor_tracker.hashes.end()) {
return true;
}
}
if (block.switch_cases_predecessor != 0) {
BlockDuplicateTracker& predecessor_tracker = blocks_[block.switch_cases_predecessor];
if (predecessor_tracker.hashes.find(hash) != predecessor_tracker.hashes.end()) {
return true;
}
}
return false;
}
// If the block is terminating, mark the post-dominated blocks
BlockDuplicateTracker& FunctionDuplicateTracker::GetAndUpdate(BasicBlock& block) {
const uint32_t current_block_id = block.GetLabelId();
if (block.selection_merge_target_) {
blocks_[block.selection_merge_target_].merge_select_predecessor = current_block_id;
}
if (block.branch_conditional_true_) {
blocks_[block.branch_conditional_true_].branch_conditional_predecessor = current_block_id;
}
if (block.branch_conditional_false_) {
blocks_[block.branch_conditional_false_].branch_conditional_predecessor = current_block_id;
}
if (block.switch_default_) {
blocks_[block.switch_default_].switch_cases_predecessor = current_block_id;
}
for (uint32_t switch_case_id : block.switch_cases_) {
blocks_[switch_case_id].switch_cases_predecessor = current_block_id;
}
return blocks_[current_block_id];
}
} // namespace spirv
} // namespace gpuav