You've already forked llvm-project
mirror of
https://github.com/AdaCore/llvm-project.git
synced 2026-02-12 13:52:35 -08:00
This is similar to the fix to the greedy driver in #153957 ; except that instead of removing unreachable code, we just ignore it. Operations like: ``` %add = arith.addi %add, %add : i64 ``` are legal in unreachable code. Unfortunately many patterns would be unsafe to apply on such IR and can lead to crashes or infinite loops.
231 lines
8.2 KiB
C++
231 lines
8.2 KiB
C++
//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Implements mlir::walkAndApplyPatterns.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
|
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/OperationSupport.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/Verifier.h"
|
|
#include "mlir/IR/Visitors.h"
|
|
#include "mlir/Rewrite/PatternApplicator.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
|
|
#define DEBUG_TYPE "walk-rewriter"
|
|
|
|
namespace mlir {
|
|
|
|
// Find all reachable blocks in the region and add them to the visitedBlocks
|
|
// set.
|
|
static void findReachableBlocks(Region ®ion,
|
|
DenseSet<Block *> &reachableBlocks) {
|
|
Block *entryBlock = ®ion.front();
|
|
reachableBlocks.insert(entryBlock);
|
|
// Traverse the CFG and add all reachable blocks to the blockList.
|
|
SmallVector<Block *> worklist({entryBlock});
|
|
while (!worklist.empty()) {
|
|
Block *block = worklist.pop_back_val();
|
|
Operation *terminator = &block->back();
|
|
for (Block *successor : terminator->getSuccessors()) {
|
|
if (reachableBlocks.contains(successor))
|
|
continue;
|
|
worklist.push_back(successor);
|
|
reachableBlocks.insert(successor);
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
struct WalkAndApplyPatternsAction final
|
|
: tracing::ActionImpl<WalkAndApplyPatternsAction> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
|
|
using ActionImpl::ActionImpl;
|
|
static constexpr StringLiteral tag = "walk-and-apply-patterns";
|
|
void print(raw_ostream &os) const override { os << tag; }
|
|
};
|
|
|
|
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
|
// Forwarding listener to guard against unsupported erasures of non-descendant
|
|
// ops/blocks. Because we use walk-based pattern application, erasing the
|
|
// op/block from the *next* iteration (e.g., a user of the visited op) is not
|
|
// valid. Note that this is only used with expensive pattern API checks.
|
|
struct ErasedOpsListener final : RewriterBase::ForwardingListener {
|
|
using RewriterBase::ForwardingListener::ForwardingListener;
|
|
|
|
void notifyOperationErased(Operation *op) override {
|
|
checkErasure(op);
|
|
ForwardingListener::notifyOperationErased(op);
|
|
}
|
|
|
|
void notifyBlockErased(Block *block) override {
|
|
checkErasure(block->getParentOp());
|
|
ForwardingListener::notifyBlockErased(block);
|
|
}
|
|
|
|
void checkErasure(Operation *op) const {
|
|
Operation *ancestorOp = op;
|
|
while (ancestorOp && ancestorOp != visitedOp)
|
|
ancestorOp = ancestorOp->getParentOp();
|
|
|
|
if (ancestorOp != visitedOp)
|
|
llvm::report_fatal_error(
|
|
"unsupported erasure in WalkPatternRewriter; "
|
|
"erasure is only supported for matched ops and their descendants");
|
|
}
|
|
|
|
Operation *visitedOp = nullptr;
|
|
};
|
|
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
|
} // namespace
|
|
|
|
void walkAndApplyPatterns(Operation *op,
|
|
const FrozenRewritePatternSet &patterns,
|
|
RewriterBase::Listener *listener) {
|
|
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
|
if (failed(verify(op)))
|
|
llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
|
|
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
|
|
|
MLIRContext *ctx = op->getContext();
|
|
PatternRewriter rewriter(ctx);
|
|
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
|
ErasedOpsListener erasedListener(listener);
|
|
rewriter.setListener(&erasedListener);
|
|
#else
|
|
rewriter.setListener(listener);
|
|
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
|
|
|
PatternApplicator applicator(patterns);
|
|
applicator.applyDefaultCostModel();
|
|
|
|
// Iterator on all reachable operations in the region.
|
|
// Also keep track if we visited the nested regions of the current op
|
|
// already to drive the post-order traversal.
|
|
struct RegionReachableOpIterator {
|
|
RegionReachableOpIterator(Region *region) : region(region) {
|
|
regionIt = region->begin();
|
|
if (regionIt != region->end())
|
|
blockIt = regionIt->begin();
|
|
if (!llvm::hasSingleElement(*region))
|
|
findReachableBlocks(*region, reachableBlocks);
|
|
}
|
|
// Advance the iterator to the next reachable operation.
|
|
void advance() {
|
|
assert(regionIt != region->end());
|
|
hasVisitedRegions = false;
|
|
if (blockIt == regionIt->end()) {
|
|
++regionIt;
|
|
while (regionIt != region->end() &&
|
|
!reachableBlocks.contains(&*regionIt))
|
|
++regionIt;
|
|
if (regionIt != region->end())
|
|
blockIt = regionIt->begin();
|
|
return;
|
|
}
|
|
++blockIt;
|
|
if (blockIt != regionIt->end()) {
|
|
LLVM_DEBUG({
|
|
llvm::dbgs() << "Incrementing block iterator, next op: "
|
|
<< OpWithFlags(&*blockIt,
|
|
OpPrintingFlags().skipRegions())
|
|
<< "\n";
|
|
});
|
|
}
|
|
}
|
|
// The region we're iterating over.
|
|
Region *region;
|
|
// The Block currently being iterated over.
|
|
Region::iterator regionIt;
|
|
// The Operation currently being iterated over.
|
|
Block::iterator blockIt;
|
|
// The set of blocks that are reachable in the current region.
|
|
DenseSet<Block *> reachableBlocks;
|
|
// Whether we've visited the nested regions of the current op already.
|
|
bool hasVisitedRegions = false;
|
|
};
|
|
|
|
// Worklist of regions to visit to drive the post-order traversal.
|
|
SmallVector<RegionReachableOpIterator> worklist;
|
|
|
|
LLVM_DEBUG(
|
|
{ llvm::dbgs() << "Starting walk-based pattern rewrite driver\n"; });
|
|
ctx->executeAction<WalkAndApplyPatternsAction>(
|
|
[&] {
|
|
// Perform a post-order traversal of the regions, visiting each
|
|
// reachable operation.
|
|
for (Region ®ion : op->getRegions()) {
|
|
assert(worklist.empty());
|
|
if (region.empty())
|
|
continue;
|
|
|
|
// Prime the worklist with the entry block of this region.
|
|
worklist.push_back({®ion});
|
|
while (!worklist.empty()) {
|
|
RegionReachableOpIterator &it = worklist.back();
|
|
if (it.regionIt == it.region->end()) {
|
|
// We're done with this region.
|
|
worklist.pop_back();
|
|
continue;
|
|
}
|
|
if (it.blockIt == it.regionIt->end()) {
|
|
// We're done with this block.
|
|
it.advance();
|
|
continue;
|
|
}
|
|
Operation *op = &*it.blockIt;
|
|
// If we haven't visited the nested regions of this op yet,
|
|
// enqueue them.
|
|
if (!it.hasVisitedRegions) {
|
|
it.hasVisitedRegions = true;
|
|
for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
|
|
if (nestedRegion.empty())
|
|
continue;
|
|
worklist.push_back({&nestedRegion});
|
|
}
|
|
}
|
|
// If we're not at the back of the worklist, we've enqueued some
|
|
// nested region for processing. We'll come back to this op later
|
|
// (post-order)
|
|
if (&it != &worklist.back())
|
|
continue;
|
|
|
|
// Preemptively increment the iterator, in case the current op
|
|
// would be erased.
|
|
it.advance();
|
|
|
|
LLVM_DEBUG({
|
|
llvm::dbgs() << "Visiting op: "
|
|
<< OpWithFlags(op, OpPrintingFlags().skipRegions())
|
|
<< "\n";
|
|
});
|
|
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
|
erasedListener.visitedOp = op;
|
|
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
|
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
|
|
LLVM_DEBUG({ llvm::dbgs() << "\tOp matched and rewritten\n"; });
|
|
}
|
|
}
|
|
},
|
|
{op});
|
|
|
|
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
|
if (failed(verify(op)))
|
|
llvm::report_fatal_error(
|
|
"walk pattern rewriter result IR failed to verify");
|
|
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
|
}
|
|
|
|
} // namespace mlir
|