blob: 0251a0c589f863754eb0a7d5981aab24ea36a246 [file] [log] [blame] [edit]
/*
* Copyright 2020 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//
// Instrument the wasm to convert NaN values at runtime into 0s. That is, every
// operation that might produce a NaN will go through a helper function which
// filters out NaNs (replacing them with 0). This ensures that NaNs are never
// consumed by any instructions, which is useful when fuzzing between VMs that
// differ on wasm's nondeterminism around NaNs.
//
#include "ir/names.h"
#include "ir/properties.h"
#include "pass.h"
#include "wasm-builder.h"
#include "wasm.h"
namespace wasm {
struct DeNaN : public WalkerPass<
ControlFlowWalker<DeNaN, UnifiedExpressionVisitor<DeNaN>>> {
// Adds calls.
bool addsEffects() override { return true; }
Name deNan32, deNan64, deNan128;
void visitExpression(Expression* expr) {
// If the expression returns a floating-point value, ensure it is not a
// NaN. If we can do this at compile time, do it now, which is useful for
// initializations of global (which we can't do a function call in). Note
// that we don't instrument local.gets, which would cause problems if we
// ran this pass more than once (the added functions use gets, and we don't
// want to instrument them).
if (expr->is<LocalGet>()) {
return;
}
// If the result just falls through without being modified, then we've
// already fixed it up earlier.
if (Properties::isResultFallthrough(expr)) {
return;
}
Builder builder(*getModule());
Expression* replacement = nullptr;
auto* c = expr->dynCast<Const>();
if (expr->type == Type::f32) {
if (c && c->value.isNaN()) {
replacement = builder.makeConst(float(0));
} else if (!c) {
replacement = builder.makeCall(deNan32, {expr}, Type::f32);
}
} else if (expr->type == Type::f64) {
if (c && c->value.isNaN()) {
replacement = builder.makeConst(double(0));
} else if (!c) {
replacement = builder.makeCall(deNan64, {expr}, Type::f64);
}
} else if (expr->type == Type::v128) {
if (c && hasNaNLane(c)) {
uint8_t zero[16] = {};
replacement = builder.makeConst(Literal(zero));
} else if (!c) {
replacement = builder.makeCall(deNan128, {expr}, Type::v128);
}
}
if (replacement) {
// We can't do this outside of a function, like in a global initializer,
// where a call would be illegal.
if (replacement->is<Const>() || getFunction()) {
replaceCurrent(replacement);
} else {
std::cerr << "warning: cannot de-nan outside of function context\n";
}
}
}
void visitFunction(Function* func) {
if (func->imported()) {
return;
}
// Instrument all locals as they enter the function.
Builder builder(*getModule());
std::vector<Expression*> fixes;
auto num = func->getNumParams();
for (Index i = 0; i < num; i++) {
if (func->getLocalType(i) == Type::f32) {
fixes.push_back(builder.makeLocalSet(
i,
builder.makeCall(
deNan32, {builder.makeLocalGet(i, Type::f32)}, Type::f32)));
} else if (func->getLocalType(i) == Type::f64) {
fixes.push_back(builder.makeLocalSet(
i,
builder.makeCall(
deNan64, {builder.makeLocalGet(i, Type::f64)}, Type::f64)));
} else if (func->getLocalType(i) == Type::v128) {
fixes.push_back(builder.makeLocalSet(
i,
builder.makeCall(
deNan128, {builder.makeLocalGet(i, Type::v128)}, Type::v128)));
}
}
if (!fixes.empty()) {
fixes.push_back(func->body);
func->body = builder.makeBlock(fixes);
// Merge blocks so we don't add an unnecessary one.
PassRunner runner(getModule(), getPassOptions());
runner.setIsNested(true);
runner.add("merge-blocks");
runner.run();
}
}
void doWalkModule(Module* module) {
// Pick names for the helper functions.
deNan32 = Names::getValidFunctionName(*module, "deNan32");
deNan64 = Names::getValidFunctionName(*module, "deNan64");
deNan128 = Names::getValidFunctionName(*module, "deNan128");
ControlFlowWalker<DeNaN, UnifiedExpressionVisitor<DeNaN>>::doWalkModule(
module);
// Add helper functions after the walk, so they are not instrumented.
addFunc(module, deNan32, Type::f32, Literal(float(0)), EqFloat32);
addFunc(module, deNan64, Type::f64, Literal(double(0)), EqFloat64);
if (module->features.hasSIMD()) {
uint8_t zero128[16] = {};
addFunc(module, deNan128, Type::v128, Literal(zero128));
}
}
// Add a de-NaN-ing helper function.
void addFunc(Module* module,
Name name,
Type type,
Literal literal,
std::optional<BinaryOp> op = {}) {
Builder builder(*module);
auto func = Builder::makeFunction(name, Signature(type, type), {});
// Compare the value to itself to check if it is a NaN, and return 0 if
// so:
//
// (if (result f*)
// (f*.eq
// (local.get $0)
// (local.get $0)
// )
// (local.get $0)
// (f*.const 0)
// )
Expression* condition;
if (type != Type::v128) {
// Generate a simple condition.
assert(op);
condition = builder.makeBinary(
*op, builder.makeLocalGet(0, type), builder.makeLocalGet(0, type));
} else {
assert(!op);
// v128 is trickier as the 128 bits may contain f32s or f64s, and we
// need to check for nans both ways in principle. However, the f32 NaN
// pattern is a superset of f64, since it checks less bits (8 bit
// exponent vs 11), and it is checked in more places (4 32-bit values vs
// 2 64-bit ones), so we can just check that. That is, this reduces to 4
// checks of f32s, but is otherwise the same as a check of a single f32.
//
// However there is additional complexity, which is that if we do
// EqVecF32x4 then we get all-1s for each case where we compare equal.
// That itself is a NaN pattern, which means that running this pass
// twice would interfere with itself. To avoid that we'd need a way to
// detect our previous instrumentation and not instrument it, but that
// is tricky (we can't depend on function names etc. while fuzzing).
// Instead, extract the lanes and use f32 checks.
auto getLane = [&](Index index) {
return builder.makeSIMDExtract(
ExtractLaneVecF32x4, builder.makeLocalGet(0, type), index);
};
auto getLaneCheck = [&](Index index) {
return builder.makeBinary(EqFloat32, getLane(index), getLane(index));
};
auto* firstTwo =
builder.makeBinary(AndInt32, getLaneCheck(0), getLaneCheck(1));
auto* lastTwo =
builder.makeBinary(AndInt32, getLaneCheck(2), getLaneCheck(3));
condition = builder.makeBinary(AndInt32, firstTwo, lastTwo);
}
func->body = builder.makeIf(
condition, builder.makeLocalGet(0, type), builder.makeConst(literal));
module->addFunction(std::move(func));
};
// Check if a contant v128 may contain f32 or f64 NaNs.
bool hasNaNLane(Const* c) {
assert(c->type == Type::v128);
auto value = c->value;
// Compute if all f32s are equal to themselves.
auto test32 = value.eqF32x4(value);
test32 = test32.allTrueI32x4();
return !test32.getInteger();
}
};
Pass* createDeNaNPass() { return new DeNaN(); }
} // namespace wasm