blob: 52d0a6ea3982a8af15fabc92dde0f5bd3403ec87 [file] [log] [blame] [edit]
/*
* Copyright (c) 2023 Apple Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "config.h"
#include "EntryPointRewriter.h"
#include "AST.h"
#include "ASTVisitor.h"
#include "CallGraph.h"
#include "WGSL.h"
#include "WGSLShaderModule.h"
namespace WGSL {
class EntryPointRewriter {
public:
EntryPointRewriter(ShaderModule&, AST::Function&, AST::StageAttribute::Stage);
void rewrite();
Reflection::EntryPointInformation takeEntryPointInformation();
private:
struct MemberOrParameter {
AST::Identifier m_name;
AST::TypeName::Ref m_type;
AST::Attribute::List m_attributes;
};
enum class IsBuiltin {
No = 0,
Yes = 1,
};
static AST::TypeName& getResolvedType(AST::TypeName&);
void collectParameters();
void checkReturnType();
void constructInputStruct();
void materialize(Vector<String>& path, MemberOrParameter&, IsBuiltin);
void visit(Vector<String>& path, MemberOrParameter&&);
void appendBuiltins();
AST::StageAttribute::Stage m_stage;
ShaderModule& m_shaderModule;
AST::Function& m_function;
Vector<MemberOrParameter> m_builtins;
Vector<MemberOrParameter> m_parameters;
AST::Statement::List m_materializations;
String m_structTypeName;
String m_structParameterName;
Reflection::EntryPointInformation m_information;
};
EntryPointRewriter::EntryPointRewriter(ShaderModule& shaderModule, AST::Function& function, AST::StageAttribute::Stage stage)
: m_stage(stage)
, m_shaderModule(shaderModule)
, m_function(function)
{
switch (m_stage) {
case AST::StageAttribute::Stage::Compute:
m_information.typedEntryPoint = Reflection::Compute { 1, 1, 1 };
break;
case AST::StageAttribute::Stage::Vertex:
m_information.typedEntryPoint = Reflection::Vertex { false };
break;
case AST::StageAttribute::Stage::Fragment:
m_information.typedEntryPoint = Reflection::Fragment { };
break;
}
}
AST::TypeName& EntryPointRewriter::getResolvedType(AST::TypeName& type)
{
if (is<AST::NamedTypeName>(type)) {
if (auto* resolvedType = downcast<AST::NamedTypeName>(type).maybeResolvedReference())
return getResolvedType(*resolvedType);
}
return type;
}
void EntryPointRewriter::rewrite()
{
m_structTypeName = makeString("__", m_function.name(), "_inT");
m_structParameterName = makeString("__", m_function.name(), "_in");
collectParameters();
checkReturnType();
// nothing to rewrite
if (m_parameters.isEmpty()) {
appendBuiltins();
return;
}
constructInputStruct();
appendBuiltins();
// add parameter to builtins: ${structName} : ${structType}
m_function.parameters().append(makeUniqueRef<AST::Parameter>(
SourceSpan::empty(),
AST::Identifier::make(m_structParameterName),
adoptRef(*new AST::NamedTypeName(SourceSpan::empty(), AST::Identifier::make(m_structTypeName))),
AST::Attribute::List { },
AST::ParameterRole::StageIn
));
while (m_materializations.size())
m_function.body().statements().insert(0, m_materializations.takeLast());
}
Reflection::EntryPointInformation EntryPointRewriter::takeEntryPointInformation()
{
return WTFMove(m_information);
}
void EntryPointRewriter::collectParameters()
{
while (m_function.parameters().size()) {
auto parameter = m_function.parameters().takeLast();
Vector<String> path;
visit(path, MemberOrParameter { parameter->name(), parameter->typeName(), WTFMove(parameter->attributes()) });
}
}
void EntryPointRewriter::checkReturnType()
{
if (m_stage != AST::StageAttribute::Stage::Vertex)
return;
// FIXME: we might have to duplicate this struct if it has other uses
if (auto* maybeReturnType = m_function.maybeReturnType()) {
auto& returnType = getResolvedType(*maybeReturnType);
if (is<AST::StructTypeName>(returnType)) {
auto& structDecl = downcast<AST::StructTypeName>(returnType).structure();
ASSERT(structDecl.role() == AST::StructureRole::UserDefined);
structDecl.setRole(AST::StructureRole::VertexOutput);
}
}
}
void EntryPointRewriter::constructInputStruct()
{
// insert `var ${parameter.name()} = ${structName}.${parameter.name()}`
AST::StructureMember::List structMembers;
for (auto& parameter : m_parameters) {
structMembers.append(makeUniqueRef<AST::StructureMember>(
SourceSpan::empty(),
WTFMove(parameter.m_name),
WTFMove(parameter.m_type),
WTFMove(parameter.m_attributes)
));
}
AST::StructureRole role;
switch (m_stage) {
case AST::StageAttribute::Stage::Compute:
role = AST::StructureRole::ComputeInput;
break;
case AST::StageAttribute::Stage::Vertex:
role = AST::StructureRole::VertexInput;
break;
case AST::StageAttribute::Stage::Fragment:
role = AST::StructureRole::FragmentInput;
break;
}
m_shaderModule.structures().append(makeUniqueRef<AST::Structure>(
SourceSpan::empty(),
AST::Identifier::make(m_structTypeName),
WTFMove(structMembers),
AST::Attribute::List { },
role
));
}
void EntryPointRewriter::materialize(Vector<String>& path, MemberOrParameter& data, IsBuiltin isBuiltin)
{
std::unique_ptr<AST::Expression> rhs;
if (isBuiltin == IsBuiltin::Yes)
rhs = makeUnique<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make(data.m_name));
else {
rhs = makeUnique<AST::FieldAccessExpression>(
SourceSpan::empty(),
makeUniqueRef<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make(m_structParameterName)),
AST::Identifier::make(data.m_name)
);
}
if (!path.size()) {
m_materializations.append(makeUniqueRef<AST::VariableStatement>(
SourceSpan::empty(),
makeUniqueRef<AST::Variable>(
SourceSpan::empty(),
AST::VariableFlavor::Var,
AST::Identifier::make(data.m_name),
nullptr, // TODO: do we need a VariableQualifier?
data.m_type.copyRef(),
WTFMove(rhs),
AST::Attribute::List { }
)
));
return;
}
path.append(data.m_name);
unsigned i = 0;
UniqueRef<AST::Expression> lhs = makeUniqueRef<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make(path[i++]));
while (i < path.size()) {
lhs = makeUniqueRef<AST::FieldAccessExpression>(
SourceSpan::empty(),
WTFMove(lhs),
AST::Identifier::make(path[i++])
);
}
path.removeLast();
m_materializations.append(makeUniqueRef<AST::AssignmentStatement>(
SourceSpan::empty(),
WTFMove(lhs),
makeUniqueRefFromNonNullUniquePtr(WTFMove(rhs))
));
}
void EntryPointRewriter::visit(Vector<String>& path, MemberOrParameter&& data)
{
auto& type = getResolvedType(data.m_type);
if (is<AST::StructTypeName>(type)) {
m_materializations.append(makeUniqueRef<AST::VariableStatement>(
SourceSpan::empty(),
makeUniqueRef<AST::Variable>(
SourceSpan::empty(),
AST::VariableFlavor::Var,
AST::Identifier::make(data.m_name),
nullptr,
&type,
nullptr,
AST::Attribute::List { }
)
));
path.append(data.m_name);
for (auto& member : downcast<AST::StructTypeName>(type).structure().members())
visit(path, MemberOrParameter { member.name(), member.type(), member.attributes() });
path.removeLast();
return;
}
bool isBuiltin = false;
for (auto& attribute : data.m_attributes) {
if (is<AST::BuiltinAttribute>(attribute)) {
isBuiltin = true;
break;
}
}
if (isBuiltin) {
// if path is empty, then it was already a parameter and there's nothing to do
if (!path.isEmpty())
materialize(path, data, IsBuiltin::Yes);
// builtin was hoisted from a struct into a parameter, we need to reconstruct the struct
// ${path}.${data.m_name} = ${data.name}
m_builtins.append(WTFMove(data));
return;
}
// parameter was moved into a struct, so we need to reload it
// ${path}.${data.m_name} = ${struct}.${data.name}
materialize(path, data, IsBuiltin::No);
m_parameters.append(WTFMove(data));
}
void EntryPointRewriter::appendBuiltins()
{
for (auto& data : m_builtins) {
m_function.parameters().append(makeUniqueRef<AST::Parameter>(
SourceSpan::empty(),
AST::Identifier::make(data.m_name),
WTFMove(data.m_type),
WTFMove(data.m_attributes),
AST::ParameterRole::UserDefined
));
}
}
void rewriteEntryPoints(CallGraph& callGraph, PrepareResult& result)
{
for (auto& entryPoint : callGraph.entrypoints()) {
EntryPointRewriter rewriter(callGraph.ast(), entryPoint.m_function, entryPoint.m_stage);
rewriter.rewrite();
auto addResult = result.entryPoints.add(entryPoint.m_function.name().id(), rewriter.takeEntryPointInformation());
ASSERT_UNUSED(addResult, addResult.isNewEntry);
}
}
} // namespace WGSL