blob: 10106a75658a4ea094b356e3483323d08cc09c83 [file] [log] [blame] [edit]
/*
* Copyright 2022 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef wasm_ir_function_h
#define wasm_ir_function_h
#include <variant>
#include "ir/debuginfo.h"
#include "ir/type-updating.h"
#include "wasm.h"
namespace wasm::CallUtils {
// Define a variant to describe the information we know about an indirect call,
// which is one of three things:
// * Unknown: Nothing is known this call.
// * Trap: This call target is invalid and will trap at runtime.
// * Known: This call goes to a known static call target, which is provided.
struct Unknown : public std::monostate {};
struct Trap : public std::monostate {};
struct Known {
Name target;
};
using IndirectCallInfo = std::variant<Unknown, Trap, Known>;
// Converts indirect calls that target selects between values into ifs over
// direct calls. For example, consider this input:
//
// (call_ref
// (select
// (ref.func A)
// (ref.func B)
// (..condition..)
// )
// )
//
// We'll check if the input falls into such a pattern, and if so, return the new
// form:
//
// (if
// (..condition..)
// (call $A)
// (call $B)
// )
//
// If we fail to find the expected pattern, or we decide it is not worth
// optimizing it for some reason, we return nullptr.
//
// If this returns the new form, it will modify the function as necessary,
// adding new locals etc., which later passes should optimize.
//
// |getCallInfo| is given one of the arms of the select and should return an
// IndirectCallInfo that says what we know about it. We may know nothing, or
// that it will trap, or that it will go to a known static target.
template<typename T>
inline Expression*
convertToDirectCalls(T* curr,
std::function<IndirectCallInfo(Expression*)> getCallInfo,
Function& func,
Module& wasm) {
auto* select = curr->target->template dynCast<Select>();
if (!select) {
return nullptr;
}
if (select->type == Type::unreachable) {
// Leave this for DCE.
return nullptr;
}
// Check if we can find useful info for both arms: either known call targets,
// or traps.
// TODO: support more than 2 targets (with nested selects)
auto ifTrueCallInfo = getCallInfo(select->ifTrue);
auto ifFalseCallInfo = getCallInfo(select->ifFalse);
if (std::get_if<Unknown>(&ifTrueCallInfo) ||
std::get_if<Unknown>(&ifFalseCallInfo)) {
// We know nothing about at least one arm, so give up.
// TODO: Perhaps emitting a direct call for one arm is enough even if the
// other remains indirect?
return nullptr;
}
auto& operands = curr->operands;
// We must use the operands twice, and also must move the condition to
// execute first, so we'll use locals for them all. First, see if any are
// unreachable, and if so stop trying to optimize and leave this for DCE.
for (auto* operand : operands) {
if (operand->type == Type::unreachable ||
!TypeUpdating::canHandleAsLocal(operand->type)) {
return nullptr;
}
}
Builder builder(wasm);
std::vector<Expression*> blockContents;
// None of the types are a problem, so we can proceed to add new vars as
// needed and perform this optimization.
std::vector<Index> operandLocals;
for (auto* operand : operands) {
auto currLocal = builder.addVar(&func, operand->type);
operandLocals.push_back(currLocal);
blockContents.push_back(builder.makeLocalSet(currLocal, operand));
}
// Build the calls.
auto numOperands = operands.size();
auto getOperands = [&]() {
std::vector<Expression*> newOperands(numOperands);
for (Index i = 0; i < numOperands; i++) {
newOperands[i] =
builder.makeLocalGet(operandLocals[i], operands[i]->type);
}
return newOperands;
};
auto makeCall = [&](IndirectCallInfo info) -> Expression* {
Expression* ret;
if (std::get_if<Trap>(&info)) {
ret = builder.makeUnreachable();
} else {
ret = builder.makeCall(std::get<Known>(info).target,
getOperands(),
curr->type,
curr->isReturn);
}
debuginfo::copyOriginalToReplacement(curr, ret, &func);
return ret;
};
auto* ifTrueCall = makeCall(ifTrueCallInfo);
auto* ifFalseCall = makeCall(ifFalseCallInfo);
// Create the if to pick the calls, and emit the final block.
auto* iff = builder.makeIf(select->condition, ifTrueCall, ifFalseCall);
blockContents.push_back(iff);
return builder.makeBlock(blockContents);
}
} // namespace wasm::CallUtils
#endif // wasm_ir_function_h