mirror of https://github.com/kendryte/nncase.git
Compare commits
3 Commits
e39d73eb1c
...
8de9964054
Author | SHA1 | Date |
---|---|---|
Curio Yang | 8de9964054 | |
sunnycase | 413aff7386 | |
郑启航 | 952f89bb72 |
|
@ -308,3 +308,4 @@ cmake-build-*
|
|||
*.ipynb_checkpoints*
|
||||
# Auto generated files
|
||||
# generated/
|
||||
.history/
|
|
@ -13,7 +13,6 @@ using System.Reactive;
|
|||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
using DryIoc;
|
||||
using Google.OrTools.Sat;
|
||||
using NetFabric.Hyperlinq;
|
||||
using Nncase.IR;
|
||||
using Nncase.Runtime;
|
||||
|
@ -176,7 +175,7 @@ internal sealed class DeviceCSourceConvertVisitor : ExprFunctor<CSymbol, Unit>
|
|||
TensorType { Shape: { IsRanked: true } } x => x.Shape.IsFixed switch
|
||||
{
|
||||
true => $"tensor_view<{x.DType.ToC()}, fixed_shape<{x.Shape.ToString()[1..^1]}>>",
|
||||
false => $"tensor_view<{x.DType.ToC()}, ranked_shape<{x.Shape.Rank}>>",
|
||||
false => "auto",
|
||||
},
|
||||
_ => throw new NotSupportedException(),
|
||||
};
|
||||
|
@ -217,6 +216,26 @@ internal sealed class DeviceCSourceConvertVisitor : ExprFunctor<CSymbol, Unit>
|
|||
break;
|
||||
case IR.Buffers.Allocate op:
|
||||
str = $"({type})runtime_util->malloc({arguments[0].Name})";
|
||||
break;
|
||||
case IR.Buffers.BufferSubview op:
|
||||
{
|
||||
var arg0 = expr.Arguments[1] switch
|
||||
{
|
||||
TupleConst => $"fixed_shape<{arguments[1].Name}>{{}}",
|
||||
IR.Tuple tc => $"ranked_shape<{tc.Count}>{{{arguments[1].Name}}}",
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(expr)),
|
||||
};
|
||||
|
||||
var arg1 = expr.Arguments[2] switch
|
||||
{
|
||||
TupleConst => $"fixed_shape<{arguments[2].Name}>{{}}",
|
||||
IR.Tuple tc => $"ranked_shape<{tc.Count}>{{{arguments[2].Name}}}",
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(expr)),
|
||||
};
|
||||
|
||||
str = $"{arguments[0].Name}.view({arg0}, {arg1})";
|
||||
}
|
||||
|
||||
break;
|
||||
case IR.Buffers.AllocateBufferView op:
|
||||
{
|
||||
|
@ -252,6 +271,14 @@ internal sealed class DeviceCSourceConvertVisitor : ExprFunctor<CSymbol, Unit>
|
|||
BinaryOp = op.BinaryOp,
|
||||
}).Result);
|
||||
break;
|
||||
case TIR.CPU.Swish swish:
|
||||
if (swish.Beta != 1.0f)
|
||||
{
|
||||
throw new NotSupportedException();
|
||||
}
|
||||
|
||||
IndentScope.Writer.IndWrite($"unary<ops::swish>({arguments[0].Name}, {arguments[1].Name});\n");
|
||||
break;
|
||||
default:
|
||||
throw new NotSupportedException();
|
||||
}
|
||||
|
@ -298,6 +325,34 @@ internal sealed class DeviceCSourceConvertVisitor : ExprFunctor<CSymbol, Unit>
|
|||
return symbol;
|
||||
}
|
||||
|
||||
protected override CSymbol VisitTupleConst(TupleConst tp)
|
||||
{
|
||||
if (_exprMemo.TryGetValue(tp, out var symbol))
|
||||
{
|
||||
return symbol;
|
||||
}
|
||||
|
||||
string type = string.Empty;
|
||||
string str = $"{string.Join(",", tp.Value.Select(x => Visit(Const.FromValue(x)).Name))}";
|
||||
symbol = new(type, str);
|
||||
_exprMemo.Add(tp, symbol);
|
||||
return symbol;
|
||||
}
|
||||
|
||||
protected override CSymbol VisitTuple(IR.Tuple tp)
|
||||
{
|
||||
if (_exprMemo.TryGetValue(tp, out var symbol))
|
||||
{
|
||||
return symbol;
|
||||
}
|
||||
|
||||
string type = string.Empty;
|
||||
string str = $"{string.Join(",", tp.Fields.AsValueEnumerable().Select(x => Visit(x).Name).ToArray())}";
|
||||
symbol = new(type, str);
|
||||
_exprMemo.Add(tp, symbol);
|
||||
return symbol;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override CSymbol VisitSequential(Sequential expr)
|
||||
{
|
||||
|
|
|
@ -62,12 +62,11 @@ internal class FunctionBuilder
|
|||
|
||||
return new LinkableKernelFunction(_id, function, functionCSource, _sectionManager.GetContent(WellknownSectionNames.Text)!, new LinkedSection(_sectionManager.GetContent(KernelHeaderSectionName), KernelHeaderSectionName, 0, 8, (uint)sizeof(DescHeader)));
|
||||
}
|
||||
else if (function.Name.EndsWith("device"))
|
||||
else
|
||||
{
|
||||
var visitor = new DeviceCSourceConvertVisitor();
|
||||
visitor.Visit(function);
|
||||
var header = visitor.GetHeader();
|
||||
|
||||
return new LinkableDeviceFunction(_id, function, header, _sectionManager.GetContent(WellknownSectionNames.Text)!);
|
||||
}
|
||||
|
||||
|
|
|
@ -132,7 +132,7 @@ internal sealed class KernelCSourceConvertVisitor : ExprFunctor<CSymbol, Unit>,
|
|||
|
||||
public KernelCSource GetCSource()
|
||||
{
|
||||
var ctype = $"void {VisitEntry.Name}({string.Join(", ", VisitEntry.Parameters.AsValueEnumerable().Select(Visit).Select(s => $"{s.Type} {s.Name}").ToArray().Concat(_exprMemo.Keys.OfType<TIR.Buffer>().Where(b => b.MemSpan.Location == MemoryLocation.Rdata).Select(Visit).Select(s => $" {s.Type} {s.Name}").ToArray()))}, uint8_t* l1_data)";
|
||||
var ctype = $"void {VisitEntry.Name}({string.Join(", ", VisitEntry.Parameters.AsValueEnumerable().Select(Visit).Select(s => $"{s.Type} {s.Name}").ToArray().Concat(_exprMemo.Keys.OfType<TIR.Buffer>().Where(b => b.MemSpan.Location == MemoryLocation.Rdata).Select(Visit).Select(s => $" {s.Type} {s.Name}").ToArray()))}, uint8_t* data)";
|
||||
return new(
|
||||
CSourceBuiltn.MakeMain(VisitEntry, _exprMemo.Keys.OfType<TIR.Buffer>().Where(b => b.MemSpan.Location == MemoryLocation.Rdata)),
|
||||
CSourceBuiltn.MakeKernel(ctype, _kernelBuilder.ToString()));
|
||||
|
|
|
@ -32,18 +32,34 @@ public sealed class BoxingEvaluator : ITypeInferencer<Boxing>, ICostEvaluator<Bo
|
|||
switch (inType, returnType)
|
||||
{
|
||||
case (TensorType tensorType, DistributedType distTensorType):
|
||||
cost = new Cost()
|
||||
switch (context.CompileOptions.TargetCompileOptions)
|
||||
{
|
||||
[CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(tensorType),
|
||||
[CostFactorNames.MemoryStore] = (UInt128)((float)CostUtility.GetMemoryAccess(distTensorType) / DistributedUtility.GetDividedTensorEfficiency(distTensorType, _burstLength)),
|
||||
};
|
||||
case Targets.CpuTargetOptions { UnifiedMemoryArchitecture: true }:
|
||||
break;
|
||||
default:
|
||||
cost = new Cost()
|
||||
{
|
||||
[CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(tensorType),
|
||||
[CostFactorNames.MemoryStore] = (UInt128)((float)CostUtility.GetMemoryAccess(distTensorType) / DistributedUtility.GetDividedTensorEfficiency(distTensorType, _burstLength)),
|
||||
};
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
case (DistributedType distTensorType, TensorType tensorType):
|
||||
cost = new Cost()
|
||||
switch (context.CompileOptions.TargetCompileOptions)
|
||||
{
|
||||
[CostFactorNames.MemoryLoad] = (UInt128)((float)CostUtility.GetMemoryAccess(distTensorType) / DistributedUtility.GetDividedTensorEfficiency(distTensorType, _burstLength)),
|
||||
[CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(tensorType),
|
||||
};
|
||||
case Targets.CpuTargetOptions { UnifiedMemoryArchitecture: true }:
|
||||
break;
|
||||
default:
|
||||
cost = new Cost()
|
||||
{
|
||||
[CostFactorNames.MemoryLoad] = (UInt128)((float)CostUtility.GetMemoryAccess(distTensorType) / DistributedUtility.GetDividedTensorEfficiency(distTensorType, _burstLength)),
|
||||
[CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(tensorType),
|
||||
};
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
|
||||
case (DistributedType a, DistributedType b) when a.Placement == b.Placement && a.NdSBP != b.NdSBP:
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Linq;
|
||||
using Nncase.CostModel;
|
||||
using Nncase.IR;
|
||||
using Nncase.IR.CPU;
|
||||
|
||||
namespace Nncase.Evaluator.IR.CPU;
|
||||
|
||||
/// <summary>
|
||||
/// Evaluator for <see cref="CPUKernelOp"/>.
|
||||
/// </summary>
|
||||
public class CPUKernelOpEvaluator : IEvaluator<CPUKernelOp>, ITypeInferencer<CPUKernelOp>, ICostEvaluator<CPUKernelOp>
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
public IValue Visit(IEvaluateContext context, CPUKernelOp target)
|
||||
{
|
||||
return CompilerServices.EvaluateOp(target.Target, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public IRType Visit(ITypeInferenceContext context, CPUKernelOp target)
|
||||
{
|
||||
return CompilerServices.InferenceOp(target.Target, context, new());
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public Cost Visit(ICostEvaluateContext context, CPUKernelOp target)
|
||||
{
|
||||
return CompilerServices.EvaluateOpCost(target.Target, context);
|
||||
}
|
||||
}
|
|
@ -14,7 +14,6 @@ internal class CPUModule : IApplicationPart
|
|||
public void ConfigureServices(IRegistrator registrator)
|
||||
{
|
||||
registrator.RegisterManyInterface<BoxingEvaluator>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<CPUKernelOpEvaluator>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<LoadEvaluator>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<StoreEvaluator>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<PackEvaluator>(reuse: Reuse.Singleton);
|
||||
|
|
|
@ -38,10 +38,6 @@ public sealed class PackedMatMulEvaluator : IEvaluator<PackedMatMul>, ITypeInfer
|
|||
rhs = rhs.Unpack(axis);
|
||||
}
|
||||
|
||||
// lhs = OrtKI.Unsqueeze(lhs, new long[] { -4, -1 }); // [x,m/32,k/32, 1 , m' ,k', 1 ]
|
||||
// rhs = OrtKI.Unsqueeze(rhs, new long[] { -6, -3 }); // [x, 1 ,k/32,n/32, 1 ,k', n']
|
||||
// var matmul = OrtKI.Mul(lhs, rhs); // [x, m/32,k/32,n/32,m',k',n']
|
||||
// matmul = OrtKI.ReduceSum(matmul, new long[] { -2, -5 }, 0, 1);
|
||||
var matmul = OrtKI.MatMul(lhs, rhs);
|
||||
if (target.LhsPackedAxes.Count == 2)
|
||||
{
|
||||
|
@ -112,7 +108,7 @@ public sealed class PackedMatMulEvaluator : IEvaluator<PackedMatMul>, ITypeInfer
|
|||
rType = Math.MatMulEvaluator.VisitTensorType(a, b);
|
||||
break;
|
||||
default:
|
||||
ERROR: rType = new InvalidType($"{lhs} {rhs} not support");
|
||||
ERROR: rType = new InvalidType($"lhs: {lhs}, rhs: {rhs} not support");
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR.Math;
|
||||
using Nncase.PatternMatch;
|
||||
|
||||
namespace Nncase.IR.CPU;
|
||||
|
||||
public sealed class CPUKernelOp : Op
|
||||
{
|
||||
private readonly ExprPinner _exprPinner;
|
||||
|
||||
public CPUKernelOp(Op target)
|
||||
{
|
||||
_exprPinner = new(target);
|
||||
Target = target;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the target.
|
||||
/// </summary>
|
||||
public Op Target { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override IEnumerable<ParameterInfo> Parameters => Target.Parameters;
|
||||
|
||||
public override string DisplayProperty() => Target.GetType().Name;
|
||||
}
|
|
@ -12,17 +12,6 @@ namespace Nncase.IR.F;
|
|||
|
||||
public partial class CPU
|
||||
{
|
||||
/// <summary>
|
||||
/// Call cpu kernel.
|
||||
/// </summary>
|
||||
/// <param name="target">Unary operator.</param>
|
||||
/// <param name="inputs">Source inputs.</param>
|
||||
/// <returns>Result expression.</returns>
|
||||
public static Call CPUKernel(Op target, params Expr[] inputs)
|
||||
{
|
||||
return new Call(new CPUKernelOp(target), inputs);
|
||||
}
|
||||
|
||||
public static Call Boxing(Expr input, IRType type)
|
||||
{
|
||||
return new Call(new Boxing(type), input);
|
||||
|
|
|
@ -14,49 +14,42 @@ using static Nncase.PatternMatch.Utility;
|
|||
|
||||
[assembly: InternalsVisibleTo("Nncase.Tests")]
|
||||
|
||||
namespace Nncase.Passes.Rules;
|
||||
namespace Nncase.Passes;
|
||||
|
||||
/// <summary>
|
||||
/// auto distributed the xpu fusion.
|
||||
/// </summary>
|
||||
[RuleGenerator]
|
||||
public sealed partial class AutoDistributed : IRewriteRule
|
||||
public sealed partial class AutoDistributedPass : FunctionPass
|
||||
{
|
||||
private readonly CompileOptions _compileOptions;
|
||||
|
||||
public AutoDistributed(CompileOptions compileOptions)
|
||||
public AutoDistributedPass(CompileOptions compileOptions)
|
||||
{
|
||||
_compileOptions = compileOptions;
|
||||
}
|
||||
|
||||
public IPattern Pattern { get; } = IsCallWildcard("call", IsFusion("fusion", CPUTarget.Kind, IsWildcard("body"), IsVArgsRepeat("parameters", () => IsVar())));
|
||||
|
||||
private Expr? GetReplace(Call call, Fusion fusion, Expr body, IReadOnlyList<Expr> parameters, IReadOnlyList<Expr> callParams)
|
||||
protected override Task<BaseFunction> RunCoreAsync(BaseFunction input, RunPassContext context)
|
||||
{
|
||||
// 1. convert to distribute graph
|
||||
if (body is Call { Target: Boxing } || (body is IR.Tuple tp && tp.Fields.AsValueEnumerable().Any(e => e is Call { Target: Boxing })))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var distConverter = new AutoDistributedConvertVisitor(_compileOptions.TargetCompileOptions is CpuTargetOptions options ? options : new CpuTargetOptions());
|
||||
var newbody = distConverter.Convert(body);
|
||||
var newFusion = fusion.With(moduleKind: CPUTarget.Kind, body: newbody, parameters: parameters.Cast<Var>().ToArray());
|
||||
return new Call(newFusion, callParams.ToArray());
|
||||
var rewriter = new AutoDistributedRewriter(_compileOptions, _compileOptions.TargetCompileOptions is CpuTargetOptions options ? options : new CpuTargetOptions());
|
||||
return Task.FromResult(rewriter.Rewirte(input));
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRType, List<Expr>>, Unit>
|
||||
internal sealed class AutoDistributedRewriter : ExprVisitor<Dictionary<IRType, List<Expr>>, Unit>
|
||||
{
|
||||
public AutoDistributedConvertVisitor(CpuTargetOptions compileOptions)
|
||||
public AutoDistributedRewriter(CompileOptions compileOptions, CpuTargetOptions targetOptions)
|
||||
{
|
||||
Placement = new Placement(compileOptions.Hierarchy, compileOptions.HierarchyNames);
|
||||
Placement = new Placement(targetOptions.Hierarchy, targetOptions.HierarchyNames);
|
||||
CompileOptions = compileOptions;
|
||||
TargetOptions = targetOptions;
|
||||
}
|
||||
|
||||
public Placement Placement { get; }
|
||||
|
||||
public CpuTargetOptions CompileOptions { get; }
|
||||
public CompileOptions CompileOptions { get; }
|
||||
|
||||
public CpuTargetOptions TargetOptions { get; }
|
||||
|
||||
public static IReadOnlyList<Expr> GetLeafCandidateBoxings(Expr expr, Placement placement)
|
||||
{
|
||||
|
@ -65,41 +58,31 @@ internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRT
|
|||
ToArray();
|
||||
}
|
||||
|
||||
public Expr Convert(Expr body)
|
||||
public BaseFunction Rewirte(BaseFunction input)
|
||||
{
|
||||
var createFinalBoxing = (Expr e, TensorType type) =>
|
||||
if (input is Function function)
|
||||
{
|
||||
var d = (DistributedType)e.CheckedType;
|
||||
if (d.NdSBP.Any(s => s is SBPPartialSum))
|
||||
var equivalents = Visit(function.Body).Select(g => InstertTerminator(g.Value[0])).ToArray();
|
||||
using (new ExprPinner(equivalents))
|
||||
{
|
||||
var boxingP2B = IR.F.CPU.Boxing(e, new DistributedType(type, d.NdSBP.Select(s => s is SBPPartialSum ? SBP.B : s).ToArray(), Placement));
|
||||
return IR.F.CPU.Boxing(boxingP2B, type);
|
||||
BranchCut();
|
||||
}
|
||||
|
||||
return IR.F.CPU.Boxing(e, type);
|
||||
};
|
||||
|
||||
var equivalents = Visit(body).Select(g => g.Value[0] switch
|
||||
{
|
||||
IR.Tuple tp => new IR.Tuple(tp.Fields.ToArray().Select((f, i) => createFinalBoxing(f, (TensorType)((IR.Tuple)body).Fields[i].CheckedType)).ToArray()),
|
||||
Expr e => (Expr)createFinalBoxing(e, (TensorType)body.CheckedType),
|
||||
}).ToArray();
|
||||
using (new ExprPinner(equivalents))
|
||||
{
|
||||
BranchCut();
|
||||
}
|
||||
|
||||
var graph = new EGraph();
|
||||
foreach (var (exprKey, buckets) in ExprMemo.Where(kv => kv.Key is not Op))
|
||||
{
|
||||
foreach (var (typeKey, bucket) in buckets.Where(kv => kv.Value.Any()))
|
||||
var graph = new EGraph();
|
||||
foreach (var (exprKey, buckets) in ExprMemo.Where(kv => kv.Key is not Op))
|
||||
{
|
||||
Unions(graph, bucket);
|
||||
foreach (var (typeKey, bucket) in buckets.Where(kv => kv.Value.Any()))
|
||||
{
|
||||
Unions(graph, bucket);
|
||||
}
|
||||
}
|
||||
|
||||
var root = Unions(graph, equivalents);
|
||||
var post = graph.Extract(root, CompileOptions, null, Array.Empty<EGraphExtractConstrains>());
|
||||
return function.With(body: post);
|
||||
}
|
||||
|
||||
var root = Unions(graph, equivalents);
|
||||
return graph.Extract(root, null, Array.Empty<EGraphExtractConstrains>());
|
||||
return input;
|
||||
}
|
||||
|
||||
protected override Dictionary<IRType, List<Expr>> DefaultVisitLeaf(Expr expr)
|
||||
|
@ -124,16 +107,18 @@ internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRT
|
|||
throw new NotSupportedException("not support auto distributed call function");
|
||||
}
|
||||
|
||||
var isSupported = PassUtility.IsCpuSupported(op, expr.Arguments.ToArray());
|
||||
foreach (var param in op.Parameters)
|
||||
{
|
||||
VisitLeafArgument(param.ParameterKind, expr.Arguments[param.Index]);
|
||||
VisitLeafArgument(param.ParameterKind, expr.Arguments[param.Index], isSupported);
|
||||
}
|
||||
|
||||
var results = expr.Arguments.ToArray().
|
||||
Select(Visit).
|
||||
CartesianProduct().
|
||||
Select(args => args.ToArray()).
|
||||
Select(args => BuildEquivalCalls(op, args.Select(kv => kv.Value[0]).ToArray()).ToArray()).
|
||||
Select(args => isSupported ? BuildEquivalCalls(op, args.Select(kv => kv.Value[0]).ToArray()).ToArray() :
|
||||
BuildNotSupportedCalls(op, args.Select(kv => kv.Value[0]).ToArray())).
|
||||
SelectMany(i => i).
|
||||
GroupBy(c => c.CheckedType).
|
||||
ToDictionary(g => g.Key, g => new List<Expr>(g.ToList<Expr>()));
|
||||
|
@ -157,7 +142,7 @@ internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRT
|
|||
return results;
|
||||
}
|
||||
|
||||
private Dictionary<IRType, List<Expr>> VisitLeafArgument(ParameterKind parameterKind, Expr expr)
|
||||
private Dictionary<IRType, List<Expr>> VisitLeafArgument(ParameterKind parameterKind, Expr expr, bool isSupported)
|
||||
{
|
||||
var updateBuckets = (Dictionary<IRType, List<Expr>> buckets, IEnumerable<Expr> equivalents) =>
|
||||
{
|
||||
|
@ -179,12 +164,12 @@ internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRT
|
|||
switch (parameterKind, expr)
|
||||
{
|
||||
case (ParameterKind.Input, Expr e) when e is Const or Var:
|
||||
updateBuckets(buckets, GetLeafCandidateBoxings(e, Placement));
|
||||
updateBuckets(buckets, isSupported ? GetLeafCandidateBoxings(e, Placement) : new[] { e });
|
||||
break;
|
||||
case (ParameterKind.Input, Expr e) when e is IR.Tuple tp:
|
||||
foreach (var f in tp.Fields)
|
||||
{
|
||||
VisitLeafArgument(parameterKind, f);
|
||||
VisitLeafArgument(parameterKind, f, isSupported);
|
||||
}
|
||||
|
||||
foreach (var (k, v) in VisitLeafTuple(tp))
|
||||
|
@ -206,15 +191,44 @@ internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRT
|
|||
throw new InvalidOperationException();
|
||||
}
|
||||
}
|
||||
else if (parameterKind == ParameterKind.Input)
|
||||
{
|
||||
if (isSupported)
|
||||
{
|
||||
if (!buckets.Keys.Any(IsDistributed))
|
||||
{
|
||||
var results = buckets.Select(kv => GetLeafCandidateBoxings(kv.Value[0], Placement)).SelectMany(i => i).ToArray();
|
||||
updateBuckets(buckets, results);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (buckets.Keys.All(IsDistributed))
|
||||
{
|
||||
var results = buckets.Select(kv => InstertTerminator(kv.Value[0])).ToArray();
|
||||
updateBuckets(buckets, results);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buckets;
|
||||
}
|
||||
|
||||
private Call[] BuildNotSupportedCalls(Op target, Expr[] args)
|
||||
{
|
||||
if (target.Parameters.Where(p => p.ParameterKind == ParameterKind.Input).Any(p => IsDistributed(args[p.Index].CheckedType)))
|
||||
{
|
||||
return Array.Empty<Call>();
|
||||
}
|
||||
|
||||
return new[] { new Call(target, args) };
|
||||
}
|
||||
|
||||
private IEnumerable<Call> BuildEquivalCalls(Op target, Expr[] args)
|
||||
{
|
||||
if (!target.Parameters.Where(p => p.ParameterKind == ParameterKind.Input).All(p => IsDistributed(args[p.Index].CheckedType)))
|
||||
{
|
||||
throw new ArgumentException("the some arg have no distributed type.", nameof(args));
|
||||
return Array.Empty<Call>();
|
||||
}
|
||||
|
||||
var calls = new List<Call>();
|
||||
|
@ -226,7 +240,7 @@ internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRT
|
|||
using var pinner = new ExprPinner(args);
|
||||
call.Dispose();
|
||||
|
||||
if (target is CPUKernelOp { Target: Reshape } || target is Reshape)
|
||||
if (target is Reshape)
|
||||
{
|
||||
// the reshape need force boxing.
|
||||
var newShape = ((TensorConst)args[1]).Value.ToArray<int>();
|
||||
|
@ -306,6 +320,28 @@ internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRT
|
|||
_ => false,
|
||||
};
|
||||
|
||||
private Expr InstertTerminator(Expr expr)
|
||||
{
|
||||
Expr CreateFinalBoxing(Expr e, DistributedType type)
|
||||
{
|
||||
if (type.NdSBP.Any(s => s is SBPPartialSum))
|
||||
{
|
||||
var boxingP2B = IR.F.CPU.Boxing(e, new DistributedType(type.TensorType, type.NdSBP.Select(s => s is SBPPartialSum ? SBP.B : s).ToArray(), Placement));
|
||||
return IR.F.CPU.Boxing(boxingP2B, type.TensorType);
|
||||
}
|
||||
|
||||
return IR.F.CPU.Boxing(e, type.TensorType);
|
||||
}
|
||||
|
||||
return (expr, expr.CheckedType) switch
|
||||
{
|
||||
(IR.Tuple tp, TupleType tptype) => new IR.Tuple(tp.Fields.ToArray().Select(InstertTerminator).ToArray()),
|
||||
(Expr e, DistributedType type) => CreateFinalBoxing(e, type),
|
||||
(Expr e, TensorType type) => e,
|
||||
(_, _) => throw new NotSupportedException(),
|
||||
};
|
||||
}
|
||||
|
||||
private EClass Unions(EGraph graph, IEnumerable<Expr> equivalents)
|
||||
{
|
||||
var eids = equivalents.Select(graph.Add).ToArray();
|
|
@ -0,0 +1,97 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.CommandLine;
|
||||
using System.CommandLine.Invocation;
|
||||
using System.Linq;
|
||||
using System.Reactive;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase.Passes;
|
||||
|
||||
public sealed class CPUFunctionPartitionPass : ModulePass
|
||||
{
|
||||
protected override Task<IRModule> RunCoreAsync(IRModule module, RunPassContext context)
|
||||
{
|
||||
var funcs = module.Functions.Count;
|
||||
for (int i = 0; i < funcs; i++)
|
||||
{
|
||||
if (module.Functions[i] is Function function)
|
||||
{
|
||||
Function pre = function;
|
||||
Function post;
|
||||
|
||||
// int count = 0;
|
||||
while (true)
|
||||
{
|
||||
var rewriter = new PartitionRewriter();
|
||||
post = (Function)rewriter.Rewrite(pre, default);
|
||||
if (!rewriter.IsMutated)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
// CompilerServices.DumpDotIR(post, count++.ToString(), "/Users/lisa/Documents/nncase/tests_output");
|
||||
pre = post;
|
||||
}
|
||||
|
||||
module.Replace(i, post);
|
||||
}
|
||||
}
|
||||
|
||||
return Task.FromResult(module);
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed class PartitionContext
|
||||
{
|
||||
private readonly HashSet<Expr> _entryPoints = new(ReferenceEqualityComparer.Instance);
|
||||
|
||||
public void AddEntryPoint(Expr expr)
|
||||
{
|
||||
_entryPoints.Add(expr);
|
||||
}
|
||||
|
||||
public Dictionary<Expr, Expr> CreateVarMaps() => _entryPoints.Select(e => (e, new Var(e.CheckedType))).ToDictionary(kv => kv.e, kv => (Expr)kv.Item2, (IEqualityComparer<Expr>)ReferenceEqualityComparer.Instance);
|
||||
}
|
||||
|
||||
internal sealed class PartitionVisitor : ExprVisitor<Unit, Unit, PartitionContext>
|
||||
{
|
||||
protected override Unit VisitLeafCall(Call expr, PartitionContext context)
|
||||
{
|
||||
if (expr is Call { Target: IR.CPU.Boxing { NewType: DistributedType } })
|
||||
{
|
||||
context.AddEntryPoint(expr.Arguments[0]);
|
||||
}
|
||||
|
||||
return default;
|
||||
}
|
||||
|
||||
protected override Unit DefaultVisitLeaf(Expr expr, PartitionContext context) => default;
|
||||
}
|
||||
|
||||
internal sealed class PartitionRewriter : ExprRewriter<Unit>
|
||||
{
|
||||
protected override Expr RewriteLeafCall(Call expr, Unit context)
|
||||
{
|
||||
if (!IsMutated && expr is Call { Target: IR.CPU.Boxing { NewType: TensorType } })
|
||||
{
|
||||
var visitor = new PartitionVisitor();
|
||||
var ctx = new PartitionContext();
|
||||
visitor.Visit(expr, ctx);
|
||||
var mps = ctx.CreateVarMaps();
|
||||
|
||||
var cloner = new ReplacingExprCloner(mps);
|
||||
var post = cloner.Clone(expr, default);
|
||||
var parameters = mps.Values.OfType<Var>().ToArray();
|
||||
return new Call(new Function(post, parameters).With(moduleKind: Targets.CPUTarget.Kind), mps.Keys.ToArray());
|
||||
}
|
||||
|
||||
return expr;
|
||||
}
|
||||
}
|
|
@ -19,8 +19,6 @@ namespace Nncase.Passes;
|
|||
|
||||
internal sealed class CPUFusionToTirPass : ModulePass
|
||||
{
|
||||
private IAnalyzerManager AnalyzerManager => CompileSession.GetRequiredService<IAnalyzerManager>();
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override Task<IRModule> RunCoreAsync(IRModule module, RunPassContext options)
|
||||
{
|
||||
|
@ -29,7 +27,7 @@ internal sealed class CPUFusionToTirPass : ModulePass
|
|||
|
||||
for (int i = 0; i < module.Functions.Count; i++)
|
||||
{
|
||||
if (module.Functions[i] is Fusion { ModuleKind: CPUTarget.Kind } fusion && fusion.Name.EndsWith("kernel"))
|
||||
if (module.Functions[i] is Fusion { ModuleKind: CPUTarget.Kind } fusion)
|
||||
{
|
||||
// var analysis = new Dictionary<Type, IAnalysisResult>
|
||||
// {
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase.Passes;
|
||||
|
||||
public static class PassUtility
|
||||
{
|
||||
public static bool IsCpuSupported(Op op)
|
||||
{
|
||||
if (op.GetType().Namespace == "Nncase.IR.CPU")
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
return op is IR.Math.Unary or IR.Math.MatMul or IR.NN.Conv2D { PadMode: PadMode.Constant } or IR.NN.Softmax or IR.NN.LayerNorm or IR.NN.InstanceNormalization or IR.Imaging.ResizeImage { IsTFResize: false } or IR.Tensors.Unsqueeze or IR.Tensors.Reshape or IR.Tensors.Slice or IR.Tensors.Concat or IR.Tensors.Transpose or IR.NN.Swish or IR.Tensors.Gather or IR.NN.Pad { PadMode: PadMode.Constant };
|
||||
}
|
||||
|
||||
public static bool IsCpuSupported(Op op, IEnumerable<Expr> arguments)
|
||||
{
|
||||
if (!IsCpuSupported(op))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!op.Parameters.Zip(arguments).All(p => p.First.ParameterKind == ParameterKind.Input || (p.First.ParameterKind == ParameterKind.Attribute && p.Second is TensorConst)))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (op)
|
||||
{
|
||||
case IR.Imaging.ResizeImage:
|
||||
if (arguments.Skip(IR.Imaging.ResizeImage.Roi.Index).First() is not IR.None)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
break;
|
||||
case IR.Tensors.Slice slice:
|
||||
if (((TensorConst)arguments.Skip(IR.Tensors.Slice.Strides.Index).First()).Value.ToArray<int>().Any(s => s < 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (((TensorConst)arguments.Skip(IR.Tensors.Slice.Begins.Index).First()).Value.ToArray<int>().Any(s => s < 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (((TensorConst)arguments.Skip(IR.Tensors.Slice.Ends.Index).First()).Value.ToArray<int>().Any(s => s < 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
break;
|
||||
case IR.NN.Conv2D conv2d:
|
||||
if (((TensorConst)arguments.Skip(IR.NN.Conv2D.FusedClamp.Index).First()).Value.ToArray<float>() is var clamp)
|
||||
{
|
||||
return clamp == new[] { float.NegativeInfinity, float.PositiveInfinity };
|
||||
}
|
||||
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Nncase.IR;
|
||||
using Nncase.IR.Affine;
|
||||
using Nncase.PatternMatch;
|
||||
using Nncase.Targets;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
using static Nncase.PatternMatch.Utility;
|
||||
|
||||
namespace Nncase.Passes.Rules.CPU.Affine;
|
||||
|
||||
[RuleGenerator]
|
||||
public partial class LowerSwish : RewriteRule<Pattern>
|
||||
{
|
||||
private int _count;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Pattern Pattern { get; } = PatternMatch.F.NN.IsSwish(
|
||||
"swish",
|
||||
"call",
|
||||
IsWildcard("input") with { TypePattern = HasFixedShape() },
|
||||
IsTensorConst("beta") with { TypePattern = IsFloat() & (IsScalar() | HasShape(s => s.Rank == 1 && s[0].FixedValue == 1, "scalar")) });
|
||||
|
||||
private Expr GetReplace(Call call, Expr input, float beta)
|
||||
{
|
||||
var rank = input.CheckedShape.Rank;
|
||||
var bufferType = input.CheckedType switch
|
||||
{
|
||||
TensorType t => t,
|
||||
DistributedType dt => Utilities.DistributedUtility.GetDividedTensorType(dt),
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(input)),
|
||||
};
|
||||
return IR.F.Affine.Grid(CPUTarget.Kind)
|
||||
.Read(input, AffineMap.Identity(rank), out var inTile)
|
||||
.Write(TIR.T.CreateBuffer(bufferType, TIR.MemoryLocation.Data, out _, $"swish_{_count++}"), AffineMap.Identity(rank), out var outTile)
|
||||
.Body(TIR.F.CPU.Swish(inTile, outTile, beta))
|
||||
.Build();
|
||||
}
|
||||
}
|
|
@ -13,7 +13,6 @@ using Nncase.PatternMatch;
|
|||
using Nncase.Targets;
|
||||
using static Nncase.IR.F.CPU;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
using static Nncase.PatternMatch.F.Math;
|
||||
using static Nncase.PatternMatch.Utility;
|
||||
|
||||
namespace Nncase.Passes.Rules.CPU.Affine;
|
||||
|
@ -21,18 +20,27 @@ namespace Nncase.Passes.Rules.CPU.Affine;
|
|||
[RuleGenerator]
|
||||
public partial class LowerUnary : RewriteRule<Pattern>
|
||||
{
|
||||
private int _count;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Pattern Pattern { get; } = IsUnary(
|
||||
target_name: "unary",
|
||||
public override Pattern Pattern { get; } = PatternMatch.F.Math.IsUnary(
|
||||
"unary",
|
||||
"call",
|
||||
_ => true,
|
||||
IsWildcard("input") with { TypePattern = HasFixedShape() });
|
||||
IsWildcard("input") with { TypePattern = HasShape(s => s.Rank > 0 && s.IsFixed, "tileable") });
|
||||
|
||||
private Expr GetReplace(Unary unary, Expr input)
|
||||
{
|
||||
var bufferType = input.CheckedType switch
|
||||
{
|
||||
TensorType t => t,
|
||||
DistributedType dt => Utilities.DistributedUtility.GetDividedTensorType(dt),
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(input)),
|
||||
};
|
||||
var rank = input.CheckedShape.Rank;
|
||||
return IR.F.Affine.Grid(CPUTarget.Kind)
|
||||
.Read(input, AffineMap.Identity(rank), out var inTile)
|
||||
.Write(TIR.T.CreateBuffer(input.CheckedTensorType, TIR.MemoryLocation.Data, out _), AffineMap.Identity(rank), out var outTile)
|
||||
.Write(TIR.T.CreateBuffer(bufferType, TIR.MemoryLocation.Data, out _, $"unary_{_count++}"), AffineMap.Identity(rank), out var outTile)
|
||||
.Body(TIR.F.CPU.Unary(unary.UnaryOp, inTile, outTile))
|
||||
.Build();
|
||||
}
|
||||
|
|
|
@ -1,66 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System.Reactive;
|
||||
using System.Runtime.CompilerServices;
|
||||
using NetFabric.Hyperlinq;
|
||||
using Nncase.CodeGen;
|
||||
using Nncase.IR;
|
||||
using Nncase.IR.CPU;
|
||||
using Nncase.IR.Tensors;
|
||||
using Nncase.PatternMatch;
|
||||
using Nncase.Targets;
|
||||
using static Nncase.PatternMatch.Utility;
|
||||
|
||||
[assembly: InternalsVisibleTo("Nncase.Tests")]
|
||||
|
||||
namespace Nncase.Passes.Rules;
|
||||
|
||||
/// <summary>
|
||||
/// auto distributed the xpu fusion.
|
||||
/// </summary>
|
||||
[RuleGenerator]
|
||||
public sealed partial class AutoPacking : IRewriteRule
|
||||
{
|
||||
public IPattern Pattern { get; } = IsCallWildcard("call", IsFusion("fusion", CPUTarget.Kind, IsWildcard("body"), IsVArgsRepeat("parameters", () => IsVar())));
|
||||
|
||||
private Expr? GetReplace(Call call, Fusion fusion, Expr body, IReadOnlyList<Expr> parameters, IReadOnlyList<Expr> callParams)
|
||||
{
|
||||
// 1. convert to distribute graph
|
||||
if (fusion.Metadata is PackMetaData)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var rank = 1;
|
||||
var lane = System.Runtime.Intrinsics.Vector256.IsHardwareAccelerated ? 8 : 4;
|
||||
var newbody = CompilerServices.ERewrite(
|
||||
body,
|
||||
new IRewriteRule[] {
|
||||
new Passes.Rules.CPU.PackSoftmax() { Rank = rank, Lane = lane },
|
||||
new Passes.Rules.CPU.PackSwish() { Rank = rank, Lane = lane },
|
||||
new Passes.Rules.CPU.PackLayerNorm() { Rank = rank, Lane = lane },
|
||||
new Passes.Rules.CPU.PackResizeImage() { Rank = rank, Lane = lane },
|
||||
new Passes.Rules.CPU.PackMatMul() { Rank = rank, Lane = lane },
|
||||
new Passes.Rules.CPU.PackConv2D() { Rank = rank, Lane = lane },
|
||||
new Passes.Rules.CPU.PackUnary() { Rank = rank, Lane = lane },
|
||||
new Passes.Rules.CPU.PackBinary() { Rank = rank, Lane = lane },
|
||||
new Passes.Rules.CPU.PackTranspose() { Rank = rank, Lane = lane },
|
||||
new Passes.Rules.CPU.PackUnsqueeze() { Rank = rank, Lane = lane },
|
||||
new Passes.Rules.CPU.PackReshape() { Rank = rank, Lane = lane },
|
||||
new Passes.Rules.CPU.PackSlice() { Rank = rank, Lane = lane },
|
||||
new Passes.Rules.Neutral.FoldConstCall(),
|
||||
new Passes.Rules.CPU.FoldPackUnpack(),
|
||||
new Passes.Rules.CPU.FoldPackConcatUnpack(),
|
||||
},
|
||||
new());
|
||||
|
||||
var newFusion = fusion.With(moduleKind: CPUTarget.Kind, body: newbody, parameters: parameters.Cast<Var>().ToArray());
|
||||
newFusion.Metadata = new PackMetaData();
|
||||
return new Call(newFusion, callParams.ToArray());
|
||||
}
|
||||
|
||||
private sealed class PackMetaData : IR.IRMetadata
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,687 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reactive;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using NetFabric.Hyperlinq;
|
||||
using Nncase.CostModel;
|
||||
using Nncase.IR;
|
||||
using Nncase.IR.CPU;
|
||||
using Nncase.IR.Math;
|
||||
using Nncase.IR.NN;
|
||||
using Nncase.IR.Tensors;
|
||||
using Nncase.Passes.Analysis;
|
||||
using Nncase.Passes.Rules.Neutral;
|
||||
using Nncase.PatternMatch;
|
||||
using Nncase.Targets;
|
||||
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
using static Nncase.PatternMatch.F.CPU;
|
||||
using static Nncase.PatternMatch.F.Math;
|
||||
using static Nncase.PatternMatch.F.Tensors;
|
||||
using static Nncase.PatternMatch.Utility;
|
||||
using static Nncase.Utilities.ReplaceUtility;
|
||||
|
||||
namespace Nncase.Passes.Rules.CPU;
|
||||
|
||||
[RuleGenerator]
|
||||
internal sealed partial class CPUOutputBoxingFusion : FusionMaker
|
||||
{
|
||||
public override string ModuleKind { get; } = CPUTarget.Kind;
|
||||
|
||||
public override Pattern Pattern { get; } = IsBoxing(
|
||||
target_name: "boxing",
|
||||
op => op.NewType is TensorType,
|
||||
IsCallWildcard("call", IsOp<Op>("op", PassUtility.IsCpuSupported))) with
|
||||
{ TypePattern = HasFixedShape() };
|
||||
|
||||
private Call? GetReplace(Call call, Op op, Boxing boxing, IReadOnlyList<Expr> callParams)
|
||||
{
|
||||
if (!PassUtility.IsCpuSupported(op, callParams))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var newInputs = new List<Expr>();
|
||||
for (int i = 0; i < callParams.Count; i++)
|
||||
{
|
||||
if (callParams[i] is Call or Var)
|
||||
{
|
||||
newInputs.Add(new Var(callParams[i].CheckedType!));
|
||||
}
|
||||
else
|
||||
{
|
||||
if (callParams[i] is TensorConst { Value: Tensor { Shape.IsScalar: true } } tc)
|
||||
{
|
||||
newInputs.Add(Const.FromTensor(Tensor.FromBytes(tc.CheckedDataType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })));
|
||||
}
|
||||
else
|
||||
{
|
||||
newInputs.Add(callParams[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var newCall = new Call(op, newInputs.ToArray());
|
||||
var newBoxingCall = new Call(boxing, newCall);
|
||||
var callFusion = new Call(new Fusion($"{op.GetType().Name}_{Count++}_kernel", ModuleKind, newBoxingCall, newInputs.OfType<Var>().ToArray()), newInputs.Select((e, i) => (e, i)).Where(p => p.e is Var).Select(p => callParams[p.i]).ToArray());
|
||||
return callFusion;
|
||||
}
|
||||
}
|
||||
|
||||
[RuleGenerator]
|
||||
internal sealed partial class CPUSingleFusion : FusionMaker
|
||||
{
|
||||
public override string ModuleKind { get; } = CPUTarget.Kind;
|
||||
|
||||
public override Pattern Pattern { get; } = IsCallWildcard(
|
||||
"call",
|
||||
IsOp<Op>("op", PassUtility.IsCpuSupported)) with
|
||||
{ TypePattern = HasFixedShape() };
|
||||
|
||||
private Call? GetReplace(Call call, Op op, IReadOnlyList<Expr> callParams)
|
||||
{
|
||||
if (!PassUtility.IsCpuSupported(op, callParams))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
if (op is Concat concat)
|
||||
{
|
||||
var tuple = (IR.Tuple)call.Arguments[0];
|
||||
var tupleInputs = tuple.Fields.ToArray();
|
||||
if (!tupleInputs.All(e => e is Var))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var newInputs = new List<Expr>();
|
||||
for (int i = 0; i < tupleInputs.Length; i++)
|
||||
{
|
||||
newInputs.Add(new Var(tupleInputs[i].CheckedType!));
|
||||
}
|
||||
|
||||
var newCall = new Call(new IR.Tensors.Concat(concat.Axis), new IR.Tuple(newInputs.ToArray()));
|
||||
var callFusion = new Call(new Fusion($"{op.GetType().Name}_{Count++}_kernel", ModuleKind, newCall, newInputs.OfType<Var>().ToArray()), newInputs.Select((e, i) => (e, i)).Where(p => p.e is Var).Select(p => tupleInputs[p.i]).ToArray());
|
||||
return callFusion;
|
||||
}
|
||||
else
|
||||
{
|
||||
var newInputs = new List<Expr>();
|
||||
for (int i = 0; i < callParams.Count; i++)
|
||||
{
|
||||
if (callParams[i] is Call or Var)
|
||||
{
|
||||
newInputs.Add(new Var(callParams[i].CheckedType!));
|
||||
}
|
||||
else
|
||||
{
|
||||
if (callParams[i] is TensorConst { Value: Tensor { Shape.IsScalar: true } } tc)
|
||||
{
|
||||
newInputs.Add(Const.FromTensor(Tensor.FromBytes(tc.CheckedDataType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })));
|
||||
}
|
||||
else
|
||||
{
|
||||
newInputs.Add(callParams[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var newCall = new Call(op, newInputs.ToArray());
|
||||
var callFusion = new Call(new Fusion($"{op.GetType().Name}_{Count++}_kernel", ModuleKind, newCall, newInputs.OfType<Var>().ToArray()), newInputs.Select((e, i) => (e, i)).Where(p => p.e is Var).Select(p => callParams[p.i]).ToArray());
|
||||
return callFusion;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed class FusionCostEvaluator : Evaluator.IBaseFuncCostEvaluator
|
||||
{
|
||||
private readonly CompileOptions _compileOptions;
|
||||
|
||||
public FusionCostEvaluator(CompileOptions compileOptions)
|
||||
{
|
||||
_compileOptions = compileOptions;
|
||||
}
|
||||
|
||||
public Cost VisitLeaf(IR.BaseFunction target)
|
||||
{
|
||||
if (target is Fusion fusion)
|
||||
{
|
||||
var vistor = new FusionGraphCostVisitor(_compileOptions);
|
||||
vistor.Visit(fusion);
|
||||
return vistor.ExprMemo.Values.Aggregate(Cost.Zero, (a, b) => a + b);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new NotSupportedException();
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class GraphOpCostEvaluateContext : Evaluator.ICostEvaluateContext
|
||||
{
|
||||
private readonly IRType? _returnType;
|
||||
private readonly IRType?[] _argumentTypes;
|
||||
private readonly Expr[] _arguments;
|
||||
|
||||
public GraphOpCostEvaluateContext(IRType? returnType, IRType?[] argumentTypes, ReadOnlySpan<Expr> arguments, CompileOptions compileOptions)
|
||||
{
|
||||
_returnType = returnType;
|
||||
_argumentTypes = argumentTypes;
|
||||
CompileOptions = compileOptions;
|
||||
_arguments = arguments.ToArray();
|
||||
}
|
||||
|
||||
public CompileOptions CompileOptions { get; }
|
||||
|
||||
public T GetArgument<T>(Op op, ParameterInfo parameter)
|
||||
where T : IR.BaseFunction
|
||||
{
|
||||
return (T)_arguments[parameter.Index];
|
||||
}
|
||||
|
||||
public T GetArgumentType<T>(Op op, ParameterInfo parameter)
|
||||
where T : IRType
|
||||
{
|
||||
if (op.GetType() == parameter.OwnerType)
|
||||
{
|
||||
return (T?)_argumentTypes[parameter.Index] ?? throw new InvalidOperationException("Run type infer first.");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentOutOfRangeException($"Operator {op} doesn't have parameter: {parameter.Name}.");
|
||||
}
|
||||
}
|
||||
|
||||
public T GetReturnType<T>()
|
||||
where T : IRType
|
||||
{
|
||||
return (T?)_returnType ?? throw new InvalidOperationException("Run type infer first.");
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class FusionGraphCostVisitor : ExprVisitor<Cost, IRType>
|
||||
{
|
||||
public FusionGraphCostVisitor(CompileOptions compileOptions)
|
||||
{
|
||||
CompileOptions = compileOptions;
|
||||
}
|
||||
|
||||
public CompileOptions CompileOptions { get; }
|
||||
|
||||
protected override Cost VisitLeafVar(Var var)
|
||||
{
|
||||
return new Cost()
|
||||
{
|
||||
[CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(var.CheckedType!),
|
||||
};
|
||||
}
|
||||
|
||||
protected override Cost DefaultVisitLeaf(Expr expr)
|
||||
{
|
||||
return Cost.Zero;
|
||||
}
|
||||
|
||||
protected override Cost VisitLeafCall(Call call)
|
||||
{
|
||||
Cost cost;
|
||||
if (call.Target is Op op)
|
||||
{
|
||||
var context = new GraphOpCostEvaluateContext(call.CheckedType, call.Arguments.AsValueEnumerable().Select(p => p.CheckedType).ToArray(), call.Arguments, CompileOptions);
|
||||
cost = CompilerServices.EvaluateOpCost(op, context) ?? Cost.Zero;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new NotSupportedException();
|
||||
}
|
||||
|
||||
return cost;
|
||||
}
|
||||
|
||||
protected override Cost VisitLeafFusion(Fusion fusion)
|
||||
{
|
||||
var cost = new Cost()
|
||||
{
|
||||
[CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(fusion.Body.CheckedType!),
|
||||
};
|
||||
cost += fusion.Parameters.AsValueEnumerable().Select(Visit).Sum() ?? Cost.Zero;
|
||||
return cost;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed class FusionMerger : ExprCloner<Unit>
|
||||
{
|
||||
private readonly Dictionary<Var, Expr> _varMap;
|
||||
|
||||
public FusionMerger(Dictionary<Var, Expr> varMap)
|
||||
{
|
||||
_varMap = varMap;
|
||||
}
|
||||
|
||||
protected override Expr VisitLeafVar(Var v, Unit context)
|
||||
{
|
||||
if (_varMap.TryGetValue(v, out var new_expr))
|
||||
{
|
||||
return Visit(new_expr, context);
|
||||
}
|
||||
|
||||
return v;
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed class GeneralFusionMergeRule : IRewriteRule
|
||||
{
|
||||
private readonly Dictionary<int, Call> _mergedCache = new();
|
||||
private int _count;
|
||||
|
||||
public IPattern Pattern { get; } =
|
||||
IsCall(
|
||||
"caller",
|
||||
IsFusion("caller_fusion", _ => true, IsWildcard(), IsVArgsRepeat("inputs", exprs =>
|
||||
{
|
||||
var patterns = new Pattern[exprs.Length];
|
||||
for (var i = 0; i < patterns.Length; i++)
|
||||
{
|
||||
patterns[i] = IsWildcard($"input_{i}");
|
||||
}
|
||||
|
||||
return patterns;
|
||||
})),
|
||||
IsVArgsRepeat("callerInputs", exprs =>
|
||||
{
|
||||
var patterns = new Pattern[exprs.Length];
|
||||
for (var i = 0; i < patterns.Length; i++)
|
||||
{
|
||||
patterns[i] = IsWildcard($"callee_{i}");
|
||||
}
|
||||
|
||||
return patterns;
|
||||
}));
|
||||
|
||||
public Expr? GetReplace(IMatchResult result, RunPassContext options)
|
||||
{
|
||||
var caller = (Call)result["caller"];
|
||||
var caller_fusion = (Fusion)result["caller_fusion"];
|
||||
var callerInputs = (IReadOnlyList<Expr>)result["callerInputs"];
|
||||
var callees = new List<Call>();
|
||||
var callee_fusions = new List<Fusion>();
|
||||
var fusion_index = new List<int>();
|
||||
for (var i = 0; i < callerInputs.Count; i++)
|
||||
{
|
||||
if (result[$"callee_{i}"] is Call { Target: Fusion })
|
||||
{
|
||||
var callee = (Call)result[$"callee_{i}"];
|
||||
var callee_fusion = callee.Target as Fusion;
|
||||
if (callee_fusion!.ModuleKind == caller_fusion.ModuleKind)
|
||||
{
|
||||
callees.Add(callee);
|
||||
callee_fusions.Add(callee_fusion);
|
||||
fusion_index.Add(i);
|
||||
}
|
||||
}
|
||||
else if (result[$"callee_{i}"] is Call { Target: GetItem })
|
||||
{
|
||||
var expr = ((Call)result[$"callee_{i}"]).Arguments[0];
|
||||
if (expr is Call { Target: Fusion } callee && ((Fusion)callee.Target)!.ModuleKind == caller_fusion.ModuleKind)
|
||||
{
|
||||
callees.Add(callee);
|
||||
callee_fusions.Add((Fusion)callee.Target);
|
||||
fusion_index.Add(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var i = callees.Count - 1; i >= 0; i--)
|
||||
{
|
||||
var callee = callees[i];
|
||||
var callee_fusion = callee_fusions[i];
|
||||
if (callees.Except(new[] { callee }).Any(c => c.Arguments.ToArray().Any(a => a == callee)))
|
||||
{
|
||||
callees.RemoveAt(i);
|
||||
callee_fusions.RemoveAt(i);
|
||||
fusion_index.RemoveAt(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (callees.Count == 0)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var hashCodes = new List<int>
|
||||
{
|
||||
ReferenceEqualityComparer.Instance.GetHashCode(caller_fusion),
|
||||
};
|
||||
foreach (var fusion in callee_fusions)
|
||||
{
|
||||
hashCodes.Add(ReferenceEqualityComparer.Instance.GetHashCode(fusion));
|
||||
}
|
||||
|
||||
var hash = default(HashCode);
|
||||
foreach (var subHash in hashCodes)
|
||||
{
|
||||
hash.Add(subHash);
|
||||
}
|
||||
|
||||
var hashcode = hash.ToHashCode();
|
||||
if (!_mergedCache.TryGetValue(hashcode, out var new_call))
|
||||
{
|
||||
var multiVarMap = new Dictionary<Var, Expr>(ReferenceEqualityComparer.Instance);
|
||||
for (var index = 0; index < fusion_index.Count; index++)
|
||||
{
|
||||
var callee = (Call)caller.Arguments[fusion_index[index]];
|
||||
if (callee is Call { Target: Fusion })
|
||||
{
|
||||
multiVarMap.Add(caller_fusion.Parameters[fusion_index[index]], callee_fusions[index].Body);
|
||||
}
|
||||
else
|
||||
{
|
||||
var newCallee = IR.F.Tensors.GetItem(callee_fusions[index].Body, callee.Arguments[1]);
|
||||
multiVarMap.Add(caller_fusion.Parameters[fusion_index[index]], newCallee);
|
||||
}
|
||||
}
|
||||
|
||||
var new_fusion_body = new FusionMerger(multiVarMap).Clone(caller_fusion.Body, default);
|
||||
|
||||
// remove duplicate callees
|
||||
var seen = new HashSet<Expr>();
|
||||
var remindIndex = Enumerable.Range(0, callerInputs.Count).ToList();
|
||||
for (var i = callees.Count - 1; i >= 0; i--)
|
||||
{
|
||||
if (!seen.Add(callees[i]))
|
||||
{
|
||||
callees.RemoveAt(i);
|
||||
callee_fusions.RemoveAt(i);
|
||||
remindIndex.RemoveAt(fusion_index[i]);
|
||||
fusion_index.RemoveAt(i);
|
||||
}
|
||||
}
|
||||
|
||||
var parameters = remindIndex.Select(i => fusion_index.Contains(i) ? callee_fusions[fusion_index.IndexOf(i)].Parameters.ToArray() : new[] { caller_fusion.Parameters[i] }).SelectMany(e => e).ToArray();
|
||||
var calleeInputs = remindIndex.Select(i => fusion_index.Contains(i) ? callees[fusion_index.IndexOf(i)].Arguments.ToArray() : new[] { callerInputs[i] }).SelectMany(a => a).ToList();
|
||||
var indexedParameters = parameters.Select((value, index) => new { value, index }).ToList();
|
||||
parameters = parameters.Distinct().ToArray();
|
||||
var distinctedList = parameters.Select(x => new { value = x, index = indexedParameters.First(i => i.value == x).index }).ToList();
|
||||
var removedIndexes = indexedParameters.Where(x => !distinctedList.Any(d => d.index == x.index)).Select(x => x.index).ToList();
|
||||
removedIndexes.Sort((a, b) => b.CompareTo(a));
|
||||
foreach (var index in removedIndexes)
|
||||
{
|
||||
calleeInputs.RemoveAt(index);
|
||||
}
|
||||
|
||||
using (new Diagnostics.DumpScope(new Diagnostics.NullDumpper()))
|
||||
{
|
||||
new_fusion_body = CompilerServices.ERewrite(new_fusion_body, Array.Empty<IRewriteRule>(), new(), new());
|
||||
}
|
||||
|
||||
var merged_fusion = new Fusion($"mfusion_{_count++}_kernel", caller_fusion.ModuleKind, new_fusion_body, parameters);
|
||||
|
||||
new_call = new Call(merged_fusion, calleeInputs.ToArray());
|
||||
_mergedCache.Add(hashcode, new_call);
|
||||
}
|
||||
else
|
||||
{
|
||||
// System.Console.WriteLine("Re Add Merged Two Fusion Call");
|
||||
}
|
||||
|
||||
return new_call;
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed class TupleFusionMergeRule : IRewriteRule
|
||||
{
|
||||
private readonly Dictionary<int, Call> _mergedCache = new();
|
||||
|
||||
public IPattern Pattern { get; } =
|
||||
IsTuple(
|
||||
"tuple",
|
||||
IsVArgsRepeat("tupleInputs", exprs =>
|
||||
{
|
||||
var patterns = new Pattern[exprs.Length];
|
||||
for (var i = 0; i < patterns.Length; i++)
|
||||
{
|
||||
patterns[i] = IsCallWildcard($"call_{i}", IsWildcard());
|
||||
}
|
||||
|
||||
return patterns;
|
||||
}));
|
||||
|
||||
public Expr? GetReplace(IMatchResult result, RunPassContext options)
|
||||
{
|
||||
var tuple = (IR.Tuple)result["tuple"];
|
||||
var tupleInputs = (IReadOnlyList<Expr>)result["tupleInputs"];
|
||||
var callees = new List<Call>();
|
||||
var callee_fusions = new List<Fusion>();
|
||||
for (var i = 0; i < tupleInputs.Count; i++)
|
||||
{
|
||||
if (result[$"call_{i}"] is Call { Target: Fusion } callee)
|
||||
{
|
||||
callees.Add(callee);
|
||||
callee_fusions.Add((Fusion)callee.Target);
|
||||
}
|
||||
else
|
||||
{
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
if (callee_fusions.Select(f => f.ModuleKind).Distinct().Count() > 1)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var hashCodes = new List<int>
|
||||
{
|
||||
ReferenceEqualityComparer.Instance.GetHashCode(tuple),
|
||||
};
|
||||
foreach (var fusion in callee_fusions)
|
||||
{
|
||||
hashCodes.Add(ReferenceEqualityComparer.Instance.GetHashCode(fusion));
|
||||
}
|
||||
|
||||
var hash = default(HashCode);
|
||||
foreach (var subHash in hashCodes)
|
||||
{
|
||||
hash.Add(subHash);
|
||||
}
|
||||
|
||||
var hashcode = hash.ToHashCode();
|
||||
if (!_mergedCache.TryGetValue(hashcode, out var new_call))
|
||||
{
|
||||
var new_fusion_body = new IR.Tuple(callee_fusions.Select(f => f.Body).ToArray());
|
||||
var name = $"tuple_" + string.Join("_", callee_fusions.Select(f => f.Name).ToArray());
|
||||
|
||||
var parameters = callee_fusions.Select(f => f.Parameters.ToArray()).SelectMany(e => e).ToArray();
|
||||
var merged_fusion = new Fusion(name, callee_fusions[0].ModuleKind, new_fusion_body, parameters);
|
||||
|
||||
var calleeInputs = callees.Select(c => c.Arguments.ToArray()).SelectMany(e => e).ToArray();
|
||||
new_call = new Call(merged_fusion, calleeInputs.ToArray());
|
||||
_mergedCache.Add(hashcode, new_call);
|
||||
}
|
||||
else
|
||||
{
|
||||
// System.Console.WriteLine("Re Add Merged Two Fusion Call");
|
||||
}
|
||||
|
||||
return new_call;
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed class ConcatFusionMergeRule : IRewriteRule
|
||||
{
|
||||
private readonly Dictionary<int, Call> _mergedCache = new();
|
||||
|
||||
public IPattern Pattern { get; } =
|
||||
IsConcat(
|
||||
"concat",
|
||||
_ => true,
|
||||
IsTuple(
|
||||
"tuple",
|
||||
IsVArgsRepeat("tupleInputs", exprs =>
|
||||
{
|
||||
var patterns = new Pattern[exprs.Length];
|
||||
for (var i = 0; i < patterns.Length; i++)
|
||||
{
|
||||
patterns[i] = IsCallWildcard($"callee_{i}", IsWildcard());
|
||||
}
|
||||
|
||||
return patterns;
|
||||
})));
|
||||
|
||||
public Expr? GetReplace(IMatchResult result, RunPassContext options)
|
||||
{
|
||||
var concat = (IR.Tensors.Concat)result["concat"];
|
||||
var tuple = (IR.Tuple)result["tuple"];
|
||||
var tupleInputs = (IReadOnlyList<Expr>)result["tupleInputs"];
|
||||
var callees = new List<Call>();
|
||||
var callee_fusions = new List<Fusion>();
|
||||
for (var i = 0; i < tupleInputs.Count; i++)
|
||||
{
|
||||
if (result[$"callee_{i}"] is Call { Target: Fusion } callee)
|
||||
{
|
||||
callees.Add(callee);
|
||||
callee_fusions.Add((Fusion)callee.Target);
|
||||
}
|
||||
else
|
||||
{
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
if (callee_fusions.Select(f => f.ModuleKind).Distinct().Count() > 1)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var hashCodes = new List<int>
|
||||
{
|
||||
ReferenceEqualityComparer.Instance.GetHashCode(tuple),
|
||||
};
|
||||
foreach (var fusion in callee_fusions)
|
||||
{
|
||||
hashCodes.Add(ReferenceEqualityComparer.Instance.GetHashCode(fusion));
|
||||
}
|
||||
|
||||
var hash = default(HashCode);
|
||||
foreach (var subHash in hashCodes)
|
||||
{
|
||||
hash.Add(subHash);
|
||||
}
|
||||
|
||||
var hashcode = hash.ToHashCode();
|
||||
if (!_mergedCache.TryGetValue(hashcode, out var new_call))
|
||||
{
|
||||
var new_fusion_body = new Call(new Concat(concat.Axis), new IR.Tuple(callee_fusions.Select(f => f.Body).ToArray()));
|
||||
var name = $"concat_" + string.Join("_", callee_fusions.Select(f => f.Name).ToArray());
|
||||
|
||||
var parameters = callee_fusions.Select(f => f.Parameters.ToArray()).SelectMany(e => e).ToArray();
|
||||
var merged_fusion = new Fusion(name, callee_fusions[0].ModuleKind, new_fusion_body, parameters);
|
||||
|
||||
var calleeInputs = callees.Select(c => c.Arguments.ToArray()).SelectMany(e => e).ToArray();
|
||||
new_call = new Call(merged_fusion, calleeInputs.ToArray());
|
||||
_mergedCache.Add(hashcode, new_call);
|
||||
}
|
||||
else
|
||||
{
|
||||
// System.Console.WriteLine("Re Add Merged Two Fusion Call");
|
||||
}
|
||||
|
||||
return new_call;
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed class DeterminedFusionMergeRule : IRewriteRule
|
||||
{
|
||||
private static readonly Pattern _input = IsWildcard("input");
|
||||
|
||||
private readonly Dictionary<int, Call> _mergedCache = new();
|
||||
|
||||
private int _count;
|
||||
|
||||
public IPattern Pattern { get; } =
|
||||
IsCall(
|
||||
"caller",
|
||||
IsFusion("caller_fusion", _ => true, IsWildcard(), IsVArgsRepeat("inputs", exprs =>
|
||||
{
|
||||
var patterns = new Pattern[exprs.Length];
|
||||
for (var i = 0; i < patterns.Length; i++)
|
||||
{
|
||||
patterns[i] = IsVar($"input_{i}");
|
||||
}
|
||||
|
||||
return patterns;
|
||||
})),
|
||||
IsVArgsRepeat("callerInputs", exprs =>
|
||||
{
|
||||
var patterns = new Pattern[exprs.Length];
|
||||
for (var i = 0; i < patterns.Length; i++)
|
||||
{
|
||||
patterns[i] = IsCallWildcard($"callee_{i}", IsWildcard());
|
||||
}
|
||||
|
||||
return patterns;
|
||||
}));
|
||||
|
||||
public Expr? GetReplace(IMatchResult result, RunPassContext options)
|
||||
{
|
||||
var userAnalysis = options.GetAnalysis<IExprUserAnalysisResult>();
|
||||
var caller = (Call)result["caller"];
|
||||
var caller_fusion = (Fusion)result["caller_fusion"];
|
||||
var callerInputs = (IReadOnlyList<Expr>)result["callerInputs"];
|
||||
var callees = new List<Call>();
|
||||
var callee_fusions = new List<Fusion>();
|
||||
var fusion_index = new List<int>();
|
||||
for (var i = 0; i < callerInputs.Count; i++)
|
||||
{
|
||||
if (result[$"callee_{i}"] is Call { Target: Fusion } callee)
|
||||
{
|
||||
var callee_fusion = callee.Target as Fusion;
|
||||
if (callee_fusion!.ModuleKind == caller_fusion.ModuleKind && !userAnalysis[callee].Except(new[] { caller }).Any())
|
||||
{
|
||||
callees.Add(callee);
|
||||
callee_fusions.Add(callee_fusion);
|
||||
fusion_index.Add(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (callees.Count == 0)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var multiVarMap = new Dictionary<Var, Expr>(ReferenceEqualityComparer.Instance);
|
||||
for (var index = 0; index < fusion_index.Count; index++)
|
||||
{
|
||||
multiVarMap.Add(caller_fusion.Parameters[fusion_index[index]], callee_fusions[index].Body);
|
||||
}
|
||||
|
||||
var new_fusion_body = new FusionMerger(multiVarMap).Clone(caller_fusion.Body, default);
|
||||
|
||||
// remove duplicate callees
|
||||
var seen = new HashSet<Expr>();
|
||||
var remindIndex = Enumerable.Range(0, callerInputs.Count).ToList();
|
||||
for (var i = callees.Count - 1; i >= 0; i--)
|
||||
{
|
||||
if (!seen.Add(callees[i]))
|
||||
{
|
||||
callees.RemoveAt(i);
|
||||
callee_fusions.RemoveAt(i);
|
||||
remindIndex.RemoveAt(fusion_index[i]);
|
||||
fusion_index.RemoveAt(i);
|
||||
}
|
||||
}
|
||||
|
||||
var parameters = remindIndex.Select(i => fusion_index.Contains(i) ? callee_fusions[fusion_index.IndexOf(i)].Parameters.ToArray() : new[] { caller_fusion.Parameters[i] }).SelectMany(e => e).ToArray();
|
||||
var merged_fusion = new Fusion($"determined_fusion_{_count++}_kernel", caller_fusion.ModuleKind, new_fusion_body, parameters);
|
||||
|
||||
var calleeInputs = remindIndex.Select(i => fusion_index.Contains(i) ? callees[fusion_index.IndexOf(i)].Arguments.ToArray() : new[] { callerInputs[i] }).SelectMany(a => a).ToArray();
|
||||
var new_call = new Call(merged_fusion, calleeInputs.ToArray());
|
||||
|
||||
return new_call;
|
||||
}
|
||||
}
|
|
@ -1,34 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR;
|
||||
using Nncase.IR.Math;
|
||||
using Nncase.PatternMatch;
|
||||
|
||||
using static Nncase.IR.F.CPU;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
using static Nncase.PatternMatch.F.Math;
|
||||
using static Nncase.PatternMatch.Utility;
|
||||
|
||||
namespace Nncase.Passes.Rules.CPU;
|
||||
|
||||
[RuleGenerator]
|
||||
public partial class LowerBinary : RewriteRule<Pattern>
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
public override Pattern Pattern { get; } = IsBinary(
|
||||
target_name: "binary",
|
||||
_ => true,
|
||||
IsWildcard("lhs") with { TypePattern = IsFloat() & HasFixedShape() },
|
||||
IsWildcard("rhs") with { TypePattern = IsFloat() & HasFixedShape() });
|
||||
|
||||
private Expr? GetReplace(Binary binary, Expr lhs, Expr rhs)
|
||||
{
|
||||
return CPUKernel(binary, lhs, rhs);
|
||||
}
|
||||
}
|
|
@ -21,15 +21,26 @@ namespace Nncase.Passes.Rules.CPU;
|
|||
|
||||
public abstract class PackRule : RewriteRule<Pattern>
|
||||
{
|
||||
public int Lane { get; set; } = 32;
|
||||
public PackRule(int rank, int lane)
|
||||
{
|
||||
Rank = rank;
|
||||
Lane = lane;
|
||||
}
|
||||
|
||||
public int Rank { get; set; } = 2;
|
||||
public int Lane { get; }
|
||||
|
||||
public int Rank { get; }
|
||||
|
||||
public override Expr? GetReplace(IMatchResult result, RunPassContext options) => throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public sealed class PackSoftmax : PackRule
|
||||
{
|
||||
public PackSoftmax(int rank, int lane)
|
||||
: base(rank, lane)
|
||||
{
|
||||
}
|
||||
|
||||
public override Pattern Pattern { get; } = IsSoftmax(
|
||||
"target",
|
||||
IsWildcard("input") with { TypePattern = IsFloat() },
|
||||
|
@ -71,6 +82,11 @@ public sealed class PackSoftmax : PackRule
|
|||
|
||||
public sealed class PackResizeImage : PackRule
|
||||
{
|
||||
public PackResizeImage(int rank, int lane)
|
||||
: base(rank, lane)
|
||||
{
|
||||
}
|
||||
|
||||
public override Pattern Pattern { get; } = IsResizeImage("target", op => op.TransformationMode == ImageResizeTransformationMode.Asymmetric && op.IsTFResize == false, IsWildcard("input"), IsNone(), IsTensorConst("newSize"), IsTensorConst("cubicCoeffA"), IsTensorConst("excludeOutside"), IsTensorConst("extrapolationValue"));
|
||||
|
||||
public override List<Expr> GetReplaceCandidates(IMatchResult result, RunPassContext context)
|
||||
|
@ -101,6 +117,11 @@ public sealed class PackResizeImage : PackRule
|
|||
|
||||
public sealed class PackInstanceNorm : PackRule
|
||||
{
|
||||
public PackInstanceNorm(int rank, int lane)
|
||||
: base(rank, lane)
|
||||
{
|
||||
}
|
||||
|
||||
public override Pattern Pattern { get; } = IsInstanceNormalization(
|
||||
"target",
|
||||
_ => true,
|
||||
|
@ -164,6 +185,11 @@ public sealed class PackInstanceNorm : PackRule
|
|||
|
||||
public sealed class PackLayerNorm : PackRule
|
||||
{
|
||||
public PackLayerNorm(int rank, int lane)
|
||||
: base(rank, lane)
|
||||
{
|
||||
}
|
||||
|
||||
public override Pattern Pattern { get; } = IsLayerNorm(
|
||||
"target",
|
||||
_ => true,
|
||||
|
@ -225,6 +251,11 @@ public sealed class PackLayerNorm : PackRule
|
|||
|
||||
public sealed class PackMatMul : PackRule
|
||||
{
|
||||
public PackMatMul(int rank, int lane)
|
||||
: base(rank, lane)
|
||||
{
|
||||
}
|
||||
|
||||
public override Pattern Pattern { get; } = IsMatMul(
|
||||
"target",
|
||||
IsWildcard("lhs") with { TypePattern = IsFloat() },
|
||||
|
@ -247,16 +278,37 @@ public sealed class PackMatMul : PackRule
|
|||
var matmul = IR.F.CPU.PackedMatMul(packedLhs, packedRhs, lhsPackedAxes, lhsPadNums, rhsPackedAxes, rhsPadNums);
|
||||
var lhsAlign = System.Math.Max(lhsShape.Length, rhsShape.Length) - lhsShape.Length;
|
||||
var rhsAlign = System.Math.Max(lhsShape.Length, rhsShape.Length) - rhsShape.Length;
|
||||
var post = matmul;
|
||||
if (lhsPackedAxes.Length == 2 && rhsPackedAxes.Length == 2)
|
||||
|
||||
var mPackIndex = Array.IndexOf(lhsPackedAxes, lhsShape.Length - 2);
|
||||
var nPackIndex = Array.IndexOf(rhsPackedAxes, rhsShape.Length - 1);
|
||||
var unpackAxes = new List<int>();
|
||||
var unpadNums = new List<int>();
|
||||
if (mPackIndex != -1)
|
||||
{
|
||||
post = PackUtility.SliceForPack(IR.F.CPU.Unpack(matmul, new[] { lhsAlign + lhsPackedAxes[0], rhsAlign + rhsPackedAxes[1] }), candidate.CheckedShape.ToValueArray(), new[] { lhsPadNums[0], rhsPadNums[1] });
|
||||
unpackAxes.Add(lhsAlign + lhsPackedAxes[mPackIndex]);
|
||||
unpadNums.Add(lhsPadNums[mPackIndex]);
|
||||
}
|
||||
|
||||
if (nPackIndex != -1)
|
||||
{
|
||||
unpackAxes.Add(rhsAlign + rhsPackedAxes[nPackIndex]);
|
||||
unpadNums.Add(rhsPadNums[nPackIndex]);
|
||||
}
|
||||
|
||||
Expr post = matmul;
|
||||
if (unpackAxes.Any())
|
||||
{
|
||||
post = PackUtility.SliceForPack(IR.F.CPU.Unpack(matmul, unpackAxes.ToArray()), candidate.CheckedShape.ToValueArray(), unpadNums.ToArray());
|
||||
}
|
||||
|
||||
rets.Add(post);
|
||||
}
|
||||
|
||||
// pack A's k and B's k
|
||||
AddCandidate(new[] { lhsShape.Length - 1 }, new[] { rhsShape.Length - 2 }, new[] { Lane }, new[] { Lane });
|
||||
|
||||
// only pack A's m
|
||||
// AddCandidate(new[] { lhsShape.Length - 2 }, Array.Empty<int>(), new[] { Lane }, Array.Empty<int>());
|
||||
if (Rank > 1)
|
||||
{
|
||||
AddCandidate(new[] { lhsShape.Length - 2, lhsShape.Length - 1 }, new[] { rhsShape.Length - 2, rhsShape.Length - 1 }, new[] { Lane, Lane }, new[] { Lane, Lane });
|
||||
|
@ -268,6 +320,11 @@ public sealed class PackMatMul : PackRule
|
|||
|
||||
public sealed class PackUnary : PackRule
|
||||
{
|
||||
public PackUnary(int rank, int lane)
|
||||
: base(rank, lane)
|
||||
{
|
||||
}
|
||||
|
||||
public override Pattern Pattern { get; } = IsUnary(
|
||||
"target",
|
||||
_ => true,
|
||||
|
@ -309,6 +366,11 @@ public sealed class PackUnary : PackRule
|
|||
|
||||
public sealed class PackBinary : PackRule
|
||||
{
|
||||
public PackBinary(int rank, int lane)
|
||||
: base(rank, lane)
|
||||
{
|
||||
}
|
||||
|
||||
public override Pattern Pattern { get; } = IsBinary(
|
||||
"target",
|
||||
_ => true,
|
||||
|
@ -373,6 +435,11 @@ public sealed class PackBinary : PackRule
|
|||
|
||||
public sealed class PackSwish : PackRule
|
||||
{
|
||||
public PackSwish(int rank, int lane)
|
||||
: base(rank, lane)
|
||||
{
|
||||
}
|
||||
|
||||
public override Pattern Pattern { get; } = IsSwish(
|
||||
"target",
|
||||
IsWildcard("input") with { TypePattern = IsFloat() },
|
||||
|
@ -411,6 +478,11 @@ public sealed class PackSwish : PackRule
|
|||
|
||||
public sealed class PackTranspose : PackRule
|
||||
{
|
||||
public PackTranspose(int rank, int lane)
|
||||
: base(rank, lane)
|
||||
{
|
||||
}
|
||||
|
||||
public override Pattern Pattern { get; } = IsTranspose(
|
||||
"trans",
|
||||
IsWildcard("input") with { TypePattern = IsFloat() },
|
||||
|
@ -462,6 +534,11 @@ public sealed class PackTranspose : PackRule
|
|||
|
||||
public sealed class PackUnsqueeze : PackRule
|
||||
{
|
||||
public PackUnsqueeze(int rank, int lane)
|
||||
: base(rank, lane)
|
||||
{
|
||||
}
|
||||
|
||||
public override Pattern Pattern { get; } = IsUnsqueeze(
|
||||
"unsq",
|
||||
IsWildcard("input") with { TypePattern = IsFloat() },
|
||||
|
@ -519,6 +596,11 @@ public sealed class PackUnsqueeze : PackRule
|
|||
|
||||
public sealed class PackConv2D : PackRule
|
||||
{
|
||||
public PackConv2D(int rank, int lane)
|
||||
: base(rank, lane)
|
||||
{
|
||||
}
|
||||
|
||||
public override Pattern Pattern { get; } = IsConv2D(
|
||||
"conv",
|
||||
conv => conv.PadMode == PadMode.Constant,
|
||||
|
@ -591,6 +673,11 @@ public sealed class PackConv2D : PackRule
|
|||
|
||||
public sealed class PackReshape : PackRule
|
||||
{
|
||||
public PackReshape(int rank, int lane)
|
||||
: base(rank, lane)
|
||||
{
|
||||
}
|
||||
|
||||
public override Pattern Pattern { get; } = IsReshape(
|
||||
"target",
|
||||
IsWildcard("input") with { TypePattern = IsFloat() },
|
||||
|
@ -687,6 +774,11 @@ public sealed class PackReshape : PackRule
|
|||
|
||||
public sealed class PackSlice : PackRule
|
||||
{
|
||||
public PackSlice(int rank, int lane)
|
||||
: base(rank, lane)
|
||||
{
|
||||
}
|
||||
|
||||
public override Pattern Pattern { get; } = IsSlice(
|
||||
"target",
|
||||
IsWildcard("input") with { TypePattern = IsFloat() },
|
||||
|
|
|
@ -1,90 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System.Runtime.CompilerServices;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using Nncase.Passes.Mutators;
|
||||
using Nncase.Targets;
|
||||
|
||||
[assembly: InternalsVisibleTo("Nncase.Tests.CPU")]
|
||||
|
||||
namespace Nncase.Passes.Tile;
|
||||
|
||||
internal sealed class CPUSameInputFusionMergeRule : SameInputFusionMergeRule
|
||||
{
|
||||
public override string ModuleKind => CPUTarget.Kind;
|
||||
}
|
||||
|
||||
internal sealed class CPUMultiInputFusionMergeRule : MultiInputFusionMergeRule
|
||||
{
|
||||
public override string ModuleKind => CPUTarget.Kind;
|
||||
}
|
||||
|
||||
internal sealed class CPUShortCutFusionMergeRuleLeft : ShortCutFusionMergeRuleLeft
|
||||
{
|
||||
public override string ModuleKind => CPUTarget.Kind;
|
||||
}
|
||||
|
||||
internal sealed class CPUShortCutFusionMergeRuleRight : ShortCutFusionMergeRuleRight
|
||||
{
|
||||
public override string ModuleKind => CPUTarget.Kind;
|
||||
}
|
||||
|
||||
internal sealed class CPUFusionGroupMutator : FusionGroupMutator
|
||||
{
|
||||
private readonly Dictionary<Fusion, FusionChecker> _fusioncheckerCache;
|
||||
private bool _checked;
|
||||
|
||||
// private readonly TileOptions _tileOptions = null!;
|
||||
public CPUFusionGroupMutator(
|
||||
Dictionary<Fusion, FusionChecker> fusioncheckerCache,
|
||||
IMergeRewriteRule rule,
|
||||
RunPassContext passOptions)
|
||||
: base(rule, passOptions)
|
||||
{
|
||||
_fusioncheckerCache = fusioncheckerCache;
|
||||
_checked = false;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override bool MergedFusionCheckCallBack(Fusion mergedFusion, HashSet<Fusion> candidateFusions)
|
||||
{
|
||||
bool ok = false;
|
||||
if (!_checked)
|
||||
{
|
||||
PrimTileVisitor primTileVisitor = new();
|
||||
primTileVisitor.Visit(mergedFusion.Body);
|
||||
var checker = new FusionChecker(primTileVisitor.TileList);
|
||||
|
||||
// CompilerServices.DumpDotIR(merged_fusion, "before_merge_check", PassOptions.DumpDir,true); // dump sub function.
|
||||
var ret = checker.Check(mergedFusion.Body);
|
||||
ok = ret.Count > 0;
|
||||
|
||||
// CompilerServices.DumpDotIR(merged_fusion, "after_merge_check", PassOptions.DumpDir,true); // dump sub function.
|
||||
if (ok)
|
||||
{
|
||||
_checked = true;
|
||||
_fusioncheckerCache.Add(mergedFusion, checker);
|
||||
foreach (var cand in candidateFusions)
|
||||
{
|
||||
// release the merged fusion.
|
||||
_fusioncheckerCache.Remove(cand);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ok;
|
||||
}
|
||||
|
||||
public override Expr MergedFusionRewriteCallBack(Expr mergedFusionBody)
|
||||
{
|
||||
using var dumpScope = new DumpScope("MergedFusionClear");
|
||||
return CompilerServices.ERewrite(mergedFusionBody, new[] { new Passes.Rules.CPU.FoldStoreLoad() }, new());
|
||||
}
|
||||
|
||||
protected override Expr RewriteLeafCall(Call expr)
|
||||
{
|
||||
return _checked ? expr : base.RewriteLeafCall(expr);
|
||||
}
|
||||
}
|
|
@ -60,9 +60,13 @@ internal sealed class KernelToTIRVisitor : ExprVisitor<Unit, Unit>
|
|||
{
|
||||
var arguments = expr.Arguments.AsValueEnumerable().Select(GetBuffer).ToArray();
|
||||
var ret = GetBuffer(expr);
|
||||
var op = expr.Target is IR.CPU.CPUKernelOp kop ? kop.Target : expr.Target;
|
||||
var op = expr.Target;
|
||||
switch (op)
|
||||
{
|
||||
case PrimFunctionWrapper { Target: TIR.PrimFunction { ModuleKind: string mkind } deviceFunc } when mkind == Targets.CPUTarget.Kind:
|
||||
_devices.Add(deviceFunc);
|
||||
_mainBody.Add(new Call(deviceFunc, arguments.Concat(new[] { ret }).ToArray()));
|
||||
break;
|
||||
case Fusion deviceFunc:
|
||||
{
|
||||
var r = new DeviceFusionToPrimFuncRewriter(_fusionCheckCache);
|
||||
|
@ -176,7 +180,7 @@ internal sealed class KernelToTIRVisitor : ExprVisitor<Unit, Unit>
|
|||
_mainBody.Add(TIR.F.CPU.Gather(arguments[0], arguments[1], ret, gather.Axis));
|
||||
break;
|
||||
case IR.NN.Pad pad:
|
||||
_mainBody.Add(TIR.F.CPU.Pad(arguments[0], ret, ((TensorConst)expr.Arguments[1]).Value.ToArray<int>(), ((TensorConst)expr.Arguments[2]).Value.ToScalar<float>()));
|
||||
_mainBody.Add(TIR.F.CPU.Pad(arguments[0], ret, ((TensorConst)expr.Arguments[1]).Value.ToArray<int>(), ((TensorConst)expr.Arguments[2]).Value.ToArray<float>()[0]));
|
||||
break;
|
||||
default:
|
||||
throw new NotSupportedException();
|
||||
|
@ -271,9 +275,6 @@ internal sealed class KernelToTIRVisitor : ExprVisitor<Unit, Unit>
|
|||
|
||||
private void GenerateBinary(Binary binary, Buffer[] arguments, Buffer ret, Call expr)
|
||||
{
|
||||
_ = (DistributedType)expr.Arguments[0].CheckedType;
|
||||
_ = (DistributedType)expr.Arguments[1].CheckedType;
|
||||
_ = (DistributedType)expr.CheckedType;
|
||||
_mainBody.Add(TIR.F.CPU.Binary(binary.BinaryOp, arguments[0], arguments[1], ret));
|
||||
}
|
||||
|
||||
|
|
|
@ -117,7 +117,7 @@ public partial class CPU
|
|||
return new Call(new Reshape(newShape), input, ret);
|
||||
}
|
||||
|
||||
public static Expr Swish(Buffer buffer, Buffer ret, float v)
|
||||
public static Expr Swish(Expr buffer, Expr ret, float v)
|
||||
{
|
||||
return new Call(new Swish(v), buffer, ret);
|
||||
}
|
||||
|
|
|
@ -32,4 +32,6 @@ public sealed partial class Swish : CPUKernelOp
|
|||
/// Gets begins.
|
||||
/// </summary>
|
||||
public float Beta { get; }
|
||||
|
||||
public override string DisplayProperty() => $"Beta: {Beta}";
|
||||
}
|
||||
|
|
|
@ -9,6 +9,25 @@ using System.Threading.Tasks;
|
|||
|
||||
namespace Nncase.Targets;
|
||||
|
||||
public enum MemoryArch : byte
|
||||
{
|
||||
/// <summary>
|
||||
/// Unified Memory Access.
|
||||
/// </summary>
|
||||
UMA = 0,
|
||||
|
||||
/// <summary>
|
||||
/// Non-Unified Memory Access.
|
||||
/// </summary>
|
||||
NUMA = 1,
|
||||
}
|
||||
|
||||
public enum NocArch : byte
|
||||
{
|
||||
Mesh = 0,
|
||||
CrossBar = 1,
|
||||
}
|
||||
|
||||
public sealed class CpuTargetOptions : ITargetOptions
|
||||
{
|
||||
public string ModelName { get; set; } = string.Empty;
|
||||
|
@ -17,6 +36,15 @@ public sealed class CpuTargetOptions : ITargetOptions
|
|||
|
||||
public int[] TargetTileSize { get; set; } = Array.Empty<int>();
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets a value indicating whether Unified Memory Architecture. see https://en.wikipedia.org/wiki/Unified_Memory_Access.
|
||||
/// </summary>
|
||||
public bool UnifiedMemoryArchitecture { get; set; } = true;
|
||||
|
||||
public MemoryArch MemoryArch { get; set; } = MemoryArch.UMA;
|
||||
|
||||
public NocArch NocArch { get; set; } = NocArch.Mesh;
|
||||
|
||||
public int[] Hierarchy { get; set; } = new[] { 1 };
|
||||
|
||||
public string HierarchyNames { get; set; } = "b";
|
||||
|
|
|
@ -41,7 +41,7 @@ public class CPUTarget : ITarget
|
|||
|
||||
ITargetOptions ParseTargetCompileOptions(InvocationContext context, Command command)
|
||||
{
|
||||
var packing = context.ParseResult.GetValueForOption<bool>(packingOption);
|
||||
var packing = context.ParseResult.GetValueForOption(packingOption);
|
||||
return new CpuTargetOptions() { Packing = packing };
|
||||
}
|
||||
|
||||
|
@ -61,25 +61,6 @@ public class CPUTarget : ITarget
|
|||
/// <inheritdoc/>
|
||||
public void RegisterTargetDependentPass(IPassManager passManager, CompileOptions options)
|
||||
{
|
||||
passManager.AddWithName<DataflowPass>("MakeFusion").Configure(p =>
|
||||
{
|
||||
p.Add<Passes.Rules.CombineMHA>();
|
||||
p.Add<Passes.Rules.Neutral.FoldConstCall>();
|
||||
p.Add<Passes.Rules.FuseMHA2>();
|
||||
p.Add<Passes.Rules.FuseVAEDecRes>();
|
||||
});
|
||||
|
||||
#if false
|
||||
passManager.AddWithName<DataflowPass>("CPUDeviceFusion").Configure(p =>
|
||||
{
|
||||
p.Add<Passes.Rules.CPU.Affine.LowerUnary>();
|
||||
});
|
||||
#endif
|
||||
|
||||
passManager.AddWithName<DataflowPass>("CPUKernelFusion").Configure(p =>
|
||||
{
|
||||
p.Add<Passes.Rules.CPUSingleKernelFusion>();
|
||||
});
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
|
@ -113,35 +94,67 @@ public class CPUTarget : ITarget
|
|||
|
||||
if (options.TargetCompileOptions is CpuTargetOptions { Packing: true })
|
||||
{
|
||||
passManager.AddWithName<DataflowPass>("AutoPacking").Configure(p =>
|
||||
passManager.AddWithName<EGraphRulesPass>("AutoPacking").Configure(p =>
|
||||
{
|
||||
p.Add<Passes.Rules.AutoPacking>();
|
||||
// todo config it in the target options.
|
||||
var rank = 1;
|
||||
var lane = System.Runtime.Intrinsics.Vector256.IsHardwareAccelerated ? 8 : 4;
|
||||
p.Add<Passes.Rules.CPU.PackSoftmax>(rank, lane);
|
||||
p.Add<Passes.Rules.CPU.PackSwish>(rank, lane);
|
||||
p.Add<Passes.Rules.CPU.PackLayerNorm>(rank, lane);
|
||||
p.Add<Passes.Rules.CPU.PackResizeImage>(rank, lane);
|
||||
p.Add<Passes.Rules.CPU.PackMatMul>(rank, lane);
|
||||
p.Add<Passes.Rules.CPU.PackConv2D>(rank, lane);
|
||||
p.Add<Passes.Rules.CPU.PackUnary>(rank, lane);
|
||||
p.Add<Passes.Rules.CPU.PackBinary>(rank, lane);
|
||||
p.Add<Passes.Rules.CPU.PackTranspose>(rank, lane);
|
||||
p.Add<Passes.Rules.CPU.PackUnsqueeze>(rank, lane);
|
||||
p.Add<Passes.Rules.CPU.PackReshape>(rank, lane);
|
||||
p.Add<Passes.Rules.CPU.PackSlice>(rank, lane);
|
||||
p.Add<Passes.Rules.Neutral.FoldConstCall>();
|
||||
p.Add<Passes.Rules.CPU.FoldPackUnpack>();
|
||||
p.Add<Passes.Rules.CPU.FoldPackConcatUnpack>();
|
||||
p.Add<Passes.Rules.Neutral.FoldTwoReshapes>();
|
||||
});
|
||||
}
|
||||
|
||||
passManager.AddWithName<DataflowPass>("AutoDistributed").Configure(p =>
|
||||
// need refactor tiling.
|
||||
// passManager.Add<AutoDistributedPass>();
|
||||
passManager.Add<DataflowPass>().Configure(p =>
|
||||
{
|
||||
p.Add<Passes.Rules.AutoDistributed>();
|
||||
p.Add<Passes.Rules.CPU.CPUOutputBoxingFusion>();
|
||||
p.Add<Passes.Rules.CPU.CPUSingleFusion>();
|
||||
});
|
||||
passManager.Add<DataflowPass>().Configure(p =>
|
||||
{
|
||||
p.AddAnalysis<Passes.Analysis.IExprUserAnalysisResult>();
|
||||
p.Add<Passes.Rules.CPU.DeterminedFusionMergeRule>();
|
||||
});
|
||||
passManager.AddWithName<EGraphRulesPass>("PartitionConstruct").Configure(p =>
|
||||
{
|
||||
p.Add<Passes.Rules.CPU.GeneralFusionMergeRule>();
|
||||
});
|
||||
passManager.AddWithName<EGraphExtractPass>("PartitionExtract").Configure(p =>
|
||||
{
|
||||
p.AddBaseFuncCostEvaluator<Passes.Rules.CPU.FusionCostEvaluator>();
|
||||
});
|
||||
|
||||
// passManager.Add<CPUFunctionPartitionPass>();
|
||||
passManager.Add<CPUFusionToModulePass>();
|
||||
|
||||
#if false
|
||||
// FIX ME: Disable macos as macho loader is buggy.
|
||||
if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
|
||||
passManager.AddWithName<DataflowPass>("LowerToAffine").Configure(p =>
|
||||
{
|
||||
passManager.AddWithName<DataflowPass>("CPUDeviceFusion").Configure(p =>
|
||||
{
|
||||
p.AddAnalysis<Passes.Analysis.IExprUserAnalysisResult>();
|
||||
p.Add<Passes.Rules.CPUDeviceFusion>();
|
||||
});
|
||||
}
|
||||
#endif
|
||||
p.Add<Passes.Rules.CPU.Affine.LowerUnary>();
|
||||
p.Add<Passes.Rules.CPU.Affine.LowerSwish>();
|
||||
});
|
||||
|
||||
// concat/reshape lower
|
||||
// tile and lower to tir.
|
||||
passManager.Add<AutoTilePass>();
|
||||
|
||||
passManager.Add<CPUFusionToTirPass>();
|
||||
|
||||
// todo add auto fusion merge pass here.
|
||||
passManager.Add<PrimFuncPass>().Configure(p =>
|
||||
{
|
||||
p.Add<Passes.Mutators.UnFoldBlock>();
|
||||
|
|
|
@ -21,10 +21,13 @@
|
|||
namespace nncase::ntt {
|
||||
namespace slice_detail {
|
||||
template <IsFixedDims TStart, IsFixedDims TStop, IsFixedDims TStride,
|
||||
size_t... Ints>
|
||||
IsFixedDims TAxes, IsFixedDims TShape, size_t... Ints>
|
||||
inline constexpr auto compute_inner_domain(std::index_sequence<Ints...>) {
|
||||
return fixed_shape<((TStop::at(Ints) - TStart::at(Ints)) /
|
||||
TStride::at(Ints))...>{};
|
||||
return fixed_shape<(
|
||||
((std::min(TShape::at(TAxes::at(Ints)), TStop::at(Ints)) - 1 -
|
||||
TStart::at(Ints)) /
|
||||
TStride::at(Ints)) +
|
||||
1)...>{};
|
||||
}
|
||||
} // namespace slice_detail
|
||||
|
||||
|
@ -41,11 +44,11 @@ inline constexpr auto compute_inner_domain(std::index_sequence<Ints...>) {
|
|||
template <IsFixedDims TStart, IsFixedDims TStop, IsFixedDims TAxes,
|
||||
IsFixedDims TStride, IsFixedTensor TIn, IsFixedTensor TOut>
|
||||
void slice(const TIn &input, TOut &&output) {
|
||||
|
||||
constexpr auto domain = shape_infer::reduced_shape_by_axes(
|
||||
typename std::decay_t<TOut>::shape_type{}, TAxes{});
|
||||
constexpr auto inner_domain =
|
||||
slice_detail::compute_inner_domain<TStart, TStop, TStride>(
|
||||
slice_detail::compute_inner_domain<TStart, TStop, TStride, TAxes,
|
||||
typename TIn::shape_type>(
|
||||
std::make_index_sequence<TAxes::rank()>{});
|
||||
|
||||
auto in_index = ranked_shape<domain.rank()>{};
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
*/
|
||||
#pragma once
|
||||
#include "compiler_defs.h"
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
#include <string_view>
|
||||
#include <type_traits>
|
||||
|
|
|
@ -107,6 +107,28 @@ class NNCASE_API value_type_node : public datatype_node {
|
|||
|
||||
using value_type_t = object_t<value_type_node>;
|
||||
|
||||
class NNCASE_API vector_type_node : public datatype_node {
|
||||
DEFINE_OBJECT_KIND(datatype_node, object_value_type)
|
||||
public:
|
||||
vector_type_node(datatype_t elemtype, dims_t lanes) noexcept
|
||||
: elemtype_(elemtype), lanes_(lanes) {}
|
||||
|
||||
size_t size_bytes() const noexcept override {
|
||||
auto acc = elemtype_->size_bytes();
|
||||
for (size_t i = 0; i < lanes_.size(); i++) {
|
||||
acc *= lanes_[i];
|
||||
}
|
||||
return acc;
|
||||
}
|
||||
typecode_t typecode() const noexcept override { return dt_vectortype; }
|
||||
|
||||
private:
|
||||
datatype_t elemtype_;
|
||||
dims_t lanes_;
|
||||
};
|
||||
|
||||
using vector_type_t = object_t<vector_type_node>;
|
||||
|
||||
namespace detail {
|
||||
template <class T> struct datatype_of {};
|
||||
|
||||
|
|
|
@ -14,3 +14,4 @@ DEFINE_TYPECODE(float64, f64, 0x0C)
|
|||
DEFINE_TYPECODE(bfloat16, bf16, 0x0D)
|
||||
DEFINE_TYPECODE(pointer, *, 0xF0)
|
||||
DEFINE_TYPECODE(valuetype, val, 0xF1)
|
||||
DEFINE_TYPECODE(vectortype, vec, 0xF2)
|
||||
|
|
|
@ -82,6 +82,15 @@ result<datatype_t> deserialize_datatype_impl(TReader &sr) noexcept {
|
|||
auto uuid = sr.template read_unaligned<nncase::uuid_t>();
|
||||
auto size_bytes = sr.template read_unaligned<uint32_t>();
|
||||
return ok<datatype_t>(value_type_t(std::in_place, uuid, size_bytes));
|
||||
}
|
||||
case dt_vectortype: {
|
||||
checked_try_var(elem_type, deserialize_datatype(sr));
|
||||
auto rank = (int32_t)sr.template read_unaligned<uint8_t>();
|
||||
dims_t lanes(rank);
|
||||
for (int32_t i = 0; i < rank; i++) {
|
||||
lanes[i] = sr.template read_unaligned<uint32_t>();
|
||||
}
|
||||
return ok<datatype_t>(vector_type_t(std::in_place, elem_type, lanes));
|
||||
}
|
||||
// prim types
|
||||
default: {
|
||||
|
|
|
@ -11,6 +11,7 @@ using System.IO;
|
|||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
using Nncase.Hosting;
|
||||
|
||||
|
|
|
@ -31,6 +31,16 @@ public static class TypeSerializer
|
|||
writer.Write((byte)Runtime.TypeCode.ValueType);
|
||||
writer.Write(t.Uuid.ToByteArray());
|
||||
writer.Write(t.SizeInBytes);
|
||||
break;
|
||||
case VectorType t:
|
||||
writer.Write((byte)Runtime.TypeCode.VectorType);
|
||||
Serialize(writer, t.ElemType);
|
||||
writer.Write(checked((byte)t.Lanes.Count));
|
||||
for (int i = 0; i < t.Lanes.Count; i++)
|
||||
{
|
||||
writer.Write(t.Lanes[i]);
|
||||
}
|
||||
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentException($"Unsupported datatype: {dataType}");
|
||||
|
|
|
@ -251,11 +251,6 @@ internal class Compiler : ICompiler
|
|||
|
||||
public void ClearFixShape(IPassManager p)
|
||||
{
|
||||
if (!_compileSession.CompileOptions.ShapeBucketOptions.Enable)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
p.AddWithName<DataflowPass>("ClearUnused").Configure(c =>
|
||||
{
|
||||
c.Add<FoldFixShape>();
|
||||
|
@ -285,7 +280,12 @@ internal class Compiler : ICompiler
|
|||
"TargetDependentAfterQuantPass",
|
||||
progress,
|
||||
token);
|
||||
await RunPassAsync(p => ClearFixShape(p), "ClearFixShape", progress, token);
|
||||
|
||||
if (_compileSession.CompileOptions.ShapeBucketOptions.Enable)
|
||||
{
|
||||
await RunPassAsync(ClearFixShape, "ClearFixShape", progress, token);
|
||||
}
|
||||
|
||||
await RunPassAsync(
|
||||
p => target.RegisterTargetDependentBeforeCodeGen(p, _compileSession.CompileOptions),
|
||||
"TargetDependentBeforeCodeGen",
|
||||
|
|
|
@ -118,8 +118,9 @@ public interface ICompilerServicesProvider
|
|||
/// Evaluate cost of the expression tree.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression.</param>
|
||||
/// <param name="compileOptions">options.</param>
|
||||
/// <returns>Evaluate result.</returns>
|
||||
Cost EvaluateCost(Expr expr);
|
||||
Cost EvaluateCost(Expr expr, CompileOptions compileOptions);
|
||||
|
||||
/// <summary>
|
||||
/// Evaluate metric of the expression tree.
|
||||
|
@ -208,8 +209,9 @@ public interface ICompilerServicesProvider
|
|||
/// <param name="expr">Expression.</param>
|
||||
/// <param name="rules">Rewrite rules.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
/// <param name="compileOptions">compileOptions.</param>
|
||||
/// <returns>Rewrited expression.</returns>
|
||||
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options);
|
||||
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options, CompileOptions compileOptions);
|
||||
|
||||
/// <summary>
|
||||
/// Using EGraph rewrite expression.
|
||||
|
@ -304,10 +306,11 @@ public static class CompilerServices
|
|||
/// Evaluate cost of the expression tree.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression.</param>
|
||||
/// <param name="compileOptions">compileOptions.</param>
|
||||
/// <returns>Evaluate result.</returns>
|
||||
public static Cost EvaluateCost(Expr expr)
|
||||
public static Cost EvaluateCost(Expr expr, CompileOptions compileOptions)
|
||||
{
|
||||
return Provider.EvaluateCost(expr);
|
||||
return Provider.EvaluateCost(expr, compileOptions);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -414,10 +417,11 @@ public static class CompilerServices
|
|||
/// <param name="expr">Expression.</param>
|
||||
/// <param name="rules">Rewrite rules.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
/// <param name="compileOptions">compileOptions.</param>
|
||||
/// <returns>Rewrited expression.</returns>
|
||||
public static Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
|
||||
public static Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options, CompileOptions compileOptions)
|
||||
{
|
||||
return Provider.ERewrite(expr, rules, options);
|
||||
return Provider.ERewrite(expr, rules, options, compileOptions);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -652,9 +656,9 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
|
|||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public Cost EvaluateCost(Expr expr)
|
||||
public Cost EvaluateCost(Expr expr, CompileOptions compileOptions)
|
||||
{
|
||||
return _costEvaluateProvider.EvaluateCost(expr);
|
||||
return _costEvaluateProvider.EvaluateCost(expr, compileOptions);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
|
@ -696,9 +700,9 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
|
|||
return _targetProvider.GetTarget(name);
|
||||
}
|
||||
|
||||
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
|
||||
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options, CompileOptions compileOptions)
|
||||
{
|
||||
return _eGraphrewriteProvider.ERewrite(expr, rules, options);
|
||||
return _eGraphrewriteProvider.ERewrite(expr, rules, options, compileOptions);
|
||||
}
|
||||
|
||||
public IEGraph ERewrite(IEGraph graph, IEnumerable<IRewriteRule> rules, RunPassContext options)
|
||||
|
|
|
@ -16,6 +16,11 @@ namespace Nncase.Evaluator;
|
|||
/// </summary>
|
||||
public interface ICostEvaluateContext
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the CompileOptions.
|
||||
/// </summary>
|
||||
public CompileOptions CompileOptions { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Get return type.
|
||||
/// </summary>
|
||||
|
|
|
@ -20,8 +20,9 @@ public interface ICostEvaluateProvider
|
|||
/// Evaluate cost of the expression tree.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression.</param>
|
||||
/// <param name="compileOptions">options.</param>
|
||||
/// <returns>Evaluate result.</returns>
|
||||
Cost EvaluateCost(Expr expr);
|
||||
Cost EvaluateCost(Expr expr, CompileOptions compileOptions);
|
||||
|
||||
/// <summary>
|
||||
/// Evaluate cost of operator.
|
||||
|
|
|
@ -32,14 +32,19 @@ public sealed class Function : BaseFunction
|
|||
{
|
||||
}
|
||||
|
||||
public Function(string name, Expr body, ReadOnlySpan<Var> parameters, Dictionary<Var, Expr[]>? varMap)
|
||||
: base(name, StackVMModuleKind, ArrayUtility.Concat(body, SpanUtility.UnsafeCast<Var, Expr>(parameters)))
|
||||
public Function(string name, string moduleKind, Expr body, ReadOnlySpan<Var> parameters, Dictionary<Var, Expr[]>? varMap)
|
||||
: base(name, moduleKind, ArrayUtility.Concat(body, SpanUtility.UnsafeCast<Var, Expr>(parameters)))
|
||||
{
|
||||
VarMap = varMap ?? new();
|
||||
var dynamicDims = VarMap.Values.SelectMany(x => x).ToArray();
|
||||
_pinner = new ExprPinner(dynamicDims);
|
||||
}
|
||||
|
||||
public Function(string name, Expr body, ReadOnlySpan<Var> parameters, Dictionary<Var, Expr[]>? varMap)
|
||||
: this(name, StackVMModuleKind, body, parameters, varMap)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="Function"/> class.
|
||||
/// build function.
|
||||
|
@ -82,6 +87,6 @@ public sealed class Function : BaseFunction
|
|||
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
|
||||
=> functor.VisitFunction(this, context);
|
||||
|
||||
public Function With(string? name = null, Expr? body = null, Var[]? parameters = null)
|
||||
=> new Function(name ?? Name, body ?? Body, parameters ?? Parameters, VarMap);
|
||||
public Function With(string? name = null, string? moduleKind = null, Expr? body = null, Var[]? parameters = null)
|
||||
=> new Function(name ?? Name, moduleKind ?? ModuleKind, body ?? Body, parameters ?? Parameters, VarMap);
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ public sealed partial class Split : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Split), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(Split), 0, "input", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets axis.
|
||||
|
|
|
@ -104,6 +104,7 @@ public static partial class TypePatternUtility
|
|||
x => x switch
|
||||
{
|
||||
TensorType ttype => shapeCond(ttype.Shape),
|
||||
DistributedType distributedType => shapeCond(distributedType.TensorType.Shape),
|
||||
_ => false,
|
||||
},
|
||||
reason);
|
||||
|
|
|
@ -36,8 +36,9 @@ public interface IEGraphRewriteProvider
|
|||
/// <param name="expr">Expression.</param>
|
||||
/// <param name="rules">Rewrite rules.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
/// <param name="compileOptions">compileOptions.</param>
|
||||
/// <returns>Rewrited expression.</returns>
|
||||
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options);
|
||||
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options, CompileOptions compileOptions);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite egraph.
|
||||
|
|
|
@ -13,12 +13,12 @@ namespace Nncase.PatternMatch;
|
|||
/// <summary>
|
||||
/// Pattern for <see cref="Fusion"/>.
|
||||
/// </summary>
|
||||
public sealed record FusionPattern(Pattern Body, string ModuleKind, VArgsPattern Parameters, string? Name) : Pattern<Fusion>(Name)
|
||||
public sealed record FusionPattern(Pattern Body, Func<string, bool> Condition, VArgsPattern Parameters, string? Name) : Pattern<Fusion>(Name)
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
protected override bool MatchLeafCore(Fusion expr)
|
||||
{
|
||||
return ModuleKind == expr.ModuleKind;
|
||||
return Condition(expr.ModuleKind);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -28,9 +28,11 @@ public static partial class Utility
|
|||
/// Create the Fusion pattern.
|
||||
/// </summary>
|
||||
/// <param name="name">name.</param>
|
||||
/// <param name="module_kind">module kind.</param>
|
||||
/// <param name="moduleKind">module kind.</param>
|
||||
/// <param name="body">body.</param>
|
||||
/// <param name="parameters">params.</param>
|
||||
/// <returns>FusionPattern .</returns>
|
||||
public static FusionPattern IsFusion(string? name, string module_kind, Pattern body, VArgsPattern parameters) => new FusionPattern(body, module_kind, parameters, name);
|
||||
public static FusionPattern IsFusion(string? name, string moduleKind, Pattern body, VArgsPattern parameters) => new FusionPattern(body, m => m == moduleKind, parameters, name);
|
||||
|
||||
public static FusionPattern IsFusion(string? name, Func<string, bool> condition, Pattern body, VArgsPattern parameters) => new FusionPattern(body, condition, parameters, name);
|
||||
}
|
||||
|
|
|
@ -93,4 +93,9 @@ public enum TypeCode : byte
|
|||
/// <see cref="Nncase.ValueType"/>.
|
||||
/// </summary>
|
||||
ValueType = 0xF1,
|
||||
|
||||
/// <summary>
|
||||
/// <see cref="Nncase.VectorType"/>.
|
||||
/// </summary>
|
||||
VectorType = 0xF2,
|
||||
}
|
||||
|
|
|
@ -88,7 +88,7 @@ public static class T
|
|||
/// <returns> for builder. </returns>
|
||||
public static ISequentialBuilder<For> ForLoop(out Var loopVar, Range domain, LoopMode mode, [CallerArgumentExpression("loopVar")] string var_name = "v")
|
||||
{
|
||||
var newLoopVar = loopVar = new Var(var_name.StartsWith("var ") ? var_name[4..] : var_name, TensorType.Scalar(DataTypes.Int32));
|
||||
var newLoopVar = loopVar = new Var(var_name.StartsWith("var ") ? var_name[4..] : var_name, domain.Start.CheckedType);
|
||||
return new SequentialBuilder<For>(body => new For(newLoopVar, domain, mode, body));
|
||||
}
|
||||
|
||||
|
|
|
@ -560,7 +560,14 @@ internal sealed class ILPrintVisitor : ExprFunctor<string, string>
|
|||
|
||||
protected override string VisitBufferOf(BufferOf expr)
|
||||
{
|
||||
return $"bufferof({Visit(expr.Input)})";
|
||||
if (_names.TryGetValue(expr, out var name))
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
name = $"bufferof({Visit(expr.Input)})";
|
||||
_names.Add(expr, name);
|
||||
return name;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
|
@ -578,7 +585,8 @@ internal sealed class ILPrintVisitor : ExprFunctor<string, string>
|
|||
|
||||
// 1. For Loop signature
|
||||
_scope.Append($"{name} = Grid({string.Join(", ", reads)})");
|
||||
AppendCheckedType(expr.CheckedType, " {");
|
||||
AppendCheckedType(expr.CheckedType);
|
||||
_scope.IndWriteLine(" {");
|
||||
|
||||
using (_scope.IndentUp())
|
||||
{
|
||||
|
|
|
@ -15,6 +15,7 @@ namespace Nncase.CostModel;
|
|||
internal sealed class EGraphCostEvaluator
|
||||
{
|
||||
private readonly EClass _root;
|
||||
private readonly CompileOptions _compileOptions;
|
||||
private readonly Dictionary<ENode, Cost> _costs = new(ReferenceEqualityComparer.Instance);
|
||||
private readonly Dictionary<EClass, Cost> _eclassCosts = new();
|
||||
private readonly HashSet<EClass> _allEclasses = new();
|
||||
|
@ -22,9 +23,10 @@ internal sealed class EGraphCostEvaluator
|
|||
private readonly bool _accumulate;
|
||||
private bool _changed;
|
||||
|
||||
public EGraphCostEvaluator(EClass root, IBaseFuncCostEvaluator? basefunc_cost_evaluator, bool accumulate = true)
|
||||
public EGraphCostEvaluator(EClass root, CompileOptions compileOptions, IBaseFuncCostEvaluator? basefunc_cost_evaluator, bool accumulate = true)
|
||||
{
|
||||
_root = root;
|
||||
_compileOptions = compileOptions;
|
||||
_accumulate = accumulate;
|
||||
_baseFuncCostEvaluator = basefunc_cost_evaluator;
|
||||
PopulateAllEclasses(_root);
|
||||
|
@ -163,7 +165,7 @@ internal sealed class EGraphCostEvaluator
|
|||
Cost? newCost;
|
||||
if (targetEnode.Expr is Op op)
|
||||
{
|
||||
var context = new EGraphOpCostEvaluateContext(returnType, enode.Children.Skip(1).Select(x => x.CheckedType).ToArray(), enode.Children.Skip(1).ToArray());
|
||||
var context = new EGraphOpCostEvaluateContext(returnType, enode.Children.Skip(1).Select(x => x.CheckedType).ToArray(), enode.Children.Skip(1).ToArray(), _compileOptions);
|
||||
newCost = CompilerServices.EvaluateOpCost(op, context);
|
||||
}
|
||||
else
|
||||
|
@ -285,13 +287,16 @@ internal sealed class EGraphOpCostEvaluateContext : ICostEvaluateContext
|
|||
private readonly IRType?[] _argumentTypes;
|
||||
private readonly EClass[] _arguments;
|
||||
|
||||
public EGraphOpCostEvaluateContext(IRType? returnType, IRType?[] argumentTypes, EClass[] arguments)
|
||||
public EGraphOpCostEvaluateContext(IRType? returnType, IRType?[] argumentTypes, EClass[] arguments, CompileOptions compileOptions)
|
||||
{
|
||||
_returnType = returnType;
|
||||
_argumentTypes = argumentTypes;
|
||||
_arguments = arguments;
|
||||
CompileOptions = compileOptions;
|
||||
}
|
||||
|
||||
public CompileOptions CompileOptions { get; }
|
||||
|
||||
public T GetArgument<T>(Op op, ParameterInfo parameter)
|
||||
where T : BaseFunction
|
||||
{
|
||||
|
|
|
@ -26,9 +26,10 @@ public static class EGraphExtensions
|
|||
/// </summary>
|
||||
/// <param name="eGraph">egraph.</param>
|
||||
/// <param name="root">Root eclass.</param>
|
||||
/// <param name="compileOptions">compileOptions.</param>
|
||||
/// <param name="basefunc_cost_evaluator">base func cost evaluator.</param>
|
||||
/// <param name="constrains">the cp model constrains.</param>
|
||||
public static Expr Extract(this IEGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, EGraphExtractConstrains[]? constrains = null)
|
||||
public static Expr Extract(this IEGraph eGraph, EClass root, CompileOptions compileOptions, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, EGraphExtractConstrains[]? constrains = null)
|
||||
{
|
||||
// 1. set enode expr with more accuracy type.
|
||||
foreach (var eclass in eGraph.Classes)
|
||||
|
@ -43,7 +44,7 @@ public static class EGraphExtensions
|
|||
}
|
||||
|
||||
// 2. start the cost evaluator
|
||||
var costModel = new CostModel.EGraphCostEvaluator(root.Find(), basefunc_cost_evaluator, false).Evaluate();
|
||||
var costModel = new CostModel.EGraphCostEvaluator(root.Find(), compileOptions, basefunc_cost_evaluator, false).Evaluate();
|
||||
|
||||
return new EGraphExtractor(costModel).Extract(root.Find(), eGraph, constrains);
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ internal class EGraphRewriteProvider : IEGraphRewriteProvider
|
|||
_logger = logger;
|
||||
}
|
||||
|
||||
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
|
||||
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options, CompileOptions compileOptions)
|
||||
{
|
||||
if (expr.CheckedType is null)
|
||||
{
|
||||
|
@ -36,7 +36,7 @@ internal class EGraphRewriteProvider : IEGraphRewriteProvider
|
|||
|
||||
var graph = new EGraph(expr);
|
||||
ERewrite(graph, rules, options);
|
||||
var post = graph.Extract(graph.Root!, null, Array.Empty<EGraphExtractConstrains>());
|
||||
var post = graph.Extract(graph.Root!, compileOptions, null, Array.Empty<EGraphExtractConstrains>());
|
||||
return post;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Nncase.IR;
|
||||
using Nncase.IR.Affine;
|
||||
|
||||
namespace Nncase.Evaluator;
|
||||
|
||||
internal sealed partial class TypeInferenceVisitor
|
||||
{
|
||||
protected override IRType VisitLeafGrid(Grid expr) => expr.Buffers[^1].CheckedType;
|
||||
}
|
|
@ -17,9 +17,10 @@ internal sealed class CostEvaluateContext : ICostEvaluateContext
|
|||
private readonly Dictionary<Expr, Cost> _exprMemo;
|
||||
private Call? _currentCall;
|
||||
|
||||
public CostEvaluateContext(Dictionary<Expr, Cost> exprMemo)
|
||||
public CostEvaluateContext(Dictionary<Expr, Cost> exprMemo, CompileOptions compileOptions)
|
||||
{
|
||||
_exprMemo = exprMemo;
|
||||
CompileOptions = compileOptions;
|
||||
}
|
||||
|
||||
public Call CurrentCall
|
||||
|
@ -28,6 +29,8 @@ internal sealed class CostEvaluateContext : ICostEvaluateContext
|
|||
set => _currentCall = value;
|
||||
}
|
||||
|
||||
public CompileOptions CompileOptions { get; }
|
||||
|
||||
public T GetArgument<T>(Op op, ParameterInfo parameter)
|
||||
where T : BaseFunction
|
||||
{
|
||||
|
|
|
@ -22,7 +22,7 @@ internal sealed class CostEvaluateProvider : ICostEvaluateProvider
|
|||
_serviceProvider = serviceProvider;
|
||||
}
|
||||
|
||||
public Cost EvaluateCost(Expr expr)
|
||||
public Cost EvaluateCost(Expr expr, CompileOptions compileOptions)
|
||||
{
|
||||
if (expr.CheckedType is null)
|
||||
{
|
||||
|
@ -34,7 +34,7 @@ internal sealed class CostEvaluateProvider : ICostEvaluateProvider
|
|||
throw new InvalidOperationException("Expr in Cost Evaluator need a valid type");
|
||||
}
|
||||
|
||||
var evaluatorVisitor = new CostEvaluateVisitor();
|
||||
var evaluatorVisitor = new CostEvaluateVisitor(compileOptions);
|
||||
return evaluatorVisitor.Visit(expr);
|
||||
}
|
||||
|
||||
|
|
|
@ -15,9 +15,9 @@ internal sealed class CostEvaluateVisitor : ExprVisitor<Cost, Unit>
|
|||
{
|
||||
private readonly CostEvaluateContext _context;
|
||||
|
||||
public CostEvaluateVisitor()
|
||||
public CostEvaluateVisitor(CompileOptions compileOptions)
|
||||
{
|
||||
_context = new CostEvaluateContext(ExprMemo);
|
||||
_context = new CostEvaluateContext(ExprMemo, compileOptions);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
|
@ -41,7 +41,7 @@ internal sealed class CostEvaluateVisitor : ExprVisitor<Cost, Unit>
|
|||
var targetCost = expr.Target switch
|
||||
{
|
||||
Op op => CompilerServices.EvaluateOpCost(op, _context),
|
||||
Function func => CompilerServices.EvaluateCost(func.Body),
|
||||
Function func => CompilerServices.EvaluateCost(func.Body, _context.CompileOptions),
|
||||
_ => throw new NotImplementedException(expr.Target.ToString()),
|
||||
};
|
||||
return argumentsCost + targetCost;
|
||||
|
|
|
@ -126,7 +126,7 @@ internal sealed class EvaluateVisitor : ExprVisitor<IValue, Unit>, IDisposable
|
|||
{
|
||||
Op op => CompilerServices.EvaluateOp(op, _context, _evaluator_cache),
|
||||
Function func => CompilerServices.Evaluate(func.Body, CreateFunctionEvaluateArguments(func.Parameters, expr.Arguments), _evaluator_cache),
|
||||
Fusion { ModuleKind: "stackvm" } fusion => CompilerServices.Evaluate(fusion.Body, CreateFunctionEvaluateArguments(fusion.Parameters, expr.Arguments), _evaluator_cache),
|
||||
Fusion fusion => CompilerServices.Evaluate(fusion.Body, CreateFunctionEvaluateArguments(fusion.Parameters, expr.Arguments), _evaluator_cache),
|
||||
_ => throw new NotImplementedException(expr.Target.ToString()),
|
||||
};
|
||||
}
|
||||
|
|
|
@ -35,7 +35,8 @@ public class SliceEvaluator : IEvaluator<Slice>, ITypeInferencer<Slice>, ICostEv
|
|||
var ends = context.GetInt64OrtTensorArgumentValue(sl, Slice.Ends);
|
||||
var axes = context.GetInt64OrtTensorArgumentValue(sl, Slice.Axes);
|
||||
var strides = context.GetInt64OrtTensorArgumentValue(sl, Slice.Strides);
|
||||
return OrtKI.Slice(input, begins, ends, axes, strides).ToValue();
|
||||
var sliced = OrtKI.Slice(input, begins, ends, axes, strides);
|
||||
return Value.FromTensor(context.CurrentCall.CheckedType is AnyType ? sliced.ToTensor() : sliced.ToTensor(context.CurrentCall.CheckedTensorType));
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
|
|
|
@ -242,6 +242,33 @@ internal sealed partial class TypeInferenceVisitor : ExprVisitor<IRType, Unit>
|
|||
return type;
|
||||
}
|
||||
|
||||
protected override IRType VisitLeafGrid(Grid expr)
|
||||
{
|
||||
foreach (var p in expr.BodyParameters)
|
||||
{
|
||||
VerifySubField(expr, p);
|
||||
}
|
||||
|
||||
foreach (var p in expr.AccessMaps)
|
||||
{
|
||||
VerifySubField(expr, p);
|
||||
}
|
||||
|
||||
foreach (var p in expr.Buffers)
|
||||
{
|
||||
VerifySubField(expr, p);
|
||||
}
|
||||
|
||||
foreach (var p in expr.Reads)
|
||||
{
|
||||
VerifySubField(expr, p);
|
||||
}
|
||||
|
||||
VerifySubField(expr, expr.Body);
|
||||
|
||||
return expr.Buffers[^1].CheckedType;
|
||||
}
|
||||
|
||||
protected override IRType VisitLeafAffineExpr(AffineExpr expr) => TensorType.Scalar(DataTypes.Int64);
|
||||
|
||||
protected override IRType VisitLeafAffineDomain(AffineDomain expr) => new TupleType(ImmutableArray.Create(expr.Offset.CheckedType, expr.Extent.CheckedType));
|
||||
|
|
|
@ -7,6 +7,7 @@ using System.IO;
|
|||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.Evaluator;
|
||||
using Nncase.IR;
|
||||
|
@ -15,16 +16,25 @@ namespace Nncase.Passes;
|
|||
|
||||
public sealed class EGraphExtractPass : Pass<IEGraph, BaseFunction>
|
||||
{
|
||||
private readonly IBaseFuncCostEvaluator? _costEvaluator;
|
||||
private readonly CompileOptions _compileOptions;
|
||||
private IBaseFuncCostEvaluator? _costEvaluator;
|
||||
|
||||
public EGraphExtractPass(IBaseFuncCostEvaluator? costEvaluator = null)
|
||||
public EGraphExtractPass(CompileOptions compileOptions)
|
||||
{
|
||||
_costEvaluator = costEvaluator;
|
||||
_compileOptions = compileOptions;
|
||||
}
|
||||
|
||||
public void AddBaseFuncCostEvaluator<T>(params object[] parameters)
|
||||
where T : IBaseFuncCostEvaluator
|
||||
{
|
||||
var compileSession = ((IPassIntern)this).CompileSession;
|
||||
using var scope = new CompileSessionScope(compileSession);
|
||||
_costEvaluator = ActivatorUtilities.CreateInstance<T>(compileSession, parameters);
|
||||
}
|
||||
|
||||
protected override Task<BaseFunction> RunCoreAsync(IEGraph input, RunPassContext context)
|
||||
{
|
||||
var post = (BaseFunction)input.Extract(input.Root!, _costEvaluator, Array.Empty<EGraphExtractConstrains>());
|
||||
var post = (BaseFunction)input.Extract(input.Root!, _compileOptions, _costEvaluator, Array.Empty<EGraphExtractConstrains>());
|
||||
IRHelpers.DCE(post);
|
||||
return Task.FromResult(post);
|
||||
}
|
||||
|
|
|
@ -203,7 +203,7 @@ internal sealed class PassManager : IPassManager
|
|||
|
||||
if (replaced && DumpScope.Current.IsEnabled(DumpFlags.PassIR))
|
||||
{
|
||||
DumpScope.Current.DumpModule(module, $"Epoch_{StartPassIndex}/After");
|
||||
DumpScope.Current.DumpModule(module, $"FunctionCallUpdate_{StartPassIndex}/After");
|
||||
}
|
||||
|
||||
return module;
|
||||
|
|
|
@ -46,9 +46,9 @@ internal sealed class AffineTiler
|
|||
for (int loop = 0; loop < _loopBuilders.Length; loop++)
|
||||
{
|
||||
var domain = schedule.Loops[loop].Domain.Offset.Position;
|
||||
var begin = 0;
|
||||
var end = begin + _dims[domain];
|
||||
var stride = schedule.Loops[loop].TileSize;
|
||||
var begin = 0ul;
|
||||
var end = begin + (ulong)_dims[domain];
|
||||
var stride = (ulong)schedule.Loops[loop].TileSize;
|
||||
_domainExtents[domain] = stride;
|
||||
_loopBuilders[loop] = T.ForLoop(out _domainOffsets[domain], (begin, end, stride), LoopMode.Serial, $"l{loop}");
|
||||
}
|
||||
|
@ -71,24 +71,47 @@ internal sealed class AffineTiler
|
|||
|
||||
// 3. Nest compute body
|
||||
var bodyBuffers = new Expr[_grid.Buffers.Length];
|
||||
var bufferOfVars = new Expr[_grid.Reads.Length + 1];
|
||||
var bodyVarReplaces = new Dictionary<Expr, Expr>();
|
||||
var bufferOfReplaces = new Dictionary<Expr, Expr>();
|
||||
for (int i = 0; i < bodyBuffers.Length; i++)
|
||||
{
|
||||
(bodyBuffers[i], cntBlock) = AllocateSubBuffer(cntBlock, _tempBuffers[i], schedule.BodyBufferViews[i]);
|
||||
bodyVarReplaces.Add(_grid.BodyParameters[i], bodyBuffers[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < _grid.Reads.Length; i++)
|
||||
{
|
||||
bufferOfVars[i] = AllocateBufferOf(_grid.Buffers[i], i);
|
||||
bufferOfReplaces.Add(_grid.Buffers[i], bufferOfVars[i]);
|
||||
}
|
||||
|
||||
bufferOfVars[^1] = _grid.Buffers[^1];
|
||||
|
||||
var cloner = new ReplacingExprCloner(bodyVarReplaces);
|
||||
var nestBody = cloner.Clone(_grid.Body, default);
|
||||
cntBlock.Body(nestBody);
|
||||
|
||||
// 4. Create PrimFunction
|
||||
var body = root.Build();
|
||||
var primFunc = new PrimFunction(_grid.ModuleKind, body, _grid.Buffers);
|
||||
cloner = new ReplacingExprCloner(bufferOfReplaces);
|
||||
body = cloner.Clone(body, default);
|
||||
var primFunc = new PrimFunction(_grid.ModuleKind, body, bufferOfVars);
|
||||
var wrapper = new PrimFunctionWrapper(primFunc, _grid.Buffers.Length - 1);
|
||||
module.Add(primFunc);
|
||||
module.Add(wrapper);
|
||||
return new Call(wrapper, _grid.Buffers);
|
||||
|
||||
// module.Add(primFunc);
|
||||
// module.Add(wrapper);
|
||||
return new Call(wrapper, _grid.Reads);
|
||||
}
|
||||
|
||||
private TIR.Buffer AllocateBufferOf(Expr expr, int i)
|
||||
{
|
||||
if (expr is IR.Buffers.BufferOf bufof)
|
||||
{
|
||||
return T.CreateBuffer(bufof.Input.CheckedTensorType, MemoryLocation.Input, out _, "buffer_" + i.ToString());
|
||||
}
|
||||
|
||||
throw new NotSupportedException();
|
||||
}
|
||||
|
||||
private ISequentialBuilder<Expr> AllocateTempBuffers(GridSchedule.Place place, ISequentialBuilder<Expr> sequential)
|
||||
|
|
|
@ -112,7 +112,7 @@ public class TilingSolver
|
|||
var objeciveMonitor = _solver.MakeMinimize(_objective, 1);
|
||||
var searchLog = _solver.MakeSearchLog(100000, objeciveMonitor);
|
||||
var searchLimit = _solver.MakeImprovementLimit(_objective, false, 1, 0, 1, 2);
|
||||
var timeLimit = _solver.MakeTimeLimit(5000);
|
||||
var timeLimit = _solver.MakeTimeLimit(50000);
|
||||
|
||||
_solver.Solve(_decisionBuilder, new SearchMonitor[] { objeciveMonitor, searchLimit, timeLimit, searchLog, _solutionCollector });
|
||||
|
||||
|
@ -557,8 +557,11 @@ public class TilingSolver
|
|||
}
|
||||
}
|
||||
|
||||
var constraint = placeVar * (anyOrder ?? _solver.MakeIntConst(1)) == 0;
|
||||
_solver.Add(constraint);
|
||||
if (anyOrder != null)
|
||||
{
|
||||
var constraint = placeVar * (anyOrder ?? _solver.MakeIntConst(1)) == 0;
|
||||
_solver.Add(constraint);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -39,6 +39,13 @@ public sealed class TestVisitor : ExprVisitor<bool, IRType>
|
|||
return count;
|
||||
}
|
||||
|
||||
public int CountCallFusion<T>()
|
||||
where T : Fusion
|
||||
{
|
||||
var count = ExprMemo.Keys.OfType<Call>().Where(call => call is { Target: T }).Count();
|
||||
return count;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override bool DefaultVisitLeaf(Expr expr) => true;
|
||||
}
|
||||
|
|
|
@ -133,7 +133,8 @@ public sealed class UnitTestDumpper : TestClassBase
|
|||
new Passes.Rules.Lower.RemoveMarker(),
|
||||
new TestMulToAdd(),
|
||||
},
|
||||
new());
|
||||
new(),
|
||||
CompileOptions);
|
||||
|
||||
Assert.True(File.Exists(Path.Join(Dumpper.Directory, "Costs/Solve.txt")));
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -48,7 +48,7 @@ public class UnitTestEGraphRewrite : TestClassBase
|
|||
{
|
||||
Expr pre = (Const)10 * 11 * 12;
|
||||
var rule = new Passes.Rules.Neutral.ReassociateMul();
|
||||
CompilerServices.ERewrite(pre, new[] { rule }, new());
|
||||
CompilerServices.ERewrite(pre, new[] { rule }, new(), CompileOptions);
|
||||
|
||||
// Assert.Equal(newExpr, 10 * ((Const)11 * 12));
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ public class UnitTestEGraphRewrite : TestClassBase
|
|||
|
||||
Assert.True(pre.InferenceType());
|
||||
|
||||
var post = CompilerServices.ERewrite(pre, new[] { new Passes.Rules.Neutral.CombineBinaryTranspose() }, new());
|
||||
var post = CompilerServices.ERewrite(pre, new[] { new Passes.Rules.Neutral.CombineBinaryTranspose() }, new(), CompileOptions);
|
||||
|
||||
Assert.True(post.InferenceType());
|
||||
Assert.Equal(pre.Evaluate(), post.Evaluate());
|
||||
|
@ -106,7 +106,8 @@ public class UnitTestEGraphRewrite : TestClassBase
|
|||
new Passes.Rules.Lower.RemoveMarker(),
|
||||
new TestMulToAdd(),
|
||||
},
|
||||
new());
|
||||
new(),
|
||||
CompileOptions);
|
||||
|
||||
Assert.True(post.InferenceType());
|
||||
|
||||
|
|
|
@ -53,14 +53,14 @@ public class UnitTestFoldSlice : TransformTestBase
|
|||
new object[]
|
||||
{
|
||||
new[] { 4, 4, 6, 8 },
|
||||
new[] { 0 }, new[] { 6 }, new[] { 3 }, new[] { -3 },
|
||||
new[] { 0 }, new[] { 4 }, new[] { 2 }, new[] { -2 },
|
||||
new[] { 0 }, new[] { 6 }, new[] { 3 }, new[] { 3 },
|
||||
new[] { 0 }, new[] { 4 }, new[] { 2 }, new[] { 2 },
|
||||
}, // negative axis
|
||||
new object[]
|
||||
{
|
||||
new[] { 3, 4, 6, 8 },
|
||||
new[] { 0 }, new[] { -1 }, new[] { 3 }, new[] { -3 },
|
||||
new[] { -5 }, new[] { 4 }, new[] { 2 }, new[] { -2 },
|
||||
new[] { 0 }, new[] { -1 }, new[] { 3 }, new[] { 3 },
|
||||
new[] { -5 }, new[] { 4 }, new[] { 2 }, new[] { 2 },
|
||||
}, // negative begin|end
|
||||
}.Select((o, i) => o.Concat(new object[] { i }).ToArray());
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ public sealed class UnitTestCPUKernels : TestClassBase
|
|||
public static int Rank => 1;
|
||||
|
||||
[Theory]
|
||||
[InlineData(new object[] { new[] { 1, 512, 64, 64 }, 0 })]
|
||||
[InlineData(new object[] { new[] { 32, 512, 64, 64 }, 0 })]
|
||||
public async Task TestSwish(int[] shape, int count)
|
||||
{
|
||||
var input = new Var(new TensorType(DataTypes.Float32, shape));
|
||||
|
@ -69,7 +69,7 @@ public sealed class UnitTestCPUKernels : TestClassBase
|
|||
{ input, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 1, shape).Evaluate() },
|
||||
};
|
||||
|
||||
var rule = new Passes.Rules.CPU.PackSwish() { Lane = Lane, Rank = Rank };
|
||||
var rule = new Passes.Rules.CPU.PackSwish(Rank, Lane);
|
||||
CompilerServices.TryMatch(pre, rule.Pattern, out var result);
|
||||
var posts = new[] { pre }.Concat(rule.GetReplaceCandidates(result!, new Passes.RunPassContext()).Where(e => e is not Call { Target: Slice }));
|
||||
await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{count}"), feedDict, posts);
|
||||
|
@ -89,7 +89,7 @@ public sealed class UnitTestCPUKernels : TestClassBase
|
|||
{ rhs, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 3, rhsShape).Evaluate() },
|
||||
};
|
||||
|
||||
var rule = new Passes.Rules.CPU.PackBinary() { Lane = Lane, Rank = Rank };
|
||||
var rule = new Passes.Rules.CPU.PackBinary(Rank, Lane);
|
||||
CompilerServices.TryMatch(pre, rule.Pattern, out var result);
|
||||
var posts = rule.GetReplaceCandidates(result!, new Passes.RunPassContext());
|
||||
await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{count}"), feedDict, posts);
|
||||
|
@ -112,7 +112,7 @@ public sealed class UnitTestCPUKernels : TestClassBase
|
|||
{ bias, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 1, pshape).Evaluate() },
|
||||
};
|
||||
|
||||
var rule = new Passes.Rules.CPU.PackLayerNorm() { Lane = Lane, Rank = Rank };
|
||||
var rule = new Passes.Rules.CPU.PackLayerNorm(Rank, Lane);
|
||||
CompilerServices.TryMatch(pre, rule.Pattern, out var result);
|
||||
var posts = new[] { pre }.Concat(rule.GetReplaceCandidates(result!, new Passes.RunPassContext())).Where(e => e is not Call { Target: Slice });
|
||||
await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{count}"), feedDict, posts);
|
||||
|
@ -135,7 +135,7 @@ public sealed class UnitTestCPUKernels : TestClassBase
|
|||
{ bias, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 1, pshape).Evaluate() },
|
||||
};
|
||||
|
||||
var rule = new Passes.Rules.CPU.PackInstanceNorm() { Lane = Lane, Rank = Rank };
|
||||
var rule = new Passes.Rules.CPU.PackInstanceNorm(Rank, Lane);
|
||||
CompilerServices.TryMatch(pre, rule.Pattern, out var result);
|
||||
var posts = new[] { pre }.Concat(rule.GetReplaceCandidates(result!, new Passes.RunPassContext())).Where(e => e is not Call { Target: Slice });
|
||||
await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{count}"), feedDict, posts);
|
||||
|
@ -153,7 +153,7 @@ public sealed class UnitTestCPUKernels : TestClassBase
|
|||
{ input, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 1, shape).Evaluate() },
|
||||
};
|
||||
|
||||
var rule = new Passes.Rules.CPU.PackResizeImage() { Lane = Lane, Rank = Rank };
|
||||
var rule = new Passes.Rules.CPU.PackResizeImage(Rank, Lane);
|
||||
CompilerServices.TryMatch(pre, rule.Pattern, out var result);
|
||||
var posts = new[] { pre }.Concat(rule.GetReplaceCandidates(result!, new Passes.RunPassContext())).Where(e => e is not Call { Target: Slice });
|
||||
await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{count}"), feedDict, posts);
|
||||
|
@ -173,7 +173,7 @@ public sealed class UnitTestCPUKernels : TestClassBase
|
|||
{ rhs, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 3, rhsShape).Evaluate() },
|
||||
};
|
||||
|
||||
var rule = new Passes.Rules.CPU.PackMatMul() { Lane = Lane, Rank = Rank };
|
||||
var rule = new Passes.Rules.CPU.PackMatMul(Rank, Lane);
|
||||
CompilerServices.TryMatch(pre, rule.Pattern, out var result);
|
||||
var posts = new[] { pre }.Concat(rule.GetReplaceCandidates(result!, new Passes.RunPassContext()));
|
||||
await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{count}"), feedDict, posts);
|
||||
|
@ -336,13 +336,14 @@ public sealed class UnitTestCPUKernels : TestClassBase
|
|||
|
||||
internal async Task RunCases(string dumpDir, Dictionary<Var, IValue> feedDict, IEnumerable<Expr> posts)
|
||||
{
|
||||
int count = 0;
|
||||
foreach (var post in posts)
|
||||
var postArray = posts.ToArray();
|
||||
using var pinner = new ExprPinner(postArray);
|
||||
for (int i = 0; i < postArray.Length; i++)
|
||||
{
|
||||
#if DEBUG
|
||||
System.Console.WriteLine(CompilerServices.Print(post));
|
||||
System.Console.WriteLine(CompilerServices.Print(postArray[i]));
|
||||
#endif
|
||||
var kernelCase = new CpuKernelCase($"Case{count++}", new Fusion("kernel", CPUTarget.Kind, post, feedDict.Keys.ToArray()), feedDict.Keys.ToArray(), feedDict.Values.Select(v => v.AsTensor()).ToArray());
|
||||
var kernelCase = new CpuKernelCase($"Case{i}", new Fusion("kernel", CPUTarget.Kind, postArray[i], feedDict.Keys.ToArray()), feedDict.Keys.ToArray(), feedDict.Values.Select(v => v.AsTensor()).ToArray());
|
||||
await Run(dumpDir, kernelCase);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue