blob: 73463eaa3107a6ebde4e0f645f89fd5159dc7ad3 [file] [edit]
/*
* Copyright 2016 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//
// Convert the AST to a CFG, while traversing it.
//
// Note that this is not the same as the relooper CFG. The relooper is
// designed for compilation to an AST, this is for processing. There is
// no built-in support for transforming this CFG into the AST back
// again, it is just metadata on the side for computation purposes.
//
// Usage: As the traversal proceeds, you can note information and add it to
// the current basic block using currBasicBlock, on the contents
// property, whose type is user-defined.
//
#ifndef cfg_traversal_h
#define cfg_traversal_h
#include "wasm.h"
#include "wasm-traversal.h"
namespace wasm {
template<typename SubType, typename VisitorType, typename Contents>
struct CFGWalker : public ControlFlowWalker<SubType, VisitorType> {
// public interface
struct BasicBlock {
Contents contents; // custom contents
std::vector<BasicBlock*> out, in;
};
BasicBlock* entry; // the entry block
BasicBlock* makeBasicBlock() { // override this with code to create a BasicBlock if necessary
return new BasicBlock();
}
// internal details
std::vector<std::unique_ptr<BasicBlock>> basicBlocks; // all the blocks
std::vector<BasicBlock*> loopTops; // blocks that are the tops of loops, i.e., have backedges to them
// traversal state
BasicBlock* currBasicBlock; // the current block in play during traversal. can be nullptr if unreachable,
// but note that we don't do a deep unreachability analysis - just enough
// to avoid constructing obviously-unreachable blocks (we do a full reachability
// analysis on the CFG once it is constructed).
std::map<Expression*, std::vector<BasicBlock*>> branches; // a block or loop => its branches
std::vector<BasicBlock*> ifStack;
std::vector<BasicBlock*> loopStack;
void startBasicBlock() {
currBasicBlock = ((SubType*)this)->makeBasicBlock();
basicBlocks.push_back(std::unique_ptr<BasicBlock>(currBasicBlock));
}
void startUnreachableBlock() {
currBasicBlock = nullptr;
}
static void doStartUnreachableBlock(SubType* self, Expression** currp) {
self->startUnreachableBlock();
}
void link(BasicBlock* from, BasicBlock* to) {
if (!from || !to) return; // if one of them is not reachable, ignore
from->out.push_back(to);
to->in.push_back(from);
}
static void doEndBlock(SubType* self, Expression** currp) {
auto* curr = (*currp)->cast<Block>();
if (!curr->name.is()) return;
auto iter = self->branches.find(curr);
if (iter == self->branches.end()) return;
auto& origins = iter->second;
if (origins.size() == 0) return;
// we have branches to here, so we need a new block
auto* last = self->currBasicBlock;
self->startBasicBlock();
self->link(last, self->currBasicBlock); // fallthrough
// branches to the new one
for (auto* origin : origins) {
self->link(origin, self->currBasicBlock);
}
self->branches.erase(curr);
}
static void doStartIfTrue(SubType* self, Expression** currp) {
auto* last = self->currBasicBlock;
self->startBasicBlock();
self->link(last, self->currBasicBlock); // ifTrue
self->ifStack.push_back(last); // the block before the ifTrue
}
static void doStartIfFalse(SubType* self, Expression** currp) {
self->ifStack.push_back(self->currBasicBlock); // the ifTrue fallthrough
self->startBasicBlock();
self->link(self->ifStack[self->ifStack.size() - 2], self->currBasicBlock); // before if -> ifFalse
}
static void doEndIf(SubType* self, Expression** currp) {
auto* last = self->currBasicBlock;
self->startBasicBlock();
self->link(last, self->currBasicBlock); // last one is ifFalse's fallthrough if there was one, otherwise it's the ifTrue fallthrough
if ((*currp)->cast<If>()->ifFalse) {
// we just linked ifFalse, need to link ifTrue to the end
self->link(self->ifStack.back(), self->currBasicBlock);
self->ifStack.pop_back();
} else {
// no ifFalse, so add a fallthrough for if the if is not taken
self->link(self->ifStack.back(), self->currBasicBlock);
}
self->ifStack.pop_back();
}
static void doStartLoop(SubType* self, Expression** currp) {
auto* last = self->currBasicBlock;
self->startBasicBlock();
self->loopTops.push_back(self->currBasicBlock); // a loop with no backedges would still be counted here, but oh well
self->link(last, self->currBasicBlock);
self->loopStack.push_back(self->currBasicBlock);
}
static void doEndLoop(SubType* self, Expression** currp) {
auto* last = self->currBasicBlock;
self->startBasicBlock();
self->link(last, self->currBasicBlock); // fallthrough
auto* curr = (*currp)->cast<Loop>();
// branches to the top of the loop
if (curr->name.is()) {
auto* loopStart = self->loopStack.back();
auto& origins = self->branches[curr];
for (auto* origin : origins) {
self->link(origin, loopStart);
}
self->branches.erase(curr);
}
self->loopStack.pop_back();
}
static void doEndBreak(SubType* self, Expression** currp) {
auto* curr = (*currp)->cast<Break>();
self->branches[self->findBreakTarget(curr->name)].push_back(self->currBasicBlock); // branch to the target
if (curr->condition) {
auto* last = self->currBasicBlock;
self->startBasicBlock();
self->link(last, self->currBasicBlock); // we might fall through
} else {
self->startUnreachableBlock();
}
}
static void doEndSwitch(SubType* self, Expression** currp) {
auto* curr = (*currp)->cast<Switch>();
std::set<Name> seen; // we might see the same label more than once; do not spam branches
for (Name target : curr->targets) {
if (!seen.count(target)) {
self->branches[self->findBreakTarget(target)].push_back(self->currBasicBlock); // branch to the target
seen.insert(target);
}
}
if (!seen.count(curr->default_)) {
self->branches[self->findBreakTarget(curr->default_)].push_back(self->currBasicBlock); // branch to the target
}
self->startUnreachableBlock();
}
static void scan(SubType* self, Expression** currp) {
Expression* curr = *currp;
switch (curr->_id) {
case Expression::Id::BlockId: {
self->pushTask(SubType::doEndBlock, currp);
break;
}
case Expression::Id::IfId: {
self->pushTask(SubType::doEndIf, currp);
auto* ifFalse = curr->cast<If>()->ifFalse;
if (ifFalse) {
self->pushTask(SubType::scan, &curr->cast<If>()->ifFalse);
self->pushTask(SubType::doStartIfFalse, currp);
}
self->pushTask(SubType::scan, &curr->cast<If>()->ifTrue);
self->pushTask(SubType::doStartIfTrue, currp);
self->pushTask(SubType::scan, &curr->cast<If>()->condition);
return; // don't do anything else
}
case Expression::Id::LoopId: {
self->pushTask(SubType::doEndLoop, currp);
break;
}
case Expression::Id::BreakId: {
self->pushTask(SubType::doEndBreak, currp);
break;
}
case Expression::Id::SwitchId: {
self->pushTask(SubType::doEndSwitch, currp);
break;
}
case Expression::Id::ReturnId: {
self->pushTask(SubType::doStartUnreachableBlock, currp);
break;
}
case Expression::Id::UnreachableId: {
self->pushTask(SubType::doStartUnreachableBlock, currp);
break;
}
default: {}
}
ControlFlowWalker<SubType, VisitorType>::scan(self, currp);
switch (curr->_id) {
case Expression::Id::LoopId: {
self->pushTask(SubType::doStartLoop, currp);
break;
}
default: {}
}
}
void doWalkFunction(Function* func) {
basicBlocks.clear();
startBasicBlock();
entry = currBasicBlock;
ControlFlowWalker<SubType, VisitorType>::doWalkFunction(func);
assert(branches.size() == 0);
assert(ifStack.size() == 0);
assert(loopStack.size() == 0);
}
std::unordered_set<BasicBlock*> findLiveBlocks() {
std::unordered_set<BasicBlock*> alive;
std::unordered_set<BasicBlock*> queue;
queue.insert(entry);
while (queue.size() > 0) {
auto iter = queue.begin();
auto* curr = *iter;
queue.erase(iter);
alive.insert(curr);
for (auto* out : curr->out) {
if (!alive.count(out)) queue.insert(out);
}
}
return alive;
}
void unlinkDeadBlocks(std::unordered_set<BasicBlock*> alive) {
for (auto& block : basicBlocks) {
if (!alive.count(block.get())) {
block->in.clear();
block->out.clear();
continue;
}
block->in.erase(std::remove_if(block->in.begin(), block->in.end(), [&alive](BasicBlock* other) {
return !alive.count(other);
}), block->in.end());
block->out.erase(std::remove_if(block->out.begin(), block->out.end(), [&alive](BasicBlock* other) {
return !alive.count(other);
}), block->out.end());
}
}
// TODO: utility method for optimizing cfg, removing empty blocks depending on their .content
std::map<BasicBlock*, size_t> debugIds;
void generateDebugIds() {
if (debugIds.size() > 0) return;
for (auto& block : basicBlocks) {
debugIds[block.get()] = debugIds.size();
}
}
void dumpCFG(std::string message) {
std::cout << "<==\nCFG [" << message << "]:\n";
generateDebugIds();
for (auto& block : basicBlocks) {
assert(debugIds.count(block.get()) > 0);
std::cout << " block " << debugIds[block.get()] << ":\n";
block->contents.dump(static_cast<SubType*>(this)->getFunction());
for (auto& in : block->in) {
assert(debugIds.count(in) > 0);
assert(std::find(in->out.begin(), in->out.end(), block.get()) != in->out.end()); // must be a parallel link back
}
for (auto& out : block->out) {
assert(debugIds.count(out) > 0);
std::cout << " out: " << debugIds[out] << "\n";
assert(std::find(out->in.begin(), out->in.end(), block.get()) != out->in.end()); // must be a parallel link back
}
checkDuplicates(block->in);
checkDuplicates(block->out);
}
std::cout << "==>\n";
}
private:
// links in out and in must be unique
void checkDuplicates(std::vector<BasicBlock*>& list) {
std::unordered_set<BasicBlock*> seen;
for (auto* curr : list) {
assert(seen.count(curr) == 0);
seen.insert(curr);
}
}
void removeLink(std::vector<BasicBlock*>& list, BasicBlock* toRemove) {
if (list.size() == 1) {
list.clear();
return;
}
for (size_t i = 0; i < list.size(); i++) {
if (list[i] == toRemove) {
list[i] = list.back();
list.pop_back();
return;
}
}
WASM_UNREACHABLE();
}
};
} // namespace wasm
#endif // cfg_traversal_h