mirror of https://github.com/kendryte/nncase.git
fix fold binary (#1182)
parent
a13d43d0f4
commit
91ea4df975
|
@ -23,13 +23,14 @@ public sealed partial class FoldNopBinary : IRewriteRule
|
|||
/// <inheritdoc/>
|
||||
public IPattern Pattern { get; } = IsBinary(
|
||||
"binary",
|
||||
"call",
|
||||
x => x.BinaryOp is BinaryOp.Add or BinaryOp.Sub or BinaryOp.Mul or BinaryOp.Div or BinaryOp.Mod or BinaryOp.Pow,
|
||||
IsWildcard("lhs"),
|
||||
IsTensorConst("rhs"));
|
||||
|
||||
private Expr? GetReplace(Binary binary, Expr lhs, TensorConst rhs)
|
||||
private Expr? GetReplace(Binary binary, Call call, Expr lhs, TensorConst rhs)
|
||||
{
|
||||
if (lhs.CheckedType is Nncase.IR.AnyType || lhs.CheckedShape == rhs.CheckedShape)
|
||||
if ((lhs.CheckedType is Nncase.IR.AnyType && rhs.CheckedShape.IsScalar) || (lhs.CheckedShape == call.CheckedShape))
|
||||
{
|
||||
return binary.BinaryOp switch
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue