| /* |
| * Copyright 2022 WebAssembly Community Group participants |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "passes/param-utils.h" |
| #include "cfg/liveness-traversal.h" |
| #include "ir/eh-utils.h" |
| #include "ir/function-utils.h" |
| #include "ir/localize.h" |
| #include "ir/possible-constant.h" |
| #include "ir/type-updating.h" |
| #include "pass.h" |
| #include "support/sorted_vector.h" |
| #include "wasm-traversal.h" |
| #include "wasm.h" |
| |
| namespace wasm::ParamUtils { |
| |
| std::unordered_set<Index> getUsedParams(Function* func, Module* module) { |
| // To find which params are used, compute liveness at the entry. |
| // TODO: We could write bespoke code here rather than reuse LivenessWalker, as |
| // we only need liveness at the entry. The code below computes it for |
| // the param indexes in the entire function. However, there are usually |
| // very few params (compared to locals, which we ignore here), so this |
| // may be fast enough, and is very simple. |
| struct ParamLiveness |
| : public LivenessWalker<ParamLiveness, Visitor<ParamLiveness>> { |
| using Super = LivenessWalker<ParamLiveness, Visitor<ParamLiveness>>; |
| |
| // Branches outside of the function can be ignored, as we only look at |
| // locals, which vanish when we leave. |
| bool ignoreBranchesOutsideOfFunc = true; |
| |
| // Ignore unreachable code and non-params. |
| static void doVisitLocalGet(ParamLiveness* self, Expression** currp) { |
| auto* get = (*currp)->cast<LocalGet>(); |
| if (self->currBasicBlock && self->getFunction()->isParam(get->index)) { |
| Super::doVisitLocalGet(self, currp); |
| } |
| } |
| static void doVisitLocalSet(ParamLiveness* self, Expression** currp) { |
| auto* set = (*currp)->cast<LocalSet>(); |
| if (self->currBasicBlock && self->getFunction()->isParam(set->index)) { |
| Super::doVisitLocalSet(self, currp); |
| } |
| } |
| } walker; |
| walker.setModule(module); |
| walker.walkFunction(func); |
| |
| if (!walker.entry) { |
| // Empty function: nothing is used. |
| return {}; |
| } |
| |
| // We now have a sorted vector of the live params at the entry. Convert that |
| // to a set. |
| auto& sortedLiveness = walker.entry->contents.start; |
| std::unordered_set<Index> usedParams; |
| for (auto live : sortedLiveness) { |
| usedParams.insert(live); |
| } |
| return usedParams; |
| } |
| |
| RemovalOutcome removeParameter(const std::vector<Function*>& funcs, |
| Index index, |
| const std::vector<Call*>& calls, |
| const std::vector<CallRef*>& callRefs, |
| Module* module, |
| PassRunner* runner) { |
| assert(funcs.size() > 0); |
| auto* first = funcs[0]; |
| #ifndef NDEBUG |
| for (auto* func : funcs) { |
| assert(func->type == first->type); |
| } |
| #endif |
| |
| // Check if none of the calls has a param with side effects that we cannot |
| // remove (as if we can remove them, we will simply do that when we remove the |
| // parameter). Note: flattening the IR beforehand can help here. |
| // |
| // It would also be bad if we remove a parameter that causes the type of the |
| // call to change, like this: |
| // |
| // (call $foo |
| // (unreachable)) |
| // |
| // After removing the parameter the type should change from unreachable to |
| // something concrete. We could handle this by updating the type and then |
| // propagating that out, or by appending an unreachable after the call, but |
| // for simplicity just ignore such cases; if we are called again later then |
| // if DCE ran meanwhile then we could optimize. |
| auto checkEffects = [&](auto* call) { |
| auto* operand = call->operands[index]; |
| |
| if (operand->type == Type::unreachable) { |
| return Failure; |
| } |
| |
| bool hasUnremovable = EffectAnalyzer(runner->options, *module, operand) |
| .hasUnremovableSideEffects(); |
| |
| return hasUnremovable ? Failure : Success; |
| }; |
| |
| for (auto* call : calls) { |
| auto result = checkEffects(call); |
| if (result != Success) { |
| return result; |
| } |
| } |
| |
| for (auto* call : callRefs) { |
| auto result = checkEffects(call); |
| if (result != Success) { |
| return result; |
| } |
| } |
| |
| // The type must be valid for us to handle as a local (since we |
| // replace the parameter with a local). |
| // TODO: if there are no references at all, we can avoid creating a |
| // local |
| bool typeIsValid = TypeUpdating::canHandleAsLocal(first->getLocalType(index)); |
| if (!typeIsValid) { |
| return Failure; |
| } |
| |
| // We can do it! |
| |
| // Remove the parameter from the function. We must add a new local |
| // for uses of the parameter, but cannot make it use the same index |
| // (in general). |
| auto paramsType = first->getParams(); |
| std::vector<Type> params(paramsType.begin(), paramsType.end()); |
| auto type = params[index]; |
| params.erase(params.begin() + index); |
| // TODO: parallelize some of these loops? |
| for (auto* func : funcs) { |
| func->setParams(Type(params)); |
| |
| // It's cumbersome to adjust local names - TODO don't clear them? |
| Builder::clearLocalNames(func); |
| } |
| std::vector<Index> newIndexes; |
| for (auto* func : funcs) { |
| newIndexes.push_back(Builder::addVar(func, type)); |
| } |
| // Update local operations. |
| struct LocalUpdater : public PostWalker<LocalUpdater> { |
| Index removedIndex; |
| Index newIndex; |
| LocalUpdater(Function* func, Index removedIndex, Index newIndex) |
| : removedIndex(removedIndex), newIndex(newIndex) { |
| walk(func->body); |
| } |
| void visitLocalGet(LocalGet* curr) { updateIndex(curr->index); } |
| void visitLocalSet(LocalSet* curr) { updateIndex(curr->index); } |
| void updateIndex(Index& index) { |
| if (index == removedIndex) { |
| index = newIndex; |
| } else if (index > removedIndex) { |
| index--; |
| } |
| } |
| }; |
| for (Index i = 0; i < funcs.size(); i++) { |
| auto* func = funcs[i]; |
| if (!func->imported()) { |
| LocalUpdater(funcs[i], index, newIndexes[i]); |
| TypeUpdating::handleNonDefaultableLocals(func, *module); |
| } |
| } |
| |
| // Remove the arguments from the calls. |
| for (auto* call : calls) { |
| call->operands.erase(call->operands.begin() + index); |
| } |
| for (auto* call : callRefs) { |
| call->operands.erase(call->operands.begin() + index); |
| } |
| |
| return Success; |
| } |
| |
| std::pair<SortedVector, RemovalOutcome> |
| removeParameters(const std::vector<Function*>& funcs, |
| SortedVector indexes, |
| const std::vector<Call*>& calls, |
| const std::vector<CallRef*>& callRefs, |
| Module* module, |
| PassRunner* runner) { |
| if (indexes.empty()) { |
| return {{}, Success}; |
| } |
| |
| assert(funcs.size() > 0); |
| auto* first = funcs[0]; |
| #ifndef NDEBUG |
| for (auto* func : funcs) { |
| assert(func->type == first->type); |
| } |
| #endif |
| |
| // Iterate downwards, as we may remove more than one, and going forwards would |
| // alter the indexes after us. |
| Index i = first->getNumParams() - 1; |
| SortedVector removed; |
| while (1) { |
| if (indexes.has(i)) { |
| auto outcome = removeParameter(funcs, i, calls, callRefs, module, runner); |
| if (outcome == Success) { |
| removed.insert(i); |
| } |
| } |
| if (i == 0) { |
| break; |
| } |
| i--; |
| } |
| RemovalOutcome finalOutcome = Success; |
| if (removed.size() < indexes.size()) { |
| finalOutcome = Failure; |
| } |
| return {removed, finalOutcome}; |
| } |
| |
| SortedVector applyConstantValues(const std::vector<Function*>& funcs, |
| const std::vector<Call*>& calls, |
| const std::vector<CallRef*>& callRefs, |
| Module* module) { |
| assert(funcs.size() > 0); |
| auto* first = funcs[0]; |
| #ifndef NDEBUG |
| for (auto* func : funcs) { |
| assert(func->type == first->type); |
| } |
| #endif |
| |
| SortedVector optimized; |
| auto numParams = first->getNumParams(); |
| for (Index i = 0; i < numParams; i++) { |
| PossibleConstantValues value; |
| for (auto* call : calls) { |
| value.note(call->operands[i], *module); |
| if (!value.isConstant()) { |
| break; |
| } |
| } |
| for (auto* call : callRefs) { |
| value.note(call->operands[i], *module); |
| if (!value.isConstant()) { |
| break; |
| } |
| } |
| if (!value.isConstant()) { |
| continue; |
| } |
| |
| // Optimize: write the constant value in the function bodies, making them |
| // ignore the parameter's value. |
| Builder builder(*module); |
| for (auto* func : funcs) { |
| func->body = builder.makeSequence( |
| builder.makeLocalSet(i, value.makeExpression(*module)), func->body); |
| } |
| optimized.insert(i); |
| } |
| |
| return optimized; |
| } |
| |
| void localizeCallsTo(const std::unordered_set<Name>& callTargets, |
| Module& wasm, |
| PassRunner* runner, |
| std::function<void(Function*)> onChange) { |
| struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> { |
| bool isFunctionParallel() override { return true; } |
| |
| std::unique_ptr<Pass> create() override { |
| return std::make_unique<LocalizerPass>(callTargets, onChange); |
| } |
| |
| const std::unordered_set<Name>& callTargets; |
| std::function<void(Function*)> onChange; |
| |
| LocalizerPass(const std::unordered_set<Name>& callTargets, |
| std::function<void(Function*)> onChange) |
| : callTargets(callTargets), onChange(onChange) {} |
| |
| void visitCall(Call* curr) { |
| if (!callTargets.count(curr->target)) { |
| return; |
| } |
| |
| ChildLocalizer localizer( |
| curr, getFunction(), *getModule(), getPassOptions()); |
| auto* replacement = localizer.getReplacement(); |
| if (replacement != curr) { |
| replaceCurrent(replacement); |
| optimized = true; |
| onChange(getFunction()); |
| } |
| } |
| |
| bool optimized = false; |
| |
| void visitFunction(Function* curr) { |
| if (optimized) { |
| // Localization can add blocks, which might move pops. |
| EHUtils::handleBlockNestedPops(curr, *getModule()); |
| } |
| } |
| }; |
| |
| LocalizerPass(callTargets, onChange).run(runner, &wasm); |
| } |
| |
| void localizeCallsTo(const std::unordered_set<HeapType>& callTargets, |
| Module& wasm, |
| PassRunner* runner) { |
| struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> { |
| bool isFunctionParallel() override { return true; } |
| |
| std::unique_ptr<Pass> create() override { |
| return std::make_unique<LocalizerPass>(callTargets); |
| } |
| |
| const std::unordered_set<HeapType>& callTargets; |
| |
| LocalizerPass(const std::unordered_set<HeapType>& callTargets) |
| : callTargets(callTargets) {} |
| |
| void visitCall(Call* curr) { |
| handleCall(curr, getModule()->getFunction(curr->target)->type); |
| } |
| |
| void visitCallRef(CallRef* curr) { |
| auto type = curr->target->type; |
| if (type.isRef()) { |
| handleCall(curr, type.getHeapType()); |
| } |
| } |
| |
| void handleCall(Expression* call, HeapType type) { |
| if (!callTargets.count(type)) { |
| return; |
| } |
| |
| ChildLocalizer localizer( |
| call, getFunction(), *getModule(), getPassOptions()); |
| auto* replacement = localizer.getReplacement(); |
| if (replacement != call) { |
| replaceCurrent(replacement); |
| optimized = true; |
| } |
| } |
| |
| bool optimized = false; |
| |
| void visitFunction(Function* curr) { |
| if (optimized) { |
| EHUtils::handleBlockNestedPops(curr, *getModule()); |
| } |
| } |
| }; |
| |
| LocalizerPass(callTargets).run(runner, &wasm); |
| } |
| |
| } // namespace wasm::ParamUtils |