| /* |
| * Copyright (c) 2023 Apple Inc. All rights reserved. |
| * |
| * Redistribution and use in source and binary forms, with or without |
| * modification, are permitted provided that the following conditions |
| * are met: |
| * 1. Redistributions of source code must retain the above copyright |
| * notice, this list of conditions and the following disclaimer. |
| * 2. Redistributions in binary form must reproduce the above copyright |
| * notice, this list of conditions and the following disclaimer in the |
| * documentation and/or other materials provided with the distribution. |
| * |
| * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY |
| * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR |
| * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR |
| * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, |
| * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, |
| * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR |
| * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY |
| * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| */ |
| |
| #include "config.h" |
| #include "GlobalVariableRewriter.h" |
| |
| #include "AST.h" |
| #include "ASTIdentifier.h" |
| #include "ASTVisitor.h" |
| #include "CallGraph.h" |
| #include "WGSL.h" |
| #include "WGSLShaderModule.h" |
| #include <wtf/DataLog.h> |
| #include <wtf/HashMap.h> |
| #include <wtf/HashSet.h> |
| #include <wtf/SetForScope.h> |
| |
| namespace WGSL { |
| |
| constexpr bool shouldLogGlobalVariableRewriting = false; |
| |
| class RewriteGlobalVariables : public AST::Visitor { |
| public: |
| RewriteGlobalVariables(CallGraph& callGraph, const HashMap<String, std::optional<PipelineLayout>>& pipelineLayouts, PrepareResult& result) |
| : AST::Visitor() |
| , m_callGraph(callGraph) |
| , m_result(result) |
| { |
| UNUSED_PARAM(pipelineLayouts); |
| } |
| |
| void run(); |
| |
| void visit(AST::Function&) override; |
| void visit(AST::Variable&) override; |
| |
| void visit(AST::CompoundStatement&) override; |
| void visit(AST::AssignmentStatement&) override; |
| |
| void visit(AST::Expression&) override; |
| |
| private: |
| enum class Context : uint8_t { Local, Global }; |
| |
| struct Global { |
| struct Resource { |
| unsigned group; |
| unsigned binding; |
| }; |
| |
| std::optional<Resource> resource; |
| AST::Variable* declaration; |
| }; |
| |
| template<typename Value> |
| using IndexMap = HashMap<unsigned, Value, WTF::IntHash<unsigned>, WTF::UnsignedWithZeroKeyHashTraits<unsigned>>; |
| |
| using UsedResources = IndexMap<IndexMap<Global*>>; |
| using UsedPrivateGlobals = Vector<Global*>; |
| |
| struct UsedGlobals { |
| UsedResources resources; |
| UsedPrivateGlobals privateGlobals; |
| }; |
| |
| struct Insertion { |
| AST::Statement* statement; |
| unsigned index; |
| }; |
| |
| static AST::Identifier argumentBufferParameterName(unsigned group); |
| static AST::Identifier argumentBufferStructName(unsigned group); |
| |
| void def(const AST::Identifier&, AST::Variable*); |
| |
| void collectGlobals(); |
| void visitEntryPoint(AST::Function&, AST::StageAttribute::Stage, PipelineLayout&); |
| void visitCallee(const CallGraph::Callee&); |
| UsedGlobals determineUsedGlobals(PipelineLayout&, AST::StageAttribute::Stage); |
| void usesOverride(AST::Variable&); |
| void insertStructs(const UsedResources&); |
| void insertParameters(AST::Function&, const UsedResources&); |
| void insertMaterializations(AST::Function&, const UsedResources&); |
| void insertLocalDefinitions(AST::Function&, const UsedPrivateGlobals&); |
| void readVariable(AST::IdentifierExpression&, AST::Variable&, Context); |
| void insertBeforeCurrentStatement(AST::Statement&); |
| |
| void packResource(AST::Variable&); |
| void packArrayResource(AST::Variable&, const Types::Array*); |
| void packStructResource(AST::Variable&, const Types::Struct*); |
| const Type* packStructType(const Types::Struct*); |
| void updateReference(AST::Variable&, AST::TypeName&); |
| |
| enum Packing : uint8_t { |
| Packed = 1 << 0, |
| Unpacked = 1 << 1, |
| Either = Packed | Unpacked, |
| }; |
| |
| Packing pack(Packing, AST::Expression&); |
| Packing getPacking(AST::IdentifierExpression&); |
| Packing getPacking(AST::FieldAccessExpression&); |
| Packing getPacking(AST::IndexAccessExpression&); |
| Packing getPacking(AST::BinaryExpression&); |
| Packing getPacking(AST::UnaryExpression&); |
| Packing getPacking(AST::CallExpression&); |
| Packing packingForType(const Type*); |
| |
| CallGraph& m_callGraph; |
| PrepareResult& m_result; |
| HashMap<String, Global> m_globals; |
| IndexMap<Vector<std::pair<unsigned, String>>> m_groupBindingMap; |
| IndexMap<const Type*> m_structTypes; |
| HashMap<String, AST::Variable*> m_defs; |
| HashSet<String> m_reads; |
| HashMap<AST::Function*, HashSet<String>> m_visitedFunctions; |
| Reflection::EntryPointInformation* m_entryPointInformation { nullptr }; |
| unsigned m_constantId { 0 }; |
| unsigned m_currentStatementIndex { 0 }; |
| Vector<Insertion> m_pendingInsertions; |
| HashMap<const Types::Struct*, const Type*> m_packedStructTypes; |
| }; |
| |
| void RewriteGlobalVariables::run() |
| { |
| dataLogLnIf(shouldLogGlobalVariableRewriting, "BEGIN: GlobalVariableRewriter"); |
| |
| collectGlobals(); |
| for (auto& entryPoint : m_callGraph.entrypoints()) { |
| PipelineLayout pipelineLayout; |
| auto it = m_result.entryPoints.find(entryPoint.function.name()); |
| RELEASE_ASSERT(it != m_result.entryPoints.end()); |
| m_entryPointInformation = &it->value; |
| |
| visitEntryPoint(entryPoint.function, entryPoint.stage, pipelineLayout); |
| |
| m_entryPointInformation->defaultLayout = WTFMove(pipelineLayout); |
| } |
| |
| dataLogLnIf(shouldLogGlobalVariableRewriting, "END: GlobalVariableRewriter"); |
| } |
| |
| void RewriteGlobalVariables::visitCallee(const CallGraph::Callee& callee) |
| { |
| const auto& updateCallee = [&] { |
| for (auto& read : m_reads) { |
| auto it = m_globals.find(read); |
| RELEASE_ASSERT(it != m_globals.end()); |
| auto& global = it->value; |
| m_callGraph.ast().append(callee.target->parameters(), m_callGraph.ast().astBuilder().construct<AST::Parameter>( |
| SourceSpan::empty(), |
| AST::Identifier::make(read), |
| *global.declaration->maybeReferenceType(), |
| AST::Attribute::List { }, |
| AST::ParameterRole::UserDefined |
| )); |
| } |
| }; |
| |
| const auto& updateCallSites = [&] { |
| for (auto& read : m_reads) { |
| for (auto& call : callee.callSites) { |
| m_callGraph.ast().append(call->arguments(), m_callGraph.ast().astBuilder().construct<AST::IdentifierExpression>( |
| SourceSpan::empty(), |
| AST::Identifier::make(read) |
| )); |
| } |
| } |
| }; |
| |
| auto it = m_visitedFunctions.find(callee.target); |
| if (it != m_visitedFunctions.end()) { |
| dataLogLnIf(shouldLogGlobalVariableRewriting, "> Already visited callee: ", callee.target->name()); |
| m_reads = it->value; |
| updateCallSites(); |
| return; |
| } |
| |
| dataLogLnIf(shouldLogGlobalVariableRewriting, "> Visiting callee: ", callee.target->name()); |
| |
| visit(*callee.target); |
| updateCallee(); |
| updateCallSites(); |
| |
| m_visitedFunctions.add(callee.target, m_reads); |
| } |
| |
| void RewriteGlobalVariables::visit(AST::Function& function) |
| { |
| HashSet<String> reads; |
| for (auto& callee : m_callGraph.callees(function)) { |
| visitCallee(callee); |
| reads.formUnion(WTFMove(m_reads)); |
| } |
| m_reads = WTFMove(reads); |
| m_defs.clear(); |
| |
| for (auto& parameter : function.parameters()) |
| def(parameter.name(), nullptr); |
| |
| // FIXME: detect when we shadow a global that a callee needs |
| visit(function.body()); |
| } |
| |
| void RewriteGlobalVariables::visit(AST::Variable& variable) |
| { |
| def(variable.name(), &variable); |
| AST::Visitor::visit(variable); |
| } |
| |
| void RewriteGlobalVariables::visit(AST::CompoundStatement& statement) |
| { |
| auto indexScope = SetForScope(m_currentStatementIndex, 0); |
| auto insertionScope = SetForScope(m_pendingInsertions, Vector<Insertion>()); |
| |
| for (auto& statement : statement.statements()) { |
| AST::Visitor::visit(statement); |
| ++m_currentStatementIndex; |
| } |
| |
| unsigned offset = 0; |
| for (auto& insertion : m_pendingInsertions) { |
| m_callGraph.ast().insert(statement.statements(), insertion.index + offset, AST::Statement::Ref(*insertion.statement)); |
| ++offset; |
| } |
| } |
| |
| void RewriteGlobalVariables::visit(AST::AssignmentStatement& statement) |
| { |
| Packing lhsPacking = pack(Packing::Either, statement.lhs()); |
| ASSERT(lhsPacking != Packing::Either); |
| Packing rhsPacking = pack(lhsPacking, statement.rhs()); |
| ASSERT_UNUSED(rhsPacking, lhsPacking == rhsPacking); |
| } |
| |
| void RewriteGlobalVariables::visit(AST::Expression& expression) |
| { |
| pack(Packing::Unpacked, expression); |
| } |
| |
| auto RewriteGlobalVariables::pack(Packing expectedPacking, AST::Expression& expression) -> Packing |
| { |
| const auto& visitAndReplace = [&](auto& expression) -> Packing { |
| auto packing = getPacking(expression); |
| if (expectedPacking & packing) |
| return packing; |
| |
| auto* type = expression.inferredType(); |
| if (auto* referenceType = std::get_if<Types::Reference>(type)) |
| type = referenceType->element; |
| ASCIILiteral operation; |
| if (std::holds_alternative<Types::Struct>(*type)) |
| operation = packing == Packing::Packed ? "__unpack"_s : "__pack"_s; |
| else if (std::holds_alternative<Types::Array>(*type)) { |
| if (packing == Packing::Packed) { |
| operation = "__unpack_array"_s; |
| m_callGraph.ast().setUsesUnpackArray(); |
| } else { |
| operation = "__pack_array"_s; |
| m_callGraph.ast().setUsesPackArray(); |
| } |
| } else { |
| ASSERT(std::holds_alternative<Types::Vector>(*type)); |
| auto& vector = std::get<Types::Vector>(*type); |
| ASSERT(std::holds_alternative<Types::Primitive>(*vector.element)); |
| switch (std::get<Types::Primitive>(*vector.element).kind) { |
| case Types::Primitive::AbstractInt: |
| case Types::Primitive::I32: |
| operation = packing == Packing::Packed ? "int3"_s : "packed_int3"_s; |
| break; |
| case Types::Primitive::U32: |
| operation = packing == Packing::Packed ? "uint3"_s : "packed_uint3"_s; |
| break; |
| case Types::Primitive::AbstractFloat: |
| case Types::Primitive::F32: |
| operation = packing == Packing::Packed ? "float3"_s : "packed_float3"_s; |
| break; |
| default: |
| RELEASE_ASSERT_NOT_REACHED(); |
| } |
| } |
| RELEASE_ASSERT(!operation.isNull()); |
| auto& callee = m_callGraph.ast().astBuilder().construct<AST::NamedTypeName>( |
| SourceSpan::empty(), |
| AST::Identifier::make(operation) |
| ); |
| callee.m_resolvedType = m_callGraph.ast().types().bottomType(); |
| auto& argument = m_callGraph.ast().astBuilder().construct<std::remove_cvref_t<decltype(expression)>>(expression); |
| auto& call = m_callGraph.ast().astBuilder().construct<AST::CallExpression>( |
| SourceSpan::empty(), |
| callee, |
| AST::Expression::List { argument } |
| ); |
| call.m_inferredType = argument.inferredType(); |
| m_callGraph.ast().replace(expression, call); |
| return static_cast<Packing>(Packing::Either ^ packing); |
| }; |
| |
| switch (expression.kind()) { |
| case AST::NodeKind::IdentifierExpression: |
| return visitAndReplace(downcast<AST::IdentifierExpression>(expression)); |
| case AST::NodeKind::FieldAccessExpression: |
| return visitAndReplace(downcast<AST::FieldAccessExpression>(expression)); |
| case AST::NodeKind::IndexAccessExpression: |
| return visitAndReplace(downcast<AST::IndexAccessExpression>(expression)); |
| case AST::NodeKind::BinaryExpression: |
| return visitAndReplace(downcast<AST::BinaryExpression>(expression)); |
| case AST::NodeKind::UnaryExpression: |
| return visitAndReplace(downcast<AST::UnaryExpression>(expression)); |
| case AST::NodeKind::CallExpression: |
| return visitAndReplace(downcast<AST::CallExpression>(expression)); |
| default: |
| AST::Visitor::visit(expression); |
| return Packing::Unpacked; |
| } |
| } |
| |
| auto RewriteGlobalVariables::getPacking(AST::IdentifierExpression& identifier) -> Packing |
| { |
| auto packing = Packing::Unpacked; |
| |
| auto def = m_defs.find(identifier.identifier()); |
| if (def != m_defs.end()) { |
| if (def->value) |
| readVariable(identifier, *def->value, Context::Local); |
| return packing; |
| } |
| |
| auto it = m_globals.find(identifier.identifier()); |
| if (it == m_globals.end()) |
| return packing; |
| readVariable(identifier, *it->value.declaration, Context::Global); |
| |
| if (it->value.resource.has_value()) |
| return packingForType(identifier.inferredType()); |
| |
| return packing; |
| } |
| |
| auto RewriteGlobalVariables::getPacking(AST::FieldAccessExpression& expression) -> Packing |
| { |
| auto basePacking = pack(Packing::Either, expression.base()); |
| if (basePacking & Packing::Unpacked) |
| return Packing::Unpacked; |
| auto* baseType = expression.base().inferredType(); |
| if (auto* referenceType = std::get_if<Types::Reference>(baseType)) |
| baseType = referenceType->element; |
| if (std::holds_alternative<Types::Vector>(*baseType)) |
| return Packing::Unpacked; |
| ASSERT(std::holds_alternative<Types::Struct>(*baseType)); |
| auto& structType = std::get<Types::Struct>(*baseType); |
| auto* fieldType = structType.fields.get(expression.fieldName()); |
| return packingForType(fieldType); |
| } |
| |
| auto RewriteGlobalVariables::getPacking(AST::IndexAccessExpression& expression) -> Packing |
| { |
| auto basePacking = pack(Packing::Either, expression.base()); |
| if (basePacking & Packing::Unpacked) |
| return Packing::Unpacked; |
| auto* baseType = expression.base().inferredType(); |
| if (auto* referenceType = std::get_if<Types::Reference>(baseType)) |
| baseType = referenceType->element; |
| if (std::holds_alternative<Types::Vector>(*baseType)) |
| return Packing::Unpacked; |
| ASSERT(std::holds_alternative<Types::Array>(*baseType)); |
| auto& arrayType = std::get<Types::Array>(*baseType); |
| return packingForType(arrayType.element); |
| } |
| |
| auto RewriteGlobalVariables::getPacking(AST::BinaryExpression& expression) -> Packing |
| { |
| pack(Packing::Unpacked, expression.leftExpression()); |
| pack(Packing::Unpacked, expression.rightExpression()); |
| return Packing::Unpacked; |
| } |
| |
| auto RewriteGlobalVariables::getPacking(AST::UnaryExpression& expression) -> Packing |
| { |
| pack(Packing::Unpacked, expression.expression()); |
| return Packing::Unpacked; |
| } |
| |
| auto RewriteGlobalVariables::getPacking(AST::CallExpression& call) -> Packing |
| { |
| for (auto& argument : call.arguments()) |
| pack(Packing::Unpacked, argument); |
| return Packing::Unpacked; |
| } |
| |
| auto RewriteGlobalVariables::packingForType(const Type* type) -> Packing |
| { |
| if (auto* referenceType = std::get_if<Types::Reference>(type)) |
| return packingForType(referenceType->element); |
| |
| if (auto* structType = std::get_if<Types::Struct>(type)) { |
| if (structType->structure.role() == AST::StructureRole::UserDefinedResource) |
| return Packing::Packed; |
| } else if (auto* vectorType = std::get_if<Types::Vector>(type)) { |
| if (vectorType->size == 3) |
| return Packing::Packed; |
| } else if (auto* arrayType = std::get_if<Types::Array>(type)) |
| return packingForType(arrayType->element); |
| |
| return Packing::Unpacked; |
| } |
| |
| void RewriteGlobalVariables::collectGlobals() |
| { |
| auto& globalVars = m_callGraph.ast().variables(); |
| for (auto& globalVar : globalVars) { |
| std::optional<unsigned> group; |
| std::optional<unsigned> binding; |
| for (auto& attribute : globalVar.attributes()) { |
| if (is<AST::GroupAttribute>(attribute)) { |
| group = { *AST::extractInteger(downcast<AST::GroupAttribute>(attribute).group()) }; |
| continue; |
| } |
| if (is<AST::BindingAttribute>(attribute)) { |
| binding = { *AST::extractInteger(downcast<AST::BindingAttribute>(attribute).binding()) }; |
| continue; |
| } |
| } |
| |
| std::optional<Global::Resource> resource; |
| if (group.has_value()) { |
| RELEASE_ASSERT(binding.has_value()); |
| resource = { *group, *binding }; |
| } |
| |
| dataLogLnIf(shouldLogGlobalVariableRewriting, "> Found global: ", globalVar.name(), ", isResource: ", resource.has_value() ? "yes" : "no"); |
| |
| auto result = m_globals.add(globalVar.name(), Global { |
| resource, |
| &globalVar |
| }); |
| ASSERT_UNUSED(result, result.isNewEntry); |
| |
| if (resource.has_value()) { |
| auto result = m_groupBindingMap.add(resource->group, Vector<std::pair<unsigned, String>>()); |
| result.iterator->value.append({ resource->binding, globalVar.name() }); |
| packResource(globalVar); |
| } |
| } |
| } |
| |
| void RewriteGlobalVariables::packResource(AST::Variable& global) |
| { |
| auto* maybeTypeName = global.maybeTypeName(); |
| ASSERT(maybeTypeName); |
| |
| auto* resolvedType = maybeTypeName->resolvedType(); |
| if (auto* arrayType = std::get_if<Types::Array>(resolvedType)) { |
| packArrayResource(global, arrayType); |
| return; |
| } |
| |
| if (auto* structType = std::get_if<Types::Struct>(resolvedType)) { |
| packStructResource(global, structType); |
| return; |
| } |
| } |
| |
| void RewriteGlobalVariables::packStructResource(AST::Variable& global, const Types::Struct* structType) |
| { |
| const Type* packedStructType = packStructType(structType); |
| auto& packedType = m_callGraph.ast().astBuilder().construct<AST::NamedTypeName>( |
| SourceSpan::empty(), |
| AST::Identifier::make(std::get<Types::Struct>(*packedStructType).structure.name().id()) |
| ); |
| packedType.m_resolvedType = packedStructType; |
| auto& namedTypeName = downcast<AST::NamedTypeName>(*global.maybeTypeName()); |
| m_callGraph.ast().replace(namedTypeName, packedType); |
| updateReference(global, packedType); |
| } |
| |
| void RewriteGlobalVariables::packArrayResource(AST::Variable& global, const Types::Array* arrayType) |
| { |
| auto* structType = std::get_if<Types::Struct>(arrayType->element); |
| if (!structType) |
| return; |
| |
| const Type* packedStructType = packStructType(structType); |
| auto& packedType = m_callGraph.ast().astBuilder().construct<AST::NamedTypeName>( |
| SourceSpan::empty(), |
| AST::Identifier::make(std::get<Types::Struct>(*packedStructType).structure.name().id()) |
| ); |
| packedType.m_resolvedType = packedStructType; |
| |
| auto& arrayTypeName = downcast<AST::ArrayTypeName>(*global.maybeTypeName()); |
| auto& packedArrayTypeName = m_callGraph.ast().astBuilder().construct<AST::ArrayTypeName>( |
| arrayTypeName.span(), |
| &packedType, |
| arrayTypeName.maybeElementCount() |
| ); |
| packedArrayTypeName.m_resolvedType = m_callGraph.ast().types().arrayType(packedStructType, arrayType->size); |
| |
| m_callGraph.ast().replace(arrayTypeName, packedArrayTypeName); |
| updateReference(global, packedArrayTypeName); |
| } |
| |
| void RewriteGlobalVariables::updateReference(AST::Variable& global, AST::TypeName& packedType) |
| { |
| auto* maybeReference = global.maybeReferenceType(); |
| ASSERT(maybeReference); |
| ASSERT(is<AST::ReferenceTypeName>(*maybeReference)); |
| auto& reference = downcast<AST::ReferenceTypeName>(*maybeReference); |
| auto* referenceType = std::get_if<Types::Reference>(reference.resolvedType()); |
| ASSERT(referenceType); |
| auto& packedTypeReference = m_callGraph.ast().astBuilder().construct<AST::ReferenceTypeName>( |
| SourceSpan::empty(), |
| packedType |
| ); |
| packedTypeReference.m_resolvedType = m_callGraph.ast().types().referenceType( |
| referenceType->addressSpace, |
| packedType.resolvedType(), |
| referenceType->accessMode |
| ); |
| m_callGraph.ast().replace(reference, packedTypeReference); |
| } |
| |
| const Type* RewriteGlobalVariables::packStructType(const Types::Struct* structType) |
| { |
| if (structType->structure.role() == AST::StructureRole::UserDefinedResource) |
| return m_packedStructTypes.get(structType); |
| |
| ASSERT(structType->structure.role() == AST::StructureRole::UserDefined); |
| m_callGraph.ast().replace(&structType->structure.role(), AST::StructureRole::UserDefinedResource); |
| |
| String packedStructName = makeString("__", structType->structure.name(), "_Packed"); |
| auto& packedStruct = m_callGraph.ast().astBuilder().construct<AST::Structure>( |
| SourceSpan::empty(), |
| AST::Identifier::make(packedStructName), |
| AST::StructureMember::List(structType->structure.members()), |
| AST::Attribute::List { }, |
| AST::StructureRole::PackedResource, |
| &structType->structure |
| ); |
| m_callGraph.ast().append(m_callGraph.ast().structures(), packedStruct); |
| const Type* packedStructType = m_callGraph.ast().types().structType(packedStruct); |
| m_packedStructTypes.add(structType, packedStructType); |
| return packedStructType; |
| } |
| |
| void RewriteGlobalVariables::visitEntryPoint(AST::Function& function, AST::StageAttribute::Stage stage, PipelineLayout& pipelineLayout) |
| { |
| m_reads.clear(); |
| m_structTypes.clear(); |
| |
| dataLogLnIf(shouldLogGlobalVariableRewriting, "> Visiting entrypoint: ", function.name()); |
| |
| visit(function); |
| if (m_reads.isEmpty()) |
| return; |
| |
| auto usedGlobals = determineUsedGlobals(pipelineLayout, stage); |
| insertStructs(usedGlobals.resources); |
| insertParameters(function, usedGlobals.resources); |
| insertMaterializations(function, usedGlobals.resources); |
| insertLocalDefinitions(function, usedGlobals.privateGlobals); |
| } |
| |
| static BindGroupLayoutEntry::BindingMember bindingMemberForGlobal(auto& global) |
| { |
| auto* variable = global.declaration; |
| ASSERT(variable); |
| auto* maybeReference = variable->maybeReferenceType(); |
| auto* type = variable->storeType(); |
| ASSERT(type); |
| auto addressSpace = [&]() { |
| if (maybeReference) { |
| auto& reference = downcast<AST::ReferenceTypeName>(*maybeReference); |
| auto* referenceType = std::get_if<Types::Reference>(reference.resolvedType()); |
| if (referenceType && referenceType->addressSpace == AddressSpace::Storage) |
| return BufferBindingType::Storage; |
| } |
| |
| return BufferBindingType::Uniform; |
| }; |
| |
| using namespace WGSL::Types; |
| return WTF::switchOn(*type, [&](const Primitive& primitive) -> BindGroupLayoutEntry::BindingMember { |
| switch (primitive.kind) { |
| case Types::Primitive::AbstractInt: |
| case Types::Primitive::I32: |
| case Types::Primitive::U32: |
| case Types::Primitive::AbstractFloat: |
| case Types::Primitive::F32: |
| case Types::Primitive::Void: |
| case Types::Primitive::Bool: |
| return BufferBindingLayout { |
| .type = addressSpace(), |
| .hasDynamicOffset = false, |
| .minBindingSize = 0 |
| }; |
| case Types::Primitive::Sampler: |
| return SamplerBindingLayout { |
| .type = SamplerBindingType::Filtering |
| }; |
| case Types::Primitive::TextureExternal: |
| return ExternalTextureBindingLayout { }; |
| } |
| }, [&](const Vector& vector) -> BindGroupLayoutEntry::BindingMember { |
| auto* primitive = std::get_if<Primitive>(vector.element); |
| UNUSED_PARAM(primitive); |
| return BufferBindingLayout { |
| .type = addressSpace(), |
| .hasDynamicOffset = false, |
| .minBindingSize = 0 |
| }; |
| }, [&](const Matrix& matrix) -> BindGroupLayoutEntry::BindingMember { |
| UNUSED_PARAM(matrix); |
| return BufferBindingLayout { |
| .type = addressSpace(), |
| .hasDynamicOffset = false, |
| .minBindingSize = 0 |
| }; |
| }, [&](const Array& array) -> BindGroupLayoutEntry::BindingMember { |
| UNUSED_PARAM(array); |
| return BufferBindingLayout { |
| .type = addressSpace(), |
| .hasDynamicOffset = false, |
| .minBindingSize = 0 |
| }; |
| }, [&](const Struct& structure) -> BindGroupLayoutEntry::BindingMember { |
| UNUSED_PARAM(structure); |
| return BufferBindingLayout { |
| .type = addressSpace(), |
| .hasDynamicOffset = false, |
| .minBindingSize = 0 |
| }; |
| }, [&](const Texture& texture) -> BindGroupLayoutEntry::BindingMember { |
| TextureViewDimension viewDimension; |
| bool multisampled = false; |
| bool isStorageTexture = false; |
| switch (texture.kind) { |
| case Types::Texture::Kind::Texture1d: |
| viewDimension = TextureViewDimension::OneDimensional; |
| break; |
| case Types::Texture::Kind::Texture2d: |
| viewDimension = TextureViewDimension::TwoDimensional; |
| break; |
| case Types::Texture::Kind::Texture2dArray: |
| viewDimension = TextureViewDimension::TwoDimensionalArray; |
| break; |
| case Types::Texture::Kind::Texture3d: |
| viewDimension = TextureViewDimension::ThreeDimensional; |
| break; |
| case Types::Texture::Kind::TextureCube: |
| viewDimension = TextureViewDimension::Cube; |
| break; |
| case Types::Texture::Kind::TextureCubeArray: |
| viewDimension = TextureViewDimension::CubeArray; |
| break; |
| case Types::Texture::Kind::TextureMultisampled2d: |
| viewDimension = TextureViewDimension::TwoDimensional; |
| multisampled = true; |
| break; |
| |
| case Types::Texture::Kind::TextureStorage1d: |
| isStorageTexture = true; |
| viewDimension = TextureViewDimension::OneDimensional; |
| break; |
| case Types::Texture::Kind::TextureStorage2d: |
| isStorageTexture = true; |
| viewDimension = TextureViewDimension::TwoDimensional; |
| break; |
| case Types::Texture::Kind::TextureStorage2dArray: |
| isStorageTexture = true; |
| viewDimension = TextureViewDimension::TwoDimensionalArray; |
| break; |
| case Types::Texture::Kind::TextureStorage3d: |
| isStorageTexture = true; |
| viewDimension = TextureViewDimension::ThreeDimensional; |
| break; |
| } |
| |
| if (isStorageTexture) { |
| return StorageTextureBindingLayout { |
| .viewDimension = viewDimension |
| }; |
| } |
| |
| return TextureBindingLayout { |
| .sampleType = TextureSampleType::Float, |
| .viewDimension = viewDimension, |
| .multisampled = multisampled |
| }; |
| }, [&](const Reference&) -> BindGroupLayoutEntry::BindingMember { |
| RELEASE_ASSERT_NOT_REACHED(); |
| }, [&](const Function&) -> BindGroupLayoutEntry::BindingMember { |
| RELEASE_ASSERT_NOT_REACHED(); |
| }, [&](const Bottom&) -> BindGroupLayoutEntry::BindingMember { |
| RELEASE_ASSERT_NOT_REACHED(); |
| }); |
| } |
| |
| auto RewriteGlobalVariables::determineUsedGlobals(PipelineLayout& pipelineLayout, AST::StageAttribute::Stage stage) -> UsedGlobals |
| { |
| UsedGlobals usedGlobals; |
| for (const auto& globalName : m_reads) { |
| auto it = m_globals.find(globalName); |
| RELEASE_ASSERT(it != m_globals.end()); |
| auto& global = it->value; |
| AST::Variable& variable = *global.declaration; |
| switch (variable.flavor()) { |
| case AST::VariableFlavor::Override: |
| usesOverride(variable); |
| break; |
| case AST::VariableFlavor::Var: |
| case AST::VariableFlavor::Let: |
| case AST::VariableFlavor::Const: |
| if (!global.resource.has_value()) { |
| usedGlobals.privateGlobals.append(&global); |
| continue; |
| } |
| break; |
| } |
| |
| auto group = global.resource->group; |
| auto result = usedGlobals.resources.add(group, IndexMap<Global*>()); |
| result.iterator->value.add(global.resource->binding, &global); |
| |
| if (pipelineLayout.bindGroupLayouts.size() <= group) |
| pipelineLayout.bindGroupLayouts.grow(group + 1); |
| |
| ShaderStage shaderStage; |
| switch (stage) { |
| case AST::StageAttribute::Stage::Compute: |
| shaderStage = ShaderStage::Compute; |
| break; |
| case AST::StageAttribute::Stage::Vertex: |
| shaderStage = ShaderStage::Vertex; |
| break; |
| case AST::StageAttribute::Stage::Fragment: |
| shaderStage = ShaderStage::Fragment; |
| break; |
| } |
| pipelineLayout.bindGroupLayouts[group].entries.append({ |
| .binding = global.resource->binding, |
| .visibility = shaderStage, |
| .bindingMember = bindingMemberForGlobal(global) |
| }); |
| } |
| return usedGlobals; |
| } |
| |
| void RewriteGlobalVariables::usesOverride(AST::Variable& variable) |
| { |
| Reflection::SpecializationConstantType constantType; |
| const Type* type = variable.storeType(); |
| ASSERT(std::holds_alternative<Types::Primitive>(*type)); |
| const auto& primitive = std::get<Types::Primitive>(*type); |
| switch (primitive.kind) { |
| case Types::Primitive::Bool: |
| constantType = Reflection::SpecializationConstantType::Boolean; |
| break; |
| case Types::Primitive::F32: |
| constantType = Reflection::SpecializationConstantType::Float; |
| break; |
| case Types::Primitive::I32: |
| constantType = Reflection::SpecializationConstantType::Int; |
| break; |
| case Types::Primitive::U32: |
| constantType = Reflection::SpecializationConstantType::Unsigned; |
| break; |
| case Types::Primitive::Void: |
| case Types::Primitive::AbstractInt: |
| case Types::Primitive::AbstractFloat: |
| case Types::Primitive::Sampler: |
| case Types::Primitive::TextureExternal: |
| RELEASE_ASSERT_NOT_REACHED(); |
| } |
| m_entryPointInformation->specializationConstants.add(variable.name(), Reflection::SpecializationConstant { String(), constantType }); |
| } |
| |
| void RewriteGlobalVariables::insertStructs(const UsedResources& usedResources) |
| { |
| for (auto& groupBinding : m_groupBindingMap) { |
| unsigned group = groupBinding.key; |
| |
| auto usedResource = usedResources.find(group); |
| if (usedResource == usedResources.end()) |
| continue; |
| |
| const auto& bindingGlobalMap = groupBinding.value; |
| const IndexMap<Global*>& usedBindings = usedResource->value; |
| |
| AST::Identifier structName = argumentBufferStructName(group); |
| AST::StructureMember::List structMembers; |
| |
| for (auto [binding, globalName] : bindingGlobalMap) { |
| if (!usedBindings.contains(binding)) |
| continue; |
| |
| auto it = m_globals.find(globalName); |
| RELEASE_ASSERT(it != m_globals.end()); |
| auto& global = it->value; |
| ASSERT(global.declaration->maybeTypeName()); |
| auto span = global.declaration->span(); |
| structMembers.append(m_callGraph.ast().astBuilder().construct<AST::StructureMember>( |
| span, |
| AST::Identifier::make(global.declaration->name()), |
| *global.declaration->maybeReferenceType(), |
| AST::Attribute::List { |
| m_callGraph.ast().astBuilder().construct<AST::BindingAttribute>( |
| span, |
| m_callGraph.ast().astBuilder().construct<AST::AbstractIntegerLiteral>(span, binding) |
| ) |
| } |
| )); |
| } |
| |
| m_callGraph.ast().append(m_callGraph.ast().structures(), m_callGraph.ast().astBuilder().construct<AST::Structure>( |
| SourceSpan::empty(), |
| WTFMove(structName), |
| WTFMove(structMembers), |
| AST::Attribute::List { }, |
| AST::StructureRole::BindGroup |
| )); |
| m_structTypes.add(groupBinding.key, m_callGraph.ast().types().structType(m_callGraph.ast().structures().last())); |
| } |
| } |
| |
| void RewriteGlobalVariables::insertParameters(AST::Function& function, const UsedResources& usedResources) |
| { |
| auto span = function.span(); |
| for (auto& it : usedResources) { |
| unsigned group = it.key; |
| auto& type = m_callGraph.ast().astBuilder().construct<AST::NamedTypeName>(span, argumentBufferStructName(group)); |
| type.m_resolvedType = m_structTypes.get(group); |
| m_callGraph.ast().append(function.parameters(), m_callGraph.ast().astBuilder().construct<AST::Parameter>( |
| span, |
| argumentBufferParameterName(group), |
| type, |
| AST::Attribute::List { |
| m_callGraph.ast().astBuilder().construct<AST::GroupAttribute>( |
| span, |
| m_callGraph.ast().astBuilder().construct<AST::AbstractIntegerLiteral>(span, group) |
| ) |
| }, |
| AST::ParameterRole::BindGroup |
| )); |
| } |
| } |
| |
| void RewriteGlobalVariables::insertMaterializations(AST::Function& function, const UsedResources& usedResources) |
| { |
| auto span = function.span(); |
| for (auto& [group, bindings] : usedResources) { |
| auto& argument = m_callGraph.ast().astBuilder().construct<AST::IdentifierExpression>( |
| span, |
| AST::Identifier::make(argumentBufferParameterName(group)) |
| ); |
| |
| for (auto& [_, global] : bindings) { |
| auto& name = global->declaration->name(); |
| String fieldName = name; |
| auto* storeType = global->declaration->storeType(); |
| if (isPrimitive(storeType, Types::Primitive::TextureExternal)) { |
| fieldName = makeString("__", name); |
| m_callGraph.ast().setUsesExternalTextures(); |
| } |
| auto& access = m_callGraph.ast().astBuilder().construct<AST::FieldAccessExpression>( |
| SourceSpan::empty(), |
| argument, |
| AST::Identifier::make(WTFMove(fieldName)) |
| ); |
| auto& variable = m_callGraph.ast().astBuilder().construct<AST::Variable>( |
| SourceSpan::empty(), |
| AST::VariableFlavor::Let, |
| AST::Identifier::make(name), |
| nullptr, |
| global->declaration->maybeReferenceType(), |
| &access, |
| AST::Attribute::List { } |
| ); |
| auto& variableStatement = m_callGraph.ast().astBuilder().construct<AST::VariableStatement>( |
| SourceSpan::empty(), |
| variable |
| ); |
| m_callGraph.ast().insert(function.body().statements(), 0, AST::Statement::Ref(variableStatement)); |
| } |
| } |
| } |
| |
| void RewriteGlobalVariables::insertLocalDefinitions(AST::Function& function, const UsedPrivateGlobals& usedPrivateGlobals) |
| { |
| for (auto* global : usedPrivateGlobals) { |
| auto& variable = *global->declaration; |
| auto& variableStatement = m_callGraph.ast().astBuilder().construct<AST::VariableStatement>(SourceSpan::empty(), variable); |
| m_callGraph.ast().insert(function.body().statements(), 0, std::reference_wrapper<AST::Statement>(variableStatement)); |
| } |
| } |
| |
| void RewriteGlobalVariables::def(const AST::Identifier& name, AST::Variable* variable) |
| { |
| dataLogLnIf(shouldLogGlobalVariableRewriting, "> def: ", name, " at line:", name.span().line, " column: ", name.span().lineOffset); |
| m_defs.add(name, variable); |
| } |
| |
| void RewriteGlobalVariables::readVariable(AST::IdentifierExpression& identifier, AST::Variable& variable, Context context) |
| { |
| if (variable.flavor() != AST::VariableFlavor::Const) { |
| if (context == Context::Global) { |
| dataLogLnIf(shouldLogGlobalVariableRewriting, "> read global: ", identifier.identifier(), " at line:", identifier.span().line, " column: ", identifier.span().lineOffset); |
| m_reads.add(identifier.identifier()); |
| } |
| return; |
| } |
| |
| String newName = makeString("__const", String::number(++m_constantId)); |
| auto& newInitializer = m_callGraph.ast().astBuilder().construct<AST::IdentityExpression>( |
| variable.maybeInitializer()->span(), |
| *variable.maybeInitializer() |
| ); |
| newInitializer.m_inferredType = identifier.inferredType(); |
| auto& newVariable = m_callGraph.ast().astBuilder().construct<AST::Variable>( |
| variable.span(), |
| AST::VariableFlavor::Let, |
| AST::Identifier::make(newName), |
| nullptr, |
| variable.maybeTypeName(), |
| &newInitializer, |
| AST::Attribute::List { } |
| ); |
| |
| m_callGraph.ast().replace(&identifier.identifier(), AST::Identifier::make(newName)); |
| |
| auto& statement = m_callGraph.ast().astBuilder().construct<AST::VariableStatement>( |
| SourceSpan::empty(), |
| newVariable |
| ); |
| insertBeforeCurrentStatement(statement); |
| } |
| |
| void RewriteGlobalVariables::insertBeforeCurrentStatement(AST::Statement& statement) |
| { |
| m_pendingInsertions.append({ &statement, m_currentStatementIndex }); |
| } |
| |
| AST::Identifier RewriteGlobalVariables::argumentBufferParameterName(unsigned group) |
| { |
| return AST::Identifier::make(makeString("__ArgumentBufer_", String::number(group))); |
| } |
| |
| AST::Identifier RewriteGlobalVariables::argumentBufferStructName(unsigned group) |
| { |
| return AST::Identifier::make(makeString("__ArgumentBuferT_", String::number(group))); |
| } |
| |
| void rewriteGlobalVariables(CallGraph& callGraph, const HashMap<String, std::optional<PipelineLayout>>& pipelineLayouts, PrepareResult& result) |
| { |
| RewriteGlobalVariables(callGraph, pipelineLayouts, result).run(); |
| } |
| |
| } // namespace WGSL |