blob: 779e268d10b4d3f3d2c2a96f8d27a923b28e17be [file] [log] [blame] [edit]
/*
* Copyright (c) 2023 Apple Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "config.h"
#include "ConstantRewriter.h"
#include "AST.h"
#include "ASTStringDumper.h"
#include "ASTVisitor.h"
#include "ConstantFunctions.h"
#include "ContextProviderInlines.h"
#include "Overload.h"
#include "Types.h"
#include "WGSLShaderModule.h"
#include <wtf/DataLog.h>
#include <wtf/FixedVector.h>
namespace WGSL {
constexpr bool shouldDumpConstantValues = false;
class ConstantRewriter : public AST::Visitor, public ContextProvider<std::optional<ConstantValue>> {
public:
ConstantRewriter(ShaderModule&);
std::optional<FailedCheck> rewrite();
// Declarations
void visit(AST::Structure&) override;
void visit(AST::Variable&) override;
void visit(AST::Function&) override;
// Statements
void visit(AST::CompoundStatement&) override;
// Expressions
void visit(AST::IdentifierExpression&) override;
void visit(AST::CallExpression&) override;
// Literals
void visit(AST::AbstractFloatLiteral&) override;
void visit(AST::AbstractIntegerLiteral&) override;
void visit(AST::BoolLiteral&) override;
void visit(AST::Float32Literal&) override;
void visit(AST::Signed32Literal&) override;
void visit(AST::Unsigned32Literal&) override;
private:
template<typename... Arguments>
void error(const SourceSpan&, Arguments&&...);
std::optional<ConstantValue> evaluate(AST::Expression&);
template<typename Node>
void evaluated(Node&, ConstantValue);
template<typename Node>
void materialize(Node&, const ConstantValue&);
using ConstantFunction = ConstantValue(*)(const FixedVector<ConstantValue>&);
ShaderModule& m_shaderModule;
Vector<Error> m_errors;
HashMap<String, ConstantFunction> m_constantFunctions;
std::optional<ConstantValue> m_evaluationResult;
};
ConstantRewriter::ConstantRewriter(ShaderModule& shaderModule)
: m_shaderModule(shaderModule)
{
m_constantFunctions.add("pow"_s, constantPow);
}
std::optional<FailedCheck> ConstantRewriter::rewrite()
{
for (auto& structure : m_shaderModule.structures())
visit(structure);
for (auto& variable : m_shaderModule.variables())
visit(variable);
for (auto& function : m_shaderModule.functions())
visit(function);
if (m_errors.isEmpty())
return std::nullopt;
Vector<Warning> warnings { };
return FailedCheck { WTFMove(m_errors), WTFMove(warnings) };
}
// Declarations
void ConstantRewriter::visit(AST::Structure& structure)
{
// FIXME: implement
UNUSED_PARAM(structure);
}
void ConstantRewriter::visit(AST::Variable& variable)
{
if (variable.flavor() != AST::VariableFlavor::Const) {
introduceVariable(variable.name(), std::nullopt);
AST::Visitor::visit(variable);
return;
}
ASSERT(variable.maybeInitializer());
auto value = evaluate(*variable.maybeInitializer());
ASSERT(value.has_value());
introduceVariable(variable.name(), *value);
}
void ConstantRewriter::visit(AST::Function& function)
{
ContextScope functionScope(this);
for (auto& parameter : function.parameters())
introduceVariable(parameter.name(), std::nullopt);
AST::Visitor::visit(function);
}
// Statements
void ConstantRewriter::visit(AST::CompoundStatement& statement)
{
ContextScope blockScope(this);
AST::Visitor::visit(statement);
}
// Expressions
void ConstantRewriter::visit(AST::IdentifierExpression& identifier)
{
auto* constant = readVariable(identifier.identifier());
if (!constant || !constant->has_value())
return;
if (shouldDumpConstantValues)
dataLogLn("> Replacing identifier '", identifier.identifier(), "' with constant: ", *constant);
materialize(identifier, constant->value());
}
// Private helpers
std::optional<ConstantValue> ConstantRewriter::evaluate(AST::Expression& expression)
{
AST::Visitor::visit(expression);
return m_evaluationResult;
}
void ConstantRewriter::visit(AST::CallExpression& call)
{
unsigned argumentCount = call.arguments().size();
FixedVector<ConstantValue> arguments(argumentCount);
for (unsigned i = 0; i < argumentCount; ++i) {
auto value = evaluate(call.arguments()[i]);
if (!value.has_value())
return;
arguments[i] = *value;
}
auto& target = call.target();
bool isNamedType = is<AST::NamedTypeName>(target);
bool isParameterizedType = !isNamedType && is<AST::ParameterizedTypeName>(target);
if (isNamedType || isParameterizedType) {
std::optional<AST::ParameterizedTypeName::Base> base;
if (isParameterizedType)
base = downcast<AST::ParameterizedTypeName>(target).base();
else {
auto& targetName = downcast<AST::NamedTypeName>(target).name().id();
base = AST::ParameterizedTypeName::stringViewToKind(targetName);
}
if (base) {
switch (*base) {
case AST::ParameterizedTypeName::Base::Vec2:
case AST::ParameterizedTypeName::Base::Vec3:
case AST::ParameterizedTypeName::Base::Vec4: {
evaluated(call, { call.inferredType(), ConstantVector(WTFMove(arguments)) });
return;
}
case AST::ParameterizedTypeName::Base::Mat2x2:
case AST::ParameterizedTypeName::Base::Mat2x3:
case AST::ParameterizedTypeName::Base::Mat2x4:
case AST::ParameterizedTypeName::Base::Mat3x2:
case AST::ParameterizedTypeName::Base::Mat3x3:
case AST::ParameterizedTypeName::Base::Mat3x4:
case AST::ParameterizedTypeName::Base::Mat4x2:
case AST::ParameterizedTypeName::Base::Mat4x3:
case AST::ParameterizedTypeName::Base::Mat4x4:
case AST::ParameterizedTypeName::Base::Texture1d:
case AST::ParameterizedTypeName::Base::Texture2d:
case AST::ParameterizedTypeName::Base::Texture2dArray:
case AST::ParameterizedTypeName::Base::Texture3d:
case AST::ParameterizedTypeName::Base::TextureCube:
case AST::ParameterizedTypeName::Base::TextureCubeArray:
case AST::ParameterizedTypeName::Base::TextureMultisampled2d:
case AST::ParameterizedTypeName::Base::TextureStorage1d:
case AST::ParameterizedTypeName::Base::TextureStorage2d:
case AST::ParameterizedTypeName::Base::TextureStorage2dArray:
case AST::ParameterizedTypeName::Base::TextureStorage3d:
return;
}
}
ASSERT(isNamedType);
auto it = m_constantFunctions.find(downcast<AST::NamedTypeName>(target).name().id());
if (it != m_constantFunctions.end()) {
materialize(call, it->value(arguments));
return;
}
}
if (is<AST::ArrayTypeName>(target)) {
evaluated(call, { call.inferredType(), ConstantArray(WTFMove(arguments)) });
return;
}
}
// Literals
void ConstantRewriter::visit(AST::BoolLiteral& literal)
{
evaluated(literal, { literal.inferredType(), literal.value() });
}
void ConstantRewriter::visit(AST::Signed32Literal& literal)
{
evaluated(literal, { literal.inferredType(), literal.value() });
}
void ConstantRewriter::visit(AST::Float32Literal& literal)
{
evaluated(literal, { literal.inferredType(), literal.value() });
}
void ConstantRewriter::visit(AST::Unsigned32Literal& literal)
{
evaluated(literal, { literal.inferredType(), literal.value() });
}
void ConstantRewriter::visit(AST::AbstractIntegerLiteral& literal)
{
evaluated(literal, { literal.inferredType(), literal.value() });
}
void ConstantRewriter::visit(AST::AbstractFloatLiteral& literal)
{
evaluated(literal, { literal.inferredType(), literal.value() });
}
template<typename Node>
void ConstantRewriter::materialize(Node& expression, const ConstantValue& value)
{
using namespace Types;
evaluated(expression, value);
const auto& replace = [&]<typename Literal, typename Value>() {
auto& node = m_shaderModule.astBuilder().construct<Literal>(SourceSpan::empty(), std::get<Value>(value));
node.m_inferredType = expression.inferredType();
m_shaderModule.replace(expression, node);
};
return WTF::switchOn(*value.type,
[&](const Primitive& primitive) {
switch (primitive.kind) {
case Primitive::AbstractInt:
replace.template operator()<AST::AbstractIntegerLiteral, int64_t>();
break;
case Primitive::I32:
replace.template operator()<AST::Signed32Literal, int64_t>();
break;
case Primitive::U32:
replace.template operator()<AST::Unsigned32Literal, int64_t>();
break;
case Primitive::AbstractFloat:
replace.template operator()<AST::AbstractFloatLiteral, double>();
break;
case Primitive::F32:
replace.template operator()<AST::Float32Literal, double>();
break;
case Primitive::Bool:
replace.template operator()<AST::BoolLiteral, bool>();
break;
case Primitive::Void:
case Primitive::Sampler:
case Primitive::TextureExternal:
RELEASE_ASSERT_NOT_REACHED();
}
},
[&](const Vector&) {
// Do not materialize complex types
},
[&](const Matrix&) {
// Do not materialize complex types
},
[&](const Array&) {
// Do not materialize complex types
},
[&](const Struct&) {
// Do not materialize complex types
},
[&](const Reference&) {
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Function&) {
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Texture&) {
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Bottom&) {
RELEASE_ASSERT_NOT_REACHED();
});
}
template<typename Node>
void ConstantRewriter::evaluated(Node& expression, ConstantValue value)
{
m_evaluationResult = value;
if (shouldDumpConstantValues) {
dataLog("> Evaluated expression: ");
dumpNode(WTF::dataFile(), expression);
dataLog(" = ");
dataLogLn(value);
}
}
template<typename... Arguments>
void ConstantRewriter::error(const SourceSpan& span, Arguments&&... arguments)
{
m_errors.append({ makeString(std::forward<Arguments>(arguments)...), span });
}
std::optional<FailedCheck> rewriteConstants(ShaderModule& shaderModule)
{
return ConstantRewriter(shaderModule).rewrite();
}
} // namespace WGSL