深入解析 MLIR Toy Tutorial(Chapter 3
![](https://img.haomeiwen.com/i13575947/0786a18a71d8c6fd.png)
概述
MLIR Toy Tutorial 的目标是通过构建一门编程语言编译器的完整过程(包括前端和后端技术),教授如何使用 MLIR 的各个组件来实现语言的解析、转换和代码生成等功能。
前一章 Chapter2 生成的 IR 最终是要转换成目标代码的,而在此之前,编译器一般都会对 IR 做一些转换和优化来提高目标代码的质量(性能好、编译时间短和内存占用小等)。
而 Canonicalization(规范化)就是其中一个重要的优化步骤,它将输入的 MLIR IR 转换为其标准规范形式。这个过程对 IR 执行规范化的重写和优化,以消除冗余、简化代码,并确保所有等效的表达都被转换成标准形式。Chapter3 将要介绍如何使用 MLIR Canonicalizer pass 来重写 IR,消除冗余。
Pattern-match and rewrite
Canonicalizer pass 通过 match-and-rewrite 的方式来重写 IR:pass 会遍历 IR 中的所有 op,在它内部维护的一个 pattern 集合 RewritePatternSet
里查找是否有匹配的 RewritePattern
,匹配上后就会通过它来重写 op。
例如,对一个张量在相同的dims重复转置两次,那这些转置就是多余的,应该优化掉,transpose(transpose(x)) == x
。原来的 IR:
toy.func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
%1 = toy.transpose(%0 : tensor<*xf64>) to tensor<*xf64>
toy.return %1 : tensor<*xf64>
}
ToyCombine.cpp 为 TransposeOp 定义了一个名为 SimplifyRedundantTranspose
的 RewritePattern,它一旦匹配到有连续的两个 transpose op,会用第一个 transpose 的输入张量替换掉第二个 transpose 的输出张量,换句话说,它会删掉第二个 transpose op:
toy.func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
%0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
toy.return %arg0 : tensor<*xf64>
}
这个 RewritePattern 只删掉了第二个 transpose,第一个 transpose instruction 还留着呢。Canonicalizer 不会主动做 DCE(Dead Code Elimination),除非 TransposeOp 添加了相应的 trait,如 Pure
:
def TransposeOp : Toy_Op<"transpose", [Pure]> {...}
Pure(纯函数)意味着函数的输出只由输入决定,输入不变输出就不变,这样有了明确的约束,Canonicalizer 就可以放心大胆滴给 op 上各种优化,其中就包括 DCE,最终就能得到我们想要的结果:
toy.func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> {
toy.return %arg0 : tensor<*xf64>
}
DRR: Declarative Rewrite Rules
和定义 op 一样,同样可以通过声明的方式,在 ToyCombine.td
定义 RewritePattern:
// Reshape(Reshape(x)) = Reshape(x)
def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
(ReshapeOp $arg)>;
-
ReshapeReshapeOptPattern: Pat<...>
,表示要定义一个名为 ReshapeReshapeOptPattern 的 RewritePattern。 -
(ReshapeOp(ReshapeOp $arg))
是match
pattern,表示要匹配连续两个 reshape op 的 IR 片段:Reshape(Reshape(x))。 -
(ReshapeOp $arg)
是rewrite
pattern,当 match 成功后,就会根据 rewrite pattern 来重写 IR,这里是指将其中一个 reshape 删除掉:Reshape(Reshape(x)) -> Reshape(x)。
另外,还可以对参数添加一些约束条件,只有条件满足后才会重写 IR:
def TypesAreIdentical : Constraint<CPred<"$0.getType() == $1.getType()">>;
def RedundantReshapeOptPattern : Pat<
(ReshapeOp:$res $arg), (replaceWithValue $arg),
[(TypesAreIdentical $res, $arg)]>;
RedundantReshapeOptPattern 用于匹配所有 reshape,当约束 (TypesAreIdentical $res, $arg) == true
后,也就是 reshape op 的输入张量和输出张量的 shape 和 dtype 都一致的时候,就删除掉这个多余的reshape。
总结
Chapter3 介绍了如何使用 MLIR 的 Canonicalizer pass,通过匹配和重写的方式来将 MLIR IR 转换为其标准规范形式,消除冗余、简化代码。在 Ops.td 里为 op 设置 let hasCanonicalizer = 1;
后,就可以自定义 RewritePattern 来重写 IR。另外,也可以通过 DRR(Declarative Rewrite Rules)来声明式重写规则,让 tablegen 来成帮助创建 RewritePattern。