Pattern Rewriting : Generic DAG-to-DAG Rewriting
- Introduction
- Defining Patterns
- Benefit
- 例子:将加法和乘法优化成乘法
- 匹配谓词是什么?
- 举例解释
- Root Operation Name(根操作名称)
- match and rewrite implementation(匹配和重写实现)
- Restrictions
- 递归应用(Recursion):
- 剥离迭代
- 调试名称和标签
- 初始化
- 构造
- Pattern Rewriter (模式重写器)
- 模式应用
- Dialect Conversion Driver(方言转换驱动器):
- Greedy Pattern Rewrite Driver(贪婪模式重写驱动器):
Introduction
模式重写框架主要可以分解为两个部分:模式定义和模式应用。
Defining Patterns
模式是通过继承 RewritePattern 类来定义的。该类代表了 MLIR 中所有重写模式的基类,包括以下组成部分:
Benefit
-
预期好处:应用一个模式(pattern)可以带来一定的优化效果,这个优化效果在模式创建时是固定的,但也可以在模式初始化时根据具体情况(例如目标架构)动态计算。
-
静态 vs 动态:静态的好处是预先确定的,而动态的好处则可以在运行时根据具体情况计算。
-
优化模式匹配:通过限制动态计算,可以让模式匹配更高效。研究表明,使用“匹配谓词”(简单条件判断)可以避免大部分情况下的动态计算。也就是说,我们可以为每种可能的情况预先创建一个模式,然后用简单的条件判断来选择合适的模式。
例子:将加法和乘法优化成乘法
假设我们有一个简单的表达式:a + a,我们希望将其优化成2 * a,因为乘法运算通常比加法运算更高效。
-
创建模式:我们创建一个模式,识别出a + a的形式,并将其转换为2 * a。这个优化的好处是显而易见的,因为乘法比加法更高效。
-
静态好处:在创建模式时,我们预先知道这个转换会带来优化,所以这是一个静态的好处。
-
动态计算:如果我们针对不同的硬件架构进行优化,比如某些架构上的加法比乘法更快,我们可以在模式初始化时根据架构信息动态决定是否应用这个优化。
-
匹配谓词:为了避免复杂的动态计算,我们可以创建多个版本的模式。例如,一个版本适用于某些架构,另一个版本适用于其他架构。使用简单的条件判断(匹配谓词)来选择哪个版本的模式。
// 初始代码
%result = add %a, %a// 应用模式后的优化代码
%result = mul constant(2), %a
匹配谓词是什么?
“匹配谓词”(match predicate)是一种条件判断,用来决定某个模式是否应该被应用。在MLIR中,模式匹配是将一个特定的代码模式转换为更优化的形式,而匹配谓词就是用来判断这个代码模式是否符合某些条件,从而决定是否进行转换。
举例解释
假设我们有一个模式,用来优化某种数学表达式,比如将x * 1优化为x,因为乘以1不会改变值。
没有匹配谓词的情况
在最简单的情况下,我们可以直接定义一个模式:
pattern {match: "mul %x, 1"rewrite: "%x"
}
使用匹配谓词的情况
但是,有时候我们需要一些额外的条件来决定是否应用这个模式。比如,我们只有在某些特定情况下(比如x是一个特定类型的变量)才希望进行这个优化。这个时候就需要用到匹配谓词。
pattern {match: "mul %x, 1"predicate: "isSpecialType(%x)"rewrite: "%x"
}
Root Operation Name(根操作名称)
这是一个可选的参数,用于指明这个模式(pattern)要匹配的根操作的名称。如果指定了根操作名称,那么只有具有该名称的操作才会被提供给匹配和重写的实现代码。如果没有指定,那么任何类型的操作都可能被提供。提供根操作名称有助于在应用成本模型时简化模式分析。如果要匹配任何类型的操作,需要提供一个特殊的标签(MatchAnyOpTypeTag)来明确意图。
match and rewrite implementation(匹配和重写实现)
这是指匹配给定的根操作并重写IR(中间表示)的代码块。一个RewritePattern可以通过独立的match和rewrite方法,或通过一个结合的matchAndRewrite方法来指定其实现。当使用结合的matchAndRewrite方法时,在匹配成功之前不应进行任何IR的变动。结合的matchAndRewrite方法在匹配和重写阶段需要非平凡的可重新计算信息时特别有用。
class MyPattern : public RewritePattern {
public:// 构造一个只匹配名称为`MyOp`的操作的模式MyPattern(PatternBenefit benefit, MLIRContext *context): RewritePattern(MyOp::getOperationName(), benefit, context) {}// 构造一个匹配任何类型操作的模式MyPattern(PatternBenefit benefit): RewritePattern(benefit, MatchAnyOpTypeTag()) {}// 使用独立的match和rewrite方法来实现匹配和重写LogicalResult match(Operation *op) const override {// 如果模式匹配,返回`success()`,否则返回`failure`// ... (具体匹配逻辑)}void rewrite(Operation *op, PatternRewriter &rewriter) {// 使用提供的rewriter对IR进行变动// ... (具体重写逻辑)}// 使用结合的matchAndRewrite方法来实现匹配和重写LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) {// 这个方法同时进行匹配和变动// 注意在匹配成功之前不应进行IR变动// ... (具体逻辑)}
};
Restrictions
- 匹配阶段:在这个阶段,不能对IR进行任何修改。也就是说,只能读数据,不能改数据。
- 重写阶段:在这个阶段,可以对IR进行修改,但必须通过指定的PatternRewriter来操作。PatternRewriter类提供了执行各种可能的修改操作的接口。例如,如果要删除一个操作(operation),不能直接调用这个操作的删除方法,而是应该使用PatternRewriter提供的删除方法eraseOp。此外,根操作必须被就地更新、替换或删除。
struct MyPattern : public mlir::RewritePattern {MyPattern(mlir::MLIRContext *context): mlir::RewritePattern("my_op", 1, context) {}// 匹配阶段mlir::LogicalResult match(mlir::Operation *op) const override {// 只能读取数据,不能修改opif (auto myOp = llvm::dyn_cast<MyOp>(op)) {return mlir::success();}return mlir::failure();}// 重写阶段void rewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override {// 使用PatternRewriter进行IR的修改rewriter.setInsertionPoint(op);auto newOp = rewriter.create<NewOp>(op->getLoc(), op->getOperands());// 使用PatternRewriter来删除操作rewriter.eraseOp(op);}
};
递归应用(Recursion):
- 递归在编程中是指一个函数调用自身。在模式重写中,一个模式可以应用在它自己产生的结果上。
- 想象一下,你有一个模式,它每次运行都会从一个循环中去掉一层迭代。如果这个循环可以剥掉多层迭代,那么这个模式可能会被反复应用多次。
- 问题是,这种反复应用可能会引起无限循环,导致程序无法停止运行。因此,系统默认假设所有模式都不能安全地递归,如果检测到递归就会停止。
- 如果你确定某个模式可以安全地递归,你需要显式告诉系统,这样系统就不会阻止它。这可以通过调用 setHasBoundedRewriteRecursion 来完成。
剥离迭代
一种优化技术,主要用于从循环中提取出一个或几个单次迭代,使其单独处理。这样做可以帮助更好地进行代码优化,例如更好地并行化或者处理特殊情况。我们通过一个具体的例子来说明这个过程。
假设我们有一个简单的MLIR循环,如下所示:
func @example(%N: index) {%c0 = constant 0 : index%c1 = constant 1 : indexscf.for %i = %c0 to %N step %c1 {// 循环体}return
}
剥离单次迭代的具体步骤
- 确定循环的范围和步长:首先,我们需要知道循环的下界、上界和步长。
- 生成剥离的单次迭代:在原始循环之前创建一个新的循环,范围是从下界到下界加上步长。
- 更新原始循环的范围:将原始循环的下界更新为新的下界,即下界加上步长。
下面是用C++和MLIR的具体实现,展示如何进行这一步骤。
C++/MLIR实现
首先,我们定义一个模式来进行剥离操作:
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Builders.h"
#include "mlir/Dialect/SCF/SCFOps.h"using namespace mlir;struct PeelLoopPattern : public RewritePattern {explicit PeelLoopPattern(MLIRContext *context): RewritePattern(scf::ForOp::getOperationName(), 1, context) {}LogicalResult matchAndRewrite(Operation *op,PatternRewriter &rewriter) const override {auto forOp = cast<scf::ForOp>(op);// 假设我们只处理步长为1的情况if (!matchPattern(forOp.getStep(), m_One())) {return failure();}// 提取循环的范围Value lowerBound = forOp.getLowerBound();Value upperBound = forOp.getUpperBound();Value step = forOp.getStep();// 生成剥离的单次迭代rewriter.setInsertionPoint(forOp);Value peeledIter = rewriter.create<scf::ForOp>(forOp.getLoc(), lowerBound, rewriter.create<AddIOp>(forOp.getLoc(), lowerBound, step), step,forOp.getIterOperands());// 将循环体移动到新的单次迭代中rewriter.inlineRegionBefore(forOp.getRegion(), peeledIter.getRegion(),peeledIter.getRegion().begin());// 更新原始循环的范围rewriter.setInsertionPointAfter(peeledIter);Value newLowerBound = rewriter.create<AddIOp>(forOp.getLoc(), lowerBound, step);rewriter.updateRootInPlace(forOp, [&]() {forOp.setLowerBound(newLowerBound);});return success();}
};void registerPeelLoopPattern(RewritePatternSet &patterns) {patterns.add<PeelLoopPattern>(patterns.getContext());
}
然后,我们需要将这个模式注册到MLIR Pass中,并在Pass中应用它:
struct PeelLoopPass : public PassWrapper<PeelLoopPass, OperationPass<FuncOp>> {void runOnOperation() override {FuncOp func = getOperation();RewritePatternSet patterns(&getContext());registerPeelLoopPattern(patterns);if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {signalPassFailure();}}
};std::unique_ptr<Pass> createPeelLoopPass() {return std::make_unique<PeelLoopPass>();
}
经过上述过程,原始的MLIR循环:
scf.for %i = %c0 to %N step %c1 {// 循环体
}
将被转换为:
scf.for %i = %c0 to %c1 step %c1 {// 剥离的单次迭代的循环体
}
scf.for %i = %c1 to %N step %c1 {// 剩余迭代的循环体
}
这样,我们就完成了剥离迭代的操作。剥离后的单次迭代可以单独优化或并行化处理。
调试名称和标签
在调试代码时,我们有时需要追踪特定的模式(相当于代码中的模板或者规则)。为了方便,我们可以给这些模式起一个调试名称(类似于给每个模式贴一个标签),这样在查看调试信息时,就能很容易地知道是哪个模式在起作用。此外,我们还可以给一组模式起一个共同的标签,这样可以方便地对这组模式进行过滤和分类。
假设我们有一个模式叫做 MyPattern,它是我们定义的一种重写规则。我们可以给它设置一个调试名称和标签:
class MyPattern : public RewritePattern {
public:using RewritePattern::RewritePattern;void initialize() {setDebugName("MyPattern");addDebugLabels("MyRewritePass");}
};// 在某个地方,我们要把这些模式添加到一个集合中,并给它们设置一个公共标签:
void populateMyPatterns(RewritePatternSet &patterns, MLIRContext *ctx) {patterns.addWithLabel<MyPattern>("MyRewritePatterns", ctx);
}
初始化
有些模式在使用前需要进行特殊的初始化,比如如果一个模式会递归调用自身,那么我们需要明确地标记它可以处理这种递归。这种初始化可以在模式的构造函数中完成,也可以通过一个专门的初始化方法来完成。
仍然以 MyPattern 为例,如果这个模式需要处理递归调用,我们可以这样做:
class MyPattern : public RewritePattern {
public:using RewritePattern::RewritePattern;void initialize() {setHasBoundedRewriteRecursion();}
};
构造
为了确保模式在创建后被正确初始化并且可以正常使用,我们建议使用一种标准的创建方法。这种方法确保所有需要的初始化都已经完成。
假设我们需要创建一个 MyPattern 的实例并添加到模式集合中,我们可以这样做:
void populateMyPatterns(RewritePatternSet &patterns, MLIRContext *ctx) {// 使用 create<T> 方法来创建并初始化模式auto myPattern = RewritePattern::create<MyPattern>(ctx);patterns.add(std::move(myPattern));
}
Pattern Rewriter (模式重写器)
PatternRewriter 是一个特殊的类,允许模式(pattern)与模式应用的驱动程序进行通信。所有对中间表示(IR)的更改,包括创建,必须通过PatternRewriter类进行。这是因为底层的模式驱动程序可能有状态,如果直接进行更改会使这些状态无效。
下面是一些常见的PatternRewriter API示例,请参考类文档以获取最新的API列表:
- 擦除操作:eraseOp
这个方法用来删除没有结果或者其结果没有被使用的操作。
- 通知匹配失败的原因:notifyMatchFailure
这个方法允许在matchAndRewrite方法中提供一个诊断消息,说明为什么一个模式匹配失败。如何显示这个消息取决于具体的模式驱动程序。
- 替换操作:replaceOp/replaceOpWithNewOp
这个方法用提供的一组值替换一个操作的结果,并擦除该操作。
- 原地更新操作:(start|cancel|finalize)OpModification
这是一组方法,提供类似事务的API,用于在模式中原地更新操作的属性、位置、操作数或后继者。更新事务通过startOpModification开始,可以用cancelOpModification取消或用finalizeOpModification完成。一个方便的封装modifyOpInPlace可以在回调周围自动包裹开始和完成
模式应用
我们定义了一些优化或转换模式,然后将这些模式应用到某个程序或数据结构上,以优化其性能或改变其结构。
-
RewritePatternSet:这是一个用来存储所有模式的集合,就像一个模式的清单。
-
PatternRewriter:这是一个工具,用来实际执行模式中的变更。为了确保在执行变更时不会破坏系统的状态,我们需要定制这个工具。
-
PatternApplicator:这是一个负责实际应用模式的类。它使用一个成本模型来决定哪些模式最值得应用,并按照这个模型来应用模式。
-
成本模型:这是一个用来评估每个模式收益的算法,帮助我们决定应该优先应用哪个模式。
以下是一个简单的MLIR示例,展示如何定义和应用一个模式来优化一个操作:
Step 1: 定义一个简单的MLIR操作
module {func @simple_op(%arg0: i32) -> i32 {%0 = "my_dialect.my_op"(%arg0) : (i32) -> i32return %0 : i32}
}
Step 2: 定义一个优化模式
我们将定义一个模式来将这个操作优化为另一个操作,例如将 my_dialect.my_op 转换为 my_dialect.optimized_op。
class MyPattern : public mlir::RewritePattern {
public:MyPattern(mlir::MLIRContext *context): RewritePattern("my_dialect.my_op", 1, context) {}mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override {// 检查操作是否为目标操作if (op->getName().getStringRef() != "my_dialect.my_op")return mlir::failure();// 创建一个新的操作来替换旧的操作rewriter.replaceOpWithNewOp<mlir::Operation>(op, "my_dialect.optimized_op",op->getResultTypes(), op->getOperands());return mlir::success();}
};
Step 3: 收集模式并应用
void applyMyPatternDriver(mlir::Operation *op, mlir::MLIRContext *context) {mlir::RewritePatternSet patterns(context);patterns.add<MyPattern>(context);mlir::FrozenRewritePatternSet frozenPatterns(std::move(patterns));mlir::PatternApplicator applicator(frozenPatterns);// 应用默认的成本模型applicator.applyDefaultCostModel();mlir::PatternRewriter rewriter(context);// 匹配并应用模式mlir::LogicalResult result = applicator.matchAndRewrite(op, rewriter);if (failed(result)) {// 没有应用任何模式llvm::errs() << "No patterns were applied.\\n";} else {// 成功应用了一个模式llvm::errs() << "A pattern was successfully applied.\\n";}
}
Dialect Conversion Driver(方言转换驱动器):
- 该驱动器提供了一个框架,用于在方言之间及方言内部进行操作转换。使用“合法性”的概念,将不合法的操作转换为目标方言支持的操作。
- 还支持类型转换。
Greedy Pattern Rewrite Driver(贪婪模式重写驱动器):
- 该驱动器以工作列表的方式处理操作,并贪婪地应用在本地最有益的模式。
模式的益处由模式自身的益处和模式列表中的相对顺序决定。
该驱动器有两种形式: - Region-based driver(基于区域的驱动器):应用模式到指定区域内的所有操作。
- Op-based driver(基于操作的驱动器):应用模式到指定的操作列表。
驱动器通过GreedyRewriteConfig进行配置,可以选择自顶向下或自底向上的遍历方式。