| /* |
| * Copyright 2022 WebAssembly Community Group participants |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #ifndef wasm_ir_function_h |
| #define wasm_ir_function_h |
| |
| #include <variant> |
| |
| #include "ir/debuginfo.h" |
| #include "ir/type-updating.h" |
| #include "wasm.h" |
| |
| namespace wasm::CallUtils { |
| |
| // Define a variant to describe the information we know about an indirect call, |
| // which is one of three things: |
| // * Unknown: Nothing is known this call. |
| // * Trap: This call target is invalid and will trap at runtime. |
| // * Known: This call goes to a known static call target, which is provided. |
| struct Unknown : public std::monostate {}; |
| struct Trap : public std::monostate {}; |
| struct Known { |
| Name target; |
| }; |
| using IndirectCallInfo = std::variant<Unknown, Trap, Known>; |
| |
| // Converts indirect calls that target selects between values into ifs over |
| // direct calls. For example, consider this input: |
| // |
| // (call_ref |
| // (select |
| // (ref.func A) |
| // (ref.func B) |
| // (..condition..) |
| // ) |
| // ) |
| // |
| // We'll check if the input falls into such a pattern, and if so, return the new |
| // form: |
| // |
| // (if |
| // (..condition..) |
| // (call $A) |
| // (call $B) |
| // ) |
| // |
| // If we fail to find the expected pattern, or we decide it is not worth |
| // optimizing it for some reason, we return nullptr. |
| // |
| // If this returns the new form, it will modify the function as necessary, |
| // adding new locals etc., which later passes should optimize. |
| // |
| // |getCallInfo| is given one of the arms of the select and should return an |
| // IndirectCallInfo that says what we know about it. We may know nothing, or |
| // that it will trap, or that it will go to a known static target. |
| template<typename T> |
| inline Expression* |
| convertToDirectCalls(T* curr, |
| std::function<IndirectCallInfo(Expression*)> getCallInfo, |
| Function& func, |
| Module& wasm) { |
| auto* select = curr->target->template dynCast<Select>(); |
| if (!select) { |
| return nullptr; |
| } |
| |
| if (select->type == Type::unreachable) { |
| // Leave this for DCE. |
| return nullptr; |
| } |
| |
| // Check if we can find useful info for both arms: either known call targets, |
| // or traps. |
| // TODO: support more than 2 targets (with nested selects) |
| auto ifTrueCallInfo = getCallInfo(select->ifTrue); |
| auto ifFalseCallInfo = getCallInfo(select->ifFalse); |
| if (std::get_if<Unknown>(&ifTrueCallInfo) || |
| std::get_if<Unknown>(&ifFalseCallInfo)) { |
| // We know nothing about at least one arm, so give up. |
| // TODO: Perhaps emitting a direct call for one arm is enough even if the |
| // other remains indirect? |
| return nullptr; |
| } |
| |
| auto& operands = curr->operands; |
| |
| // We must use the operands twice, and also must move the condition to |
| // execute first, so we'll use locals for them all. First, see if any are |
| // unreachable, and if so stop trying to optimize and leave this for DCE. |
| for (auto* operand : operands) { |
| if (operand->type == Type::unreachable || |
| !TypeUpdating::canHandleAsLocal(operand->type)) { |
| return nullptr; |
| } |
| } |
| |
| Builder builder(wasm); |
| std::vector<Expression*> blockContents; |
| |
| // None of the types are a problem, so we can proceed to add new vars as |
| // needed and perform this optimization. |
| std::vector<Index> operandLocals; |
| for (auto* operand : operands) { |
| auto currLocal = builder.addVar(&func, operand->type); |
| operandLocals.push_back(currLocal); |
| blockContents.push_back(builder.makeLocalSet(currLocal, operand)); |
| } |
| |
| // Build the calls. |
| auto numOperands = operands.size(); |
| auto getOperands = [&]() { |
| std::vector<Expression*> newOperands(numOperands); |
| for (Index i = 0; i < numOperands; i++) { |
| newOperands[i] = |
| builder.makeLocalGet(operandLocals[i], operands[i]->type); |
| } |
| return newOperands; |
| }; |
| |
| auto makeCall = [&](IndirectCallInfo info) -> Expression* { |
| Expression* ret; |
| if (std::get_if<Trap>(&info)) { |
| ret = builder.makeUnreachable(); |
| } else { |
| ret = builder.makeCall(std::get<Known>(info).target, |
| getOperands(), |
| curr->type, |
| curr->isReturn); |
| } |
| debuginfo::copyOriginalToReplacement(curr, ret, &func); |
| return ret; |
| }; |
| auto* ifTrueCall = makeCall(ifTrueCallInfo); |
| auto* ifFalseCall = makeCall(ifFalseCallInfo); |
| |
| // Create the if to pick the calls, and emit the final block. |
| auto* iff = builder.makeIf(select->condition, ifTrueCall, ifFalseCall); |
| blockContents.push_back(iff); |
| return builder.makeBlock(blockContents); |
| } |
| |
| } // namespace wasm::CallUtils |
| |
| #endif // wasm_ir_function_h |