blob: 44bc16e95e6119b139164b2c4a6a32873b1b31e3 [file] [log] [blame] [edit]
/*
* Copyright 2024 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//
// Instruments the build with code to intercept selected function calls.
// This can be e.g. used to trace allocations (malloc, free, calloc, realloc)
// and build tools for memory usage analysis.
// The pass supports SIMD but the multi-value feature is not supported yet.
//
// Instrumenting void free(void*):
// Instrumenting function `void* malloc(int32_t)` with a user-defined
// name of the tracer `trace_alloc` and function `void free(void*)`
// with the default name of the tracer `trace_free` (`trace_` prefix
// is added by default):
// wasm-opt --trace-calls=malloc:trace_alloc,free -o test-opt.wasm test.wasm
//
// Before:
// (call $malloc
// (local.const 32))
// (call $free (i32.const 64))
//
// After:
// (local $0 i32)
// (local $1 i32)
// (local $2 i32)
// (block (result i32)
// (call $trace_alloc
// (local.get $0)
// (local.tee $1
// (call $malloc
// (local.tee $0 (i32.const 2))
// )
// )
// )
// )
// (block
// (call $free
// (local.tee $3
// (i32.const 64)
// )
// )
// (call $trace_free
// (local.get $3)
// )
// )
#include <map>
#include "asmjs/shared-constants.h"
#include "ir/import-utils.h"
#include "pass.h"
#include "support/string.h"
#include "wasm-builder.h"
namespace wasm {
using TracedFunctions = std::map<Name /* originName */, Name /* tracerName */>;
struct AddTraceWrappers : public WalkerPass<PostWalker<AddTraceWrappers>> {
AddTraceWrappers(TracedFunctions tracedFunctions)
: tracedFunctions(std::move(tracedFunctions)) {}
void visitCall(Call* curr) {
auto* target = getModule()->getFunction(curr->target);
auto iter = tracedFunctions.find(target->name);
if (iter != tracedFunctions.end()) {
addInstrumentation(curr, target, iter->second);
}
}
private:
void addInstrumentation(Call* curr,
const wasm::Function* target,
const Name& wrapperName) {
Builder builder(*getModule());
std::vector<wasm::Expression*> realCallParams, trackerCallParams;
for (const auto& op : curr->operands) {
auto localVar = builder.addVar(getFunction(), op->type);
realCallParams.push_back(builder.makeLocalTee(localVar, op, op->type));
trackerCallParams.push_back(builder.makeLocalGet(localVar, op->type));
}
auto resultType = target->type.getSignature().results;
auto realCall = builder.makeCall(target->name, realCallParams, resultType);
if (resultType.isConcrete()) {
auto resultLocal = builder.addVar(getFunction(), resultType);
trackerCallParams.insert(
trackerCallParams.begin(),
builder.makeLocalTee(resultLocal, realCall, resultType));
replaceCurrent(builder.makeBlock(
{builder.makeCall(
wrapperName, trackerCallParams, Type::BasicType::none),
builder.makeLocalGet(resultLocal, resultType)}));
} else {
replaceCurrent(builder.makeBlock(
{realCall,
builder.makeCall(
wrapperName, trackerCallParams, Type::BasicType::none)}));
}
}
TracedFunctions tracedFunctions;
};
struct TraceCalls : public Pass {
// Adds calls to new imports.
bool addsEffects() override { return true; }
void run(Module* module) override {
auto functionsDefinitions =
getArgument("trace-calls",
"TraceCalls usage: wasm-opt "
"--trace-calls=FUNCTION_TO_TRACE[:TRACER_NAME][,...]");
auto tracedFunctions = parseArgument(functionsDefinitions);
for (const auto& tracedFunction : tracedFunctions) {
auto func = module->getFunctionOrNull(tracedFunction.first);
if (!func) {
std::cerr << "[TraceCalls] Function '" << tracedFunction.first
<< "' not found" << std::endl;
} else {
addImport(module, *func, tracedFunction.second);
}
}
AddTraceWrappers(std::move(tracedFunctions)).run(getPassRunner(), module);
}
private:
Type getTracerParamsType(ImportInfo& info, const Function& func) {
auto resultsType = func.type.getSignature().results;
if (resultsType.isTuple()) {
Fatal() << "Failed to instrument function '" << func.name
<< "': Multi-value result type is not supported";
}
std::vector<Type> tracerParamTypes;
if (resultsType.isConcrete()) {
tracerParamTypes.push_back(resultsType);
}
for (auto& op : func.type.getSignature().params) {
tracerParamTypes.push_back(op);
}
return Type(tracerParamTypes);
}
TracedFunctions parseArgument(const std::string& arg) {
TracedFunctions tracedFunctions;
for (const auto& definition : String::Split(arg, ",")) {
if (definition.empty()) {
// Empty definition, ignore.
continue;
}
std::string originName, traceName;
parseFunctionName(definition, originName, traceName);
tracedFunctions[Name(originName)] = Name(traceName);
}
return tracedFunctions;
}
void parseFunctionName(const std::string& str,
std::string& originName,
std::string& traceName) {
auto parts = String::Split(str, ":");
switch (parts.size()) {
case 1:
originName = parts[0];
traceName = "trace_" + originName;
break;
case 2:
originName = parts[0];
traceName = parts[1];
break;
default:
Fatal() << "Failed to parse function name ('" << str
<< "'): expected format FUNCTION_TO_TRACE[:TRACER_NAME]";
}
}
void addImport(Module* wasm, const Function& f, const Name& tracerName) {
ImportInfo info(*wasm);
if (!info.getImportedFunction(ENV, tracerName)) {
auto import = Builder::makeFunction(
tracerName, Signature(getTracerParamsType(info, f), Type::none), {});
import->module = ENV;
import->base = tracerName;
wasm->addFunction(std::move(import));
}
}
};
Pass* createTraceCallsPass() { return new TraceCalls(); }
} // namespace wasm