blob: a600e1928c7d7c24452b66a5dbbdb9e951cc625c [file] [log] [blame] [edit]
/*
* Copyright 2022 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "passes/param-utils.h"
#include "cfg/liveness-traversal.h"
#include "ir/eh-utils.h"
#include "ir/function-utils.h"
#include "ir/localize.h"
#include "ir/possible-constant.h"
#include "ir/type-updating.h"
#include "pass.h"
#include "support/sorted_vector.h"
#include "wasm-traversal.h"
#include "wasm.h"
namespace wasm::ParamUtils {
std::unordered_set<Index> getUsedParams(Function* func, Module* module) {
// To find which params are used, compute liveness at the entry.
// TODO: We could write bespoke code here rather than reuse LivenessWalker, as
// we only need liveness at the entry. The code below computes it for
// the param indexes in the entire function. However, there are usually
// very few params (compared to locals, which we ignore here), so this
// may be fast enough, and is very simple.
struct ParamLiveness
: public LivenessWalker<ParamLiveness, Visitor<ParamLiveness>> {
using Super = LivenessWalker<ParamLiveness, Visitor<ParamLiveness>>;
// Branches outside of the function can be ignored, as we only look at
// locals, which vanish when we leave.
bool ignoreBranchesOutsideOfFunc = true;
// Ignore unreachable code and non-params.
static void doVisitLocalGet(ParamLiveness* self, Expression** currp) {
auto* get = (*currp)->cast<LocalGet>();
if (self->currBasicBlock && self->getFunction()->isParam(get->index)) {
Super::doVisitLocalGet(self, currp);
}
}
static void doVisitLocalSet(ParamLiveness* self, Expression** currp) {
auto* set = (*currp)->cast<LocalSet>();
if (self->currBasicBlock && self->getFunction()->isParam(set->index)) {
Super::doVisitLocalSet(self, currp);
}
}
} walker;
walker.setModule(module);
walker.walkFunction(func);
if (!walker.entry) {
// Empty function: nothing is used.
return {};
}
// We now have a sorted vector of the live params at the entry. Convert that
// to a set.
auto& sortedLiveness = walker.entry->contents.start;
std::unordered_set<Index> usedParams;
for (auto live : sortedLiveness) {
usedParams.insert(live);
}
return usedParams;
}
RemovalOutcome removeParameter(const std::vector<Function*>& funcs,
Index index,
const std::vector<Call*>& calls,
const std::vector<CallRef*>& callRefs,
Module* module,
PassRunner* runner) {
assert(funcs.size() > 0);
auto* first = funcs[0];
#ifndef NDEBUG
for (auto* func : funcs) {
assert(func->type == first->type);
}
#endif
// Check if none of the calls has a param with side effects that we cannot
// remove (as if we can remove them, we will simply do that when we remove the
// parameter). Note: flattening the IR beforehand can help here.
//
// It would also be bad if we remove a parameter that causes the type of the
// call to change, like this:
//
// (call $foo
// (unreachable))
//
// After removing the parameter the type should change from unreachable to
// something concrete. We could handle this by updating the type and then
// propagating that out, or by appending an unreachable after the call, but
// for simplicity just ignore such cases; if we are called again later then
// if DCE ran meanwhile then we could optimize.
auto checkEffects = [&](auto* call) {
auto* operand = call->operands[index];
if (operand->type == Type::unreachable) {
return Failure;
}
bool hasUnremovable = EffectAnalyzer(runner->options, *module, operand)
.hasUnremovableSideEffects();
return hasUnremovable ? Failure : Success;
};
for (auto* call : calls) {
auto result = checkEffects(call);
if (result != Success) {
return result;
}
}
for (auto* call : callRefs) {
auto result = checkEffects(call);
if (result != Success) {
return result;
}
}
// The type must be valid for us to handle as a local (since we
// replace the parameter with a local).
// TODO: if there are no references at all, we can avoid creating a
// local
bool typeIsValid = TypeUpdating::canHandleAsLocal(first->getLocalType(index));
if (!typeIsValid) {
return Failure;
}
// We can do it!
// Remove the parameter from the function. We must add a new local
// for uses of the parameter, but cannot make it use the same index
// (in general).
auto paramsType = first->getParams();
std::vector<Type> params(paramsType.begin(), paramsType.end());
auto type = params[index];
params.erase(params.begin() + index);
// TODO: parallelize some of these loops?
for (auto* func : funcs) {
func->setParams(Type(params));
// It's cumbersome to adjust local names - TODO don't clear them?
Builder::clearLocalNames(func);
}
std::vector<Index> newIndexes;
for (auto* func : funcs) {
newIndexes.push_back(Builder::addVar(func, type));
}
// Update local operations.
struct LocalUpdater : public PostWalker<LocalUpdater> {
Index removedIndex;
Index newIndex;
LocalUpdater(Function* func, Index removedIndex, Index newIndex)
: removedIndex(removedIndex), newIndex(newIndex) {
walk(func->body);
}
void visitLocalGet(LocalGet* curr) { updateIndex(curr->index); }
void visitLocalSet(LocalSet* curr) { updateIndex(curr->index); }
void updateIndex(Index& index) {
if (index == removedIndex) {
index = newIndex;
} else if (index > removedIndex) {
index--;
}
}
};
for (Index i = 0; i < funcs.size(); i++) {
auto* func = funcs[i];
if (!func->imported()) {
LocalUpdater(funcs[i], index, newIndexes[i]);
TypeUpdating::handleNonDefaultableLocals(func, *module);
}
}
// Remove the arguments from the calls.
for (auto* call : calls) {
call->operands.erase(call->operands.begin() + index);
}
for (auto* call : callRefs) {
call->operands.erase(call->operands.begin() + index);
}
return Success;
}
std::pair<SortedVector, RemovalOutcome>
removeParameters(const std::vector<Function*>& funcs,
SortedVector indexes,
const std::vector<Call*>& calls,
const std::vector<CallRef*>& callRefs,
Module* module,
PassRunner* runner) {
if (indexes.empty()) {
return {{}, Success};
}
assert(funcs.size() > 0);
auto* first = funcs[0];
#ifndef NDEBUG
for (auto* func : funcs) {
assert(func->type == first->type);
}
#endif
// Iterate downwards, as we may remove more than one, and going forwards would
// alter the indexes after us.
Index i = first->getNumParams() - 1;
SortedVector removed;
while (1) {
if (indexes.has(i)) {
auto outcome = removeParameter(funcs, i, calls, callRefs, module, runner);
if (outcome == Success) {
removed.insert(i);
}
}
if (i == 0) {
break;
}
i--;
}
RemovalOutcome finalOutcome = Success;
if (removed.size() < indexes.size()) {
finalOutcome = Failure;
}
return {removed, finalOutcome};
}
SortedVector applyConstantValues(const std::vector<Function*>& funcs,
const std::vector<Call*>& calls,
const std::vector<CallRef*>& callRefs,
Module* module) {
assert(funcs.size() > 0);
auto* first = funcs[0];
#ifndef NDEBUG
for (auto* func : funcs) {
assert(func->type == first->type);
}
#endif
SortedVector optimized;
auto numParams = first->getNumParams();
for (Index i = 0; i < numParams; i++) {
PossibleConstantValues value;
for (auto* call : calls) {
value.note(call->operands[i], *module);
if (!value.isConstant()) {
break;
}
}
for (auto* call : callRefs) {
value.note(call->operands[i], *module);
if (!value.isConstant()) {
break;
}
}
if (!value.isConstant()) {
continue;
}
// Optimize: write the constant value in the function bodies, making them
// ignore the parameter's value.
Builder builder(*module);
for (auto* func : funcs) {
func->body = builder.makeSequence(
builder.makeLocalSet(i, value.makeExpression(*module)), func->body);
}
optimized.insert(i);
}
return optimized;
}
void localizeCallsTo(const std::unordered_set<Name>& callTargets,
Module& wasm,
PassRunner* runner,
std::function<void(Function*)> onChange) {
struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> {
bool isFunctionParallel() override { return true; }
std::unique_ptr<Pass> create() override {
return std::make_unique<LocalizerPass>(callTargets, onChange);
}
const std::unordered_set<Name>& callTargets;
std::function<void(Function*)> onChange;
LocalizerPass(const std::unordered_set<Name>& callTargets,
std::function<void(Function*)> onChange)
: callTargets(callTargets), onChange(onChange) {}
void visitCall(Call* curr) {
if (!callTargets.count(curr->target)) {
return;
}
ChildLocalizer localizer(
curr, getFunction(), *getModule(), getPassOptions());
auto* replacement = localizer.getReplacement();
if (replacement != curr) {
replaceCurrent(replacement);
optimized = true;
onChange(getFunction());
}
}
bool optimized = false;
void visitFunction(Function* curr) {
if (optimized) {
// Localization can add blocks, which might move pops.
EHUtils::handleBlockNestedPops(curr, *getModule());
}
}
};
LocalizerPass(callTargets, onChange).run(runner, &wasm);
}
void localizeCallsTo(const std::unordered_set<HeapType>& callTargets,
Module& wasm,
PassRunner* runner) {
struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> {
bool isFunctionParallel() override { return true; }
std::unique_ptr<Pass> create() override {
return std::make_unique<LocalizerPass>(callTargets);
}
const std::unordered_set<HeapType>& callTargets;
LocalizerPass(const std::unordered_set<HeapType>& callTargets)
: callTargets(callTargets) {}
void visitCall(Call* curr) {
handleCall(curr, getModule()->getFunction(curr->target)->type);
}
void visitCallRef(CallRef* curr) {
auto type = curr->target->type;
if (type.isRef()) {
handleCall(curr, type.getHeapType());
}
}
void handleCall(Expression* call, HeapType type) {
if (!callTargets.count(type)) {
return;
}
ChildLocalizer localizer(
call, getFunction(), *getModule(), getPassOptions());
auto* replacement = localizer.getReplacement();
if (replacement != call) {
replaceCurrent(replacement);
optimized = true;
}
}
bool optimized = false;
void visitFunction(Function* curr) {
if (optimized) {
EHUtils::handleBlockNestedPops(curr, *getModule());
}
}
};
LocalizerPass(callTargets).run(runner, &wasm);
}
} // namespace wasm::ParamUtils