fix fold binary (#1182)

pull/1194/head
huochenghai 2024-04-12 12:26:50 +08:00 committed by GitHub
parent a13d43d0f4
commit 91ea4df975
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 2 deletions

View File

@ -23,13 +23,14 @@ public sealed partial class FoldNopBinary : IRewriteRule
/// <inheritdoc/>
public IPattern Pattern { get; } = IsBinary(
"binary",
"call",
x => x.BinaryOp is BinaryOp.Add or BinaryOp.Sub or BinaryOp.Mul or BinaryOp.Div or BinaryOp.Mod or BinaryOp.Pow,
IsWildcard("lhs"),
IsTensorConst("rhs"));
private Expr? GetReplace(Binary binary, Expr lhs, TensorConst rhs)
private Expr? GetReplace(Binary binary, Call call, Expr lhs, TensorConst rhs)
{
if (lhs.CheckedType is Nncase.IR.AnyType || lhs.CheckedShape == rhs.CheckedShape)
if ((lhs.CheckedType is Nncase.IR.AnyType && rhs.CheckedShape.IsScalar) || (lhs.CheckedShape == call.CheckedShape))
{
return binary.BinaryOp switch
{