Refactor Compiler Pipeline (#1195)

* refactor pipeline

* fix distributed

* fix typo

* fisrt version of graph partition using egraph

* add graph partition

* enable affine

* fix circle case

* add fusion test

* graph partition support tuple

* graph partition support concat

* fix affine pipeline

* fix ci bugs

* Fix Tiling.AddReductionPlacesConstraints

* fix concat fusion when all inputs are Var

* fix ci

* fix slice test

* fix upsample

* fix complex merge

* add GetItem merge

* pass llama

* remove hashcode in fusion name

* fix compile

* skip split

* fix possible merge error

* fix build under gcc13

* fix slice

* Apply code-format changes

* skip conv2d with relu likes

* skip non-fixed shape in partition
* change vector type serializer struct

* change the lane type to size_t

* fixed type serializer
dev/3.0
郑启航 2024-05-07 10:45:51 +08:00 committed by GitHub
parent 7a682a3dd8
commit 952f89bb72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
65 changed files with 2486 additions and 595 deletions

1
.gitignore vendored
View File

@ -308,3 +308,4 @@ cmake-build-*
*.ipynb_checkpoints*
# Auto generated files
# generated/
.history/

View File

@ -13,7 +13,6 @@ using System.Reactive;
using System.Runtime.InteropServices;
using System.Text;
using DryIoc;
using Google.OrTools.Sat;
using NetFabric.Hyperlinq;
using Nncase.IR;
using Nncase.Runtime;
@ -176,7 +175,7 @@ internal sealed class DeviceCSourceConvertVisitor : ExprFunctor<CSymbol, Unit>
TensorType { Shape: { IsRanked: true } } x => x.Shape.IsFixed switch
{
true => $"tensor_view<{x.DType.ToC()}, fixed_shape<{x.Shape.ToString()[1..^1]}>>",
false => $"tensor_view<{x.DType.ToC()}, ranked_shape<{x.Shape.Rank}>>",
false => "auto",
},
_ => throw new NotSupportedException(),
};
@ -217,6 +216,26 @@ internal sealed class DeviceCSourceConvertVisitor : ExprFunctor<CSymbol, Unit>
break;
case IR.Buffers.Allocate op:
str = $"({type})runtime_util->malloc({arguments[0].Name})";
break;
case IR.Buffers.BufferSubview op:
{
var arg0 = expr.Arguments[1] switch
{
TupleConst => $"fixed_shape<{arguments[1].Name}>{{}}",
IR.Tuple tc => $"ranked_shape<{tc.Count}>{{{arguments[1].Name}}}",
_ => throw new ArgumentOutOfRangeException(nameof(expr)),
};
var arg1 = expr.Arguments[2] switch
{
TupleConst => $"fixed_shape<{arguments[2].Name}>{{}}",
IR.Tuple tc => $"ranked_shape<{tc.Count}>{{{arguments[2].Name}}}",
_ => throw new ArgumentOutOfRangeException(nameof(expr)),
};
str = $"{arguments[0].Name}.view({arg0}, {arg1})";
}
break;
case IR.Buffers.AllocateBufferView op:
{
@ -252,6 +271,14 @@ internal sealed class DeviceCSourceConvertVisitor : ExprFunctor<CSymbol, Unit>
BinaryOp = op.BinaryOp,
}).Result);
break;
case TIR.CPU.Swish swish:
if (swish.Beta != 1.0f)
{
throw new NotSupportedException();
}
IndentScope.Writer.IndWrite($"unary<ops::swish>({arguments[0].Name}, {arguments[1].Name});\n");
break;
default:
throw new NotSupportedException();
}
@ -298,6 +325,34 @@ internal sealed class DeviceCSourceConvertVisitor : ExprFunctor<CSymbol, Unit>
return symbol;
}
protected override CSymbol VisitTupleConst(TupleConst tp)
{
if (_exprMemo.TryGetValue(tp, out var symbol))
{
return symbol;
}
string type = string.Empty;
string str = $"{string.Join(",", tp.Value.Select(x => Visit(Const.FromValue(x)).Name))}";
symbol = new(type, str);
_exprMemo.Add(tp, symbol);
return symbol;
}
protected override CSymbol VisitTuple(IR.Tuple tp)
{
if (_exprMemo.TryGetValue(tp, out var symbol))
{
return symbol;
}
string type = string.Empty;
string str = $"{string.Join(",", tp.Fields.AsValueEnumerable().Select(x => Visit(x).Name).ToArray())}";
symbol = new(type, str);
_exprMemo.Add(tp, symbol);
return symbol;
}
/// <inheritdoc/>
protected override CSymbol VisitSequential(Sequential expr)
{

View File

@ -62,12 +62,11 @@ internal class FunctionBuilder
return new LinkableKernelFunction(_id, function, functionCSource, _sectionManager.GetContent(WellknownSectionNames.Text)!, new LinkedSection(_sectionManager.GetContent(KernelHeaderSectionName), KernelHeaderSectionName, 0, 8, (uint)sizeof(DescHeader)));
}
else if (function.Name.EndsWith("device"))
else
{
var visitor = new DeviceCSourceConvertVisitor();
visitor.Visit(function);
var header = visitor.GetHeader();
return new LinkableDeviceFunction(_id, function, header, _sectionManager.GetContent(WellknownSectionNames.Text)!);
}

View File

@ -132,7 +132,7 @@ internal sealed class KernelCSourceConvertVisitor : ExprFunctor<CSymbol, Unit>,
public KernelCSource GetCSource()
{
var ctype = $"void {VisitEntry.Name}({string.Join(", ", VisitEntry.Parameters.AsValueEnumerable().Select(Visit).Select(s => $"{s.Type} {s.Name}").ToArray().Concat(_exprMemo.Keys.OfType<TIR.Buffer>().Where(b => b.MemSpan.Location == MemoryLocation.Rdata).Select(Visit).Select(s => $" {s.Type} {s.Name}").ToArray()))}, uint8_t* l1_data)";
var ctype = $"void {VisitEntry.Name}({string.Join(", ", VisitEntry.Parameters.AsValueEnumerable().Select(Visit).Select(s => $"{s.Type} {s.Name}").ToArray().Concat(_exprMemo.Keys.OfType<TIR.Buffer>().Where(b => b.MemSpan.Location == MemoryLocation.Rdata).Select(Visit).Select(s => $" {s.Type} {s.Name}").ToArray()))}, uint8_t* data)";
return new(
CSourceBuiltn.MakeMain(VisitEntry, _exprMemo.Keys.OfType<TIR.Buffer>().Where(b => b.MemSpan.Location == MemoryLocation.Rdata)),
CSourceBuiltn.MakeKernel(ctype, _kernelBuilder.ToString()));

View File

@ -32,18 +32,34 @@ public sealed class BoxingEvaluator : ITypeInferencer<Boxing>, ICostEvaluator<Bo
switch (inType, returnType)
{
case (TensorType tensorType, DistributedType distTensorType):
cost = new Cost()
switch (context.CompileOptions.TargetCompileOptions)
{
[CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(tensorType),
[CostFactorNames.MemoryStore] = (UInt128)((float)CostUtility.GetMemoryAccess(distTensorType) / DistributedUtility.GetDividedTensorEfficiency(distTensorType, _burstLength)),
};
case Targets.CpuTargetOptions { UnifiedMemoryArchitecture: true }:
break;
default:
cost = new Cost()
{
[CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(tensorType),
[CostFactorNames.MemoryStore] = (UInt128)((float)CostUtility.GetMemoryAccess(distTensorType) / DistributedUtility.GetDividedTensorEfficiency(distTensorType, _burstLength)),
};
break;
}
break;
case (DistributedType distTensorType, TensorType tensorType):
cost = new Cost()
switch (context.CompileOptions.TargetCompileOptions)
{
[CostFactorNames.MemoryLoad] = (UInt128)((float)CostUtility.GetMemoryAccess(distTensorType) / DistributedUtility.GetDividedTensorEfficiency(distTensorType, _burstLength)),
[CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(tensorType),
};
case Targets.CpuTargetOptions { UnifiedMemoryArchitecture: true }:
break;
default:
cost = new Cost()
{
[CostFactorNames.MemoryLoad] = (UInt128)((float)CostUtility.GetMemoryAccess(distTensorType) / DistributedUtility.GetDividedTensorEfficiency(distTensorType, _burstLength)),
[CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(tensorType),
};
break;
}
break;
case (DistributedType a, DistributedType b) when a.Placement == b.Placement && a.NdSBP != b.NdSBP:

View File

@ -1,34 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Linq;
using Nncase.CostModel;
using Nncase.IR;
using Nncase.IR.CPU;
namespace Nncase.Evaluator.IR.CPU;
/// <summary>
/// Evaluator for <see cref="CPUKernelOp"/>.
/// </summary>
public class CPUKernelOpEvaluator : IEvaluator<CPUKernelOp>, ITypeInferencer<CPUKernelOp>, ICostEvaluator<CPUKernelOp>
{
/// <inheritdoc/>
public IValue Visit(IEvaluateContext context, CPUKernelOp target)
{
return CompilerServices.EvaluateOp(target.Target, context);
}
/// <inheritdoc/>
public IRType Visit(ITypeInferenceContext context, CPUKernelOp target)
{
return CompilerServices.InferenceOp(target.Target, context, new());
}
/// <inheritdoc/>
public Cost Visit(ICostEvaluateContext context, CPUKernelOp target)
{
return CompilerServices.EvaluateOpCost(target.Target, context);
}
}

View File

@ -14,7 +14,6 @@ internal class CPUModule : IApplicationPart
public void ConfigureServices(IRegistrator registrator)
{
registrator.RegisterManyInterface<BoxingEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<CPUKernelOpEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<LoadEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<StoreEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<PackEvaluator>(reuse: Reuse.Singleton);

View File

@ -38,10 +38,6 @@ public sealed class PackedMatMulEvaluator : IEvaluator<PackedMatMul>, ITypeInfer
rhs = rhs.Unpack(axis);
}
// lhs = OrtKI.Unsqueeze(lhs, new long[] { -4, -1 }); // [x,m/32,k/32, 1 , m' ,k', 1 ]
// rhs = OrtKI.Unsqueeze(rhs, new long[] { -6, -3 }); // [x, 1 ,k/32,n/32, 1 ,k', n']
// var matmul = OrtKI.Mul(lhs, rhs); // [x, m/32,k/32,n/32,m',k',n']
// matmul = OrtKI.ReduceSum(matmul, new long[] { -2, -5 }, 0, 1);
var matmul = OrtKI.MatMul(lhs, rhs);
if (target.LhsPackedAxes.Count == 2)
{
@ -112,7 +108,7 @@ public sealed class PackedMatMulEvaluator : IEvaluator<PackedMatMul>, ITypeInfer
rType = Math.MatMulEvaluator.VisitTensorType(a, b);
break;
default:
ERROR: rType = new InvalidType($"{lhs} {rhs} not support");
ERROR: rType = new InvalidType($"lhs: {lhs}, rhs: {rhs} not support");
break;
}

View File

@ -1,33 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR.Math;
using Nncase.PatternMatch;
namespace Nncase.IR.CPU;
public sealed class CPUKernelOp : Op
{
private readonly ExprPinner _exprPinner;
public CPUKernelOp(Op target)
{
_exprPinner = new(target);
Target = target;
}
/// <summary>
/// Gets the target.
/// </summary>
public Op Target { get; }
/// <inheritdoc/>
public override IEnumerable<ParameterInfo> Parameters => Target.Parameters;
public override string DisplayProperty() => Target.GetType().Name;
}

View File

@ -12,17 +12,6 @@ namespace Nncase.IR.F;
public partial class CPU
{
/// <summary>
/// Call cpu kernel.
/// </summary>
/// <param name="target">Unary operator.</param>
/// <param name="inputs">Source inputs.</param>
/// <returns>Result expression.</returns>
public static Call CPUKernel(Op target, params Expr[] inputs)
{
return new Call(new CPUKernelOp(target), inputs);
}
public static Call Boxing(Expr input, IRType type)
{
return new Call(new Boxing(type), input);

View File

@ -14,49 +14,42 @@ using static Nncase.PatternMatch.Utility;
[assembly: InternalsVisibleTo("Nncase.Tests")]
namespace Nncase.Passes.Rules;
namespace Nncase.Passes;
/// <summary>
/// auto distributed the xpu fusion.
/// </summary>
[RuleGenerator]
public sealed partial class AutoDistributed : IRewriteRule
public sealed partial class AutoDistributedPass : FunctionPass
{
private readonly CompileOptions _compileOptions;
public AutoDistributed(CompileOptions compileOptions)
public AutoDistributedPass(CompileOptions compileOptions)
{
_compileOptions = compileOptions;
}
public IPattern Pattern { get; } = IsCallWildcard("call", IsFusion("fusion", CPUTarget.Kind, IsWildcard("body"), IsVArgsRepeat("parameters", () => IsVar())));
private Expr? GetReplace(Call call, Fusion fusion, Expr body, IReadOnlyList<Expr> parameters, IReadOnlyList<Expr> callParams)
protected override Task<BaseFunction> RunCoreAsync(BaseFunction input, RunPassContext context)
{
// 1. convert to distribute graph
if (body is Call { Target: Boxing } || (body is IR.Tuple tp && tp.Fields.AsValueEnumerable().Any(e => e is Call { Target: Boxing })))
{
return null;
}
var distConverter = new AutoDistributedConvertVisitor(_compileOptions.TargetCompileOptions is CpuTargetOptions options ? options : new CpuTargetOptions());
var newbody = distConverter.Convert(body);
var newFusion = fusion.With(moduleKind: CPUTarget.Kind, body: newbody, parameters: parameters.Cast<Var>().ToArray());
return new Call(newFusion, callParams.ToArray());
var rewriter = new AutoDistributedRewriter(_compileOptions, _compileOptions.TargetCompileOptions is CpuTargetOptions options ? options : new CpuTargetOptions());
return Task.FromResult(rewriter.Rewirte(input));
}
}
internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRType, List<Expr>>, Unit>
internal sealed class AutoDistributedRewriter : ExprVisitor<Dictionary<IRType, List<Expr>>, Unit>
{
public AutoDistributedConvertVisitor(CpuTargetOptions compileOptions)
public AutoDistributedRewriter(CompileOptions compileOptions, CpuTargetOptions targetOptions)
{
Placement = new Placement(compileOptions.Hierarchy, compileOptions.HierarchyNames);
Placement = new Placement(targetOptions.Hierarchy, targetOptions.HierarchyNames);
CompileOptions = compileOptions;
TargetOptions = targetOptions;
}
public Placement Placement { get; }
public CpuTargetOptions CompileOptions { get; }
public CompileOptions CompileOptions { get; }
public CpuTargetOptions TargetOptions { get; }
public static IReadOnlyList<Expr> GetLeafCandidateBoxings(Expr expr, Placement placement)
{
@ -65,41 +58,31 @@ internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRT
ToArray();
}
public Expr Convert(Expr body)
public BaseFunction Rewirte(BaseFunction input)
{
var createFinalBoxing = (Expr e, TensorType type) =>
if (input is Function function)
{
var d = (DistributedType)e.CheckedType;
if (d.NdSBP.Any(s => s is SBPPartialSum))
var equivalents = Visit(function.Body).Select(g => InstertTerminator(g.Value[0])).ToArray();
using (new ExprPinner(equivalents))
{
var boxingP2B = IR.F.CPU.Boxing(e, new DistributedType(type, d.NdSBP.Select(s => s is SBPPartialSum ? SBP.B : s).ToArray(), Placement));
return IR.F.CPU.Boxing(boxingP2B, type);
BranchCut();
}
return IR.F.CPU.Boxing(e, type);
};
var equivalents = Visit(body).Select(g => g.Value[0] switch
{
IR.Tuple tp => new IR.Tuple(tp.Fields.ToArray().Select((f, i) => createFinalBoxing(f, (TensorType)((IR.Tuple)body).Fields[i].CheckedType)).ToArray()),
Expr e => (Expr)createFinalBoxing(e, (TensorType)body.CheckedType),
}).ToArray();
using (new ExprPinner(equivalents))
{
BranchCut();
}
var graph = new EGraph();
foreach (var (exprKey, buckets) in ExprMemo.Where(kv => kv.Key is not Op))
{
foreach (var (typeKey, bucket) in buckets.Where(kv => kv.Value.Any()))
var graph = new EGraph();
foreach (var (exprKey, buckets) in ExprMemo.Where(kv => kv.Key is not Op))
{
Unions(graph, bucket);
foreach (var (typeKey, bucket) in buckets.Where(kv => kv.Value.Any()))
{
Unions(graph, bucket);
}
}
var root = Unions(graph, equivalents);
var post = graph.Extract(root, CompileOptions, null, Array.Empty<EGraphExtractConstrains>());
return function.With(body: post);
}
var root = Unions(graph, equivalents);
return graph.Extract(root, null, Array.Empty<EGraphExtractConstrains>());
return input;
}
protected override Dictionary<IRType, List<Expr>> DefaultVisitLeaf(Expr expr)
@ -124,16 +107,18 @@ internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRT
throw new NotSupportedException("not support auto distributed call function");
}
var isSupported = PassUtility.IsCpuSupported(op, expr.Arguments.ToArray());
foreach (var param in op.Parameters)
{
VisitLeafArgument(param.ParameterKind, expr.Arguments[param.Index]);
VisitLeafArgument(param.ParameterKind, expr.Arguments[param.Index], isSupported);
}
var results = expr.Arguments.ToArray().
Select(Visit).
CartesianProduct().
Select(args => args.ToArray()).
Select(args => BuildEquivalCalls(op, args.Select(kv => kv.Value[0]).ToArray()).ToArray()).
Select(args => isSupported ? BuildEquivalCalls(op, args.Select(kv => kv.Value[0]).ToArray()).ToArray() :
BuildNotSupportedCalls(op, args.Select(kv => kv.Value[0]).ToArray())).
SelectMany(i => i).
GroupBy(c => c.CheckedType).
ToDictionary(g => g.Key, g => new List<Expr>(g.ToList<Expr>()));
@ -157,7 +142,7 @@ internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRT
return results;
}
private Dictionary<IRType, List<Expr>> VisitLeafArgument(ParameterKind parameterKind, Expr expr)
private Dictionary<IRType, List<Expr>> VisitLeafArgument(ParameterKind parameterKind, Expr expr, bool isSupported)
{
var updateBuckets = (Dictionary<IRType, List<Expr>> buckets, IEnumerable<Expr> equivalents) =>
{
@ -179,12 +164,12 @@ internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRT
switch (parameterKind, expr)
{
case (ParameterKind.Input, Expr e) when e is Const or Var:
updateBuckets(buckets, GetLeafCandidateBoxings(e, Placement));
updateBuckets(buckets, isSupported ? GetLeafCandidateBoxings(e, Placement) : new[] { e });
break;
case (ParameterKind.Input, Expr e) when e is IR.Tuple tp:
foreach (var f in tp.Fields)
{
VisitLeafArgument(parameterKind, f);
VisitLeafArgument(parameterKind, f, isSupported);
}
foreach (var (k, v) in VisitLeafTuple(tp))
@ -206,15 +191,44 @@ internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRT
throw new InvalidOperationException();
}
}
else if (parameterKind == ParameterKind.Input)
{
if (isSupported)
{
if (!buckets.Keys.Any(IsDistributed))
{
var results = buckets.Select(kv => GetLeafCandidateBoxings(kv.Value[0], Placement)).SelectMany(i => i).ToArray();
updateBuckets(buckets, results);
}
}
else
{
if (buckets.Keys.All(IsDistributed))
{
var results = buckets.Select(kv => InstertTerminator(kv.Value[0])).ToArray();
updateBuckets(buckets, results);
}
}
}
return buckets;
}
private Call[] BuildNotSupportedCalls(Op target, Expr[] args)
{
if (target.Parameters.Where(p => p.ParameterKind == ParameterKind.Input).Any(p => IsDistributed(args[p.Index].CheckedType)))
{
return Array.Empty<Call>();
}
return new[] { new Call(target, args) };
}
private IEnumerable<Call> BuildEquivalCalls(Op target, Expr[] args)
{
if (!target.Parameters.Where(p => p.ParameterKind == ParameterKind.Input).All(p => IsDistributed(args[p.Index].CheckedType)))
{
throw new ArgumentException("the some arg have no distributed type.", nameof(args));
return Array.Empty<Call>();
}
var calls = new List<Call>();
@ -226,7 +240,7 @@ internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRT
using var pinner = new ExprPinner(args);
call.Dispose();
if (target is CPUKernelOp { Target: Reshape } || target is Reshape)
if (target is Reshape)
{
// the reshape need force boxing.
var newShape = ((TensorConst)args[1]).Value.ToArray<int>();
@ -306,6 +320,28 @@ internal sealed class AutoDistributedConvertVisitor : ExprVisitor<Dictionary<IRT
_ => false,
};
private Expr InstertTerminator(Expr expr)
{
Expr CreateFinalBoxing(Expr e, DistributedType type)
{
if (type.NdSBP.Any(s => s is SBPPartialSum))
{
var boxingP2B = IR.F.CPU.Boxing(e, new DistributedType(type.TensorType, type.NdSBP.Select(s => s is SBPPartialSum ? SBP.B : s).ToArray(), Placement));
return IR.F.CPU.Boxing(boxingP2B, type.TensorType);
}
return IR.F.CPU.Boxing(e, type.TensorType);
}
return (expr, expr.CheckedType) switch
{
(IR.Tuple tp, TupleType tptype) => new IR.Tuple(tp.Fields.ToArray().Select(InstertTerminator).ToArray()),
(Expr e, DistributedType type) => CreateFinalBoxing(e, type),
(Expr e, TensorType type) => e,
(_, _) => throw new NotSupportedException(),
};
}
private EClass Unions(EGraph graph, IEnumerable<Expr> equivalents)
{
var eids = equivalents.Select(graph.Add).ToArray();

View File

@ -0,0 +1,97 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.CommandLine;
using System.CommandLine.Invocation;
using System.Linq;
using System.Reactive;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR;
namespace Nncase.Passes;
public sealed class CPUFunctionPartitionPass : ModulePass
{
protected override Task<IRModule> RunCoreAsync(IRModule module, RunPassContext context)
{
var funcs = module.Functions.Count;
for (int i = 0; i < funcs; i++)
{
if (module.Functions[i] is Function function)
{
Function pre = function;
Function post;
// int count = 0;
while (true)
{
var rewriter = new PartitionRewriter();
post = (Function)rewriter.Rewrite(pre, default);
if (!rewriter.IsMutated)
{
break;
}
// CompilerServices.DumpDotIR(post, count++.ToString(), "/Users/lisa/Documents/nncase/tests_output");
pre = post;
}
module.Replace(i, post);
}
}
return Task.FromResult(module);
}
}
internal sealed class PartitionContext
{
private readonly HashSet<Expr> _entryPoints = new(ReferenceEqualityComparer.Instance);
public void AddEntryPoint(Expr expr)
{
_entryPoints.Add(expr);
}
public Dictionary<Expr, Expr> CreateVarMaps() => _entryPoints.Select(e => (e, new Var(e.CheckedType))).ToDictionary(kv => kv.e, kv => (Expr)kv.Item2, (IEqualityComparer<Expr>)ReferenceEqualityComparer.Instance);
}
internal sealed class PartitionVisitor : ExprVisitor<Unit, Unit, PartitionContext>
{
protected override Unit VisitLeafCall(Call expr, PartitionContext context)
{
if (expr is Call { Target: IR.CPU.Boxing { NewType: DistributedType } })
{
context.AddEntryPoint(expr.Arguments[0]);
}
return default;
}
protected override Unit DefaultVisitLeaf(Expr expr, PartitionContext context) => default;
}
internal sealed class PartitionRewriter : ExprRewriter<Unit>
{
protected override Expr RewriteLeafCall(Call expr, Unit context)
{
if (!IsMutated && expr is Call { Target: IR.CPU.Boxing { NewType: TensorType } })
{
var visitor = new PartitionVisitor();
var ctx = new PartitionContext();
visitor.Visit(expr, ctx);
var mps = ctx.CreateVarMaps();
var cloner = new ReplacingExprCloner(mps);
var post = cloner.Clone(expr, default);
var parameters = mps.Values.OfType<Var>().ToArray();
return new Call(new Function(post, parameters).With(moduleKind: Targets.CPUTarget.Kind), mps.Keys.ToArray());
}
return expr;
}
}

View File

@ -19,8 +19,6 @@ namespace Nncase.Passes;
internal sealed class CPUFusionToTirPass : ModulePass
{
private IAnalyzerManager AnalyzerManager => CompileSession.GetRequiredService<IAnalyzerManager>();
/// <inheritdoc/>
protected override Task<IRModule> RunCoreAsync(IRModule module, RunPassContext options)
{
@ -29,7 +27,7 @@ internal sealed class CPUFusionToTirPass : ModulePass
for (int i = 0; i < module.Functions.Count; i++)
{
if (module.Functions[i] is Fusion { ModuleKind: CPUTarget.Kind } fusion && fusion.Name.EndsWith("kernel"))
if (module.Functions[i] is Fusion { ModuleKind: CPUTarget.Kind } fusion)
{
// var analysis = new Dictionary<Type, IAnalysisResult>
// {

View File

@ -0,0 +1,76 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR;
namespace Nncase.Passes;
public static class PassUtility
{
public static bool IsCpuSupported(Op op)
{
if (op.GetType().Namespace == "Nncase.IR.CPU")
{
return true;
}
return op is IR.Math.Unary or IR.Math.MatMul or IR.NN.Conv2D { PadMode: PadMode.Constant } or IR.NN.Softmax or IR.NN.LayerNorm or IR.NN.InstanceNormalization or IR.Imaging.ResizeImage { IsTFResize: false } or IR.Tensors.Unsqueeze or IR.Tensors.Reshape or IR.Tensors.Slice or IR.Tensors.Concat or IR.Tensors.Transpose or IR.NN.Swish or IR.Tensors.Gather or IR.NN.Pad { PadMode: PadMode.Constant };
}
public static bool IsCpuSupported(Op op, IEnumerable<Expr> arguments)
{
if (!IsCpuSupported(op))
{
return false;
}
if (!op.Parameters.Zip(arguments).All(p => p.First.ParameterKind == ParameterKind.Input || (p.First.ParameterKind == ParameterKind.Attribute && p.Second is TensorConst)))
{
return false;
}
switch (op)
{
case IR.Imaging.ResizeImage:
if (arguments.Skip(IR.Imaging.ResizeImage.Roi.Index).First() is not IR.None)
{
return false;
}
break;
case IR.Tensors.Slice slice:
if (((TensorConst)arguments.Skip(IR.Tensors.Slice.Strides.Index).First()).Value.ToArray<int>().Any(s => s < 0))
{
return false;
}
if (((TensorConst)arguments.Skip(IR.Tensors.Slice.Begins.Index).First()).Value.ToArray<int>().Any(s => s < 0))
{
return false;
}
if (((TensorConst)arguments.Skip(IR.Tensors.Slice.Ends.Index).First()).Value.ToArray<int>().Any(s => s < 0))
{
return false;
}
break;
case IR.NN.Conv2D conv2d:
if (((TensorConst)arguments.Skip(IR.NN.Conv2D.FusedClamp.Index).First()).Value.ToArray<float>() is var clamp)
{
return clamp == new[] { float.NegativeInfinity, float.PositiveInfinity };
}
break;
default:
break;
}
return true;
}
}

View File

@ -0,0 +1,40 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Nncase.IR;
using Nncase.IR.Affine;
using Nncase.PatternMatch;
using Nncase.Targets;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.Utility;
namespace Nncase.Passes.Rules.CPU.Affine;
[RuleGenerator]
public partial class LowerSwish : RewriteRule<Pattern>
{
private int _count;
/// <inheritdoc/>
public override Pattern Pattern { get; } = PatternMatch.F.NN.IsSwish(
"swish",
"call",
IsWildcard("input") with { TypePattern = HasFixedShape() },
IsTensorConst("beta") with { TypePattern = IsFloat() & (IsScalar() | HasShape(s => s.Rank == 1 && s[0].FixedValue == 1, "scalar")) });
private Expr GetReplace(Call call, Expr input, float beta)
{
var rank = input.CheckedShape.Rank;
var bufferType = input.CheckedType switch
{
TensorType t => t,
DistributedType dt => Utilities.DistributedUtility.GetDividedTensorType(dt),
_ => throw new ArgumentOutOfRangeException(nameof(input)),
};
return IR.F.Affine.Grid(CPUTarget.Kind)
.Read(input, AffineMap.Identity(rank), out var inTile)
.Write(TIR.T.CreateBuffer(bufferType, TIR.MemoryLocation.Data, out _, $"swish_{_count++}"), AffineMap.Identity(rank), out var outTile)
.Body(TIR.F.CPU.Swish(inTile, outTile, beta))
.Build();
}
}

View File

@ -13,7 +13,6 @@ using Nncase.PatternMatch;
using Nncase.Targets;
using static Nncase.IR.F.CPU;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.Utility;
namespace Nncase.Passes.Rules.CPU.Affine;
@ -21,18 +20,27 @@ namespace Nncase.Passes.Rules.CPU.Affine;
[RuleGenerator]
public partial class LowerUnary : RewriteRule<Pattern>
{
private int _count;
/// <inheritdoc/>
public override Pattern Pattern { get; } = IsUnary(
target_name: "unary",
public override Pattern Pattern { get; } = PatternMatch.F.Math.IsUnary(
"unary",
"call",
_ => true,
IsWildcard("input") with { TypePattern = HasFixedShape() });
IsWildcard("input") with { TypePattern = HasShape(s => s.Rank > 0 && s.IsFixed, "tileable") });
private Expr GetReplace(Unary unary, Expr input)
{
var bufferType = input.CheckedType switch
{
TensorType t => t,
DistributedType dt => Utilities.DistributedUtility.GetDividedTensorType(dt),
_ => throw new ArgumentOutOfRangeException(nameof(input)),
};
var rank = input.CheckedShape.Rank;
return IR.F.Affine.Grid(CPUTarget.Kind)
.Read(input, AffineMap.Identity(rank), out var inTile)
.Write(TIR.T.CreateBuffer(input.CheckedTensorType, TIR.MemoryLocation.Data, out _), AffineMap.Identity(rank), out var outTile)
.Write(TIR.T.CreateBuffer(bufferType, TIR.MemoryLocation.Data, out _, $"unary_{_count++}"), AffineMap.Identity(rank), out var outTile)
.Body(TIR.F.CPU.Unary(unary.UnaryOp, inTile, outTile))
.Build();
}

View File

@ -1,66 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System.Reactive;
using System.Runtime.CompilerServices;
using NetFabric.Hyperlinq;
using Nncase.CodeGen;
using Nncase.IR;
using Nncase.IR.CPU;
using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using Nncase.Targets;
using static Nncase.PatternMatch.Utility;
[assembly: InternalsVisibleTo("Nncase.Tests")]
namespace Nncase.Passes.Rules;
/// <summary>
/// auto distributed the xpu fusion.
/// </summary>
[RuleGenerator]
public sealed partial class AutoPacking : IRewriteRule
{
public IPattern Pattern { get; } = IsCallWildcard("call", IsFusion("fusion", CPUTarget.Kind, IsWildcard("body"), IsVArgsRepeat("parameters", () => IsVar())));
private Expr? GetReplace(Call call, Fusion fusion, Expr body, IReadOnlyList<Expr> parameters, IReadOnlyList<Expr> callParams)
{
// 1. convert to distribute graph
if (fusion.Metadata is PackMetaData)
{
return null;
}
var rank = 1;
var lane = System.Runtime.Intrinsics.Vector256.IsHardwareAccelerated ? 8 : 4;
var newbody = CompilerServices.ERewrite(
body,
new IRewriteRule[] {
new Passes.Rules.CPU.PackSoftmax() { Rank = rank, Lane = lane },
new Passes.Rules.CPU.PackSwish() { Rank = rank, Lane = lane },
new Passes.Rules.CPU.PackLayerNorm() { Rank = rank, Lane = lane },
new Passes.Rules.CPU.PackResizeImage() { Rank = rank, Lane = lane },
new Passes.Rules.CPU.PackMatMul() { Rank = rank, Lane = lane },
new Passes.Rules.CPU.PackConv2D() { Rank = rank, Lane = lane },
new Passes.Rules.CPU.PackUnary() { Rank = rank, Lane = lane },
new Passes.Rules.CPU.PackBinary() { Rank = rank, Lane = lane },
new Passes.Rules.CPU.PackTranspose() { Rank = rank, Lane = lane },
new Passes.Rules.CPU.PackUnsqueeze() { Rank = rank, Lane = lane },
new Passes.Rules.CPU.PackReshape() { Rank = rank, Lane = lane },
new Passes.Rules.CPU.PackSlice() { Rank = rank, Lane = lane },
new Passes.Rules.Neutral.FoldConstCall(),
new Passes.Rules.CPU.FoldPackUnpack(),
new Passes.Rules.CPU.FoldPackConcatUnpack(),
},
new());
var newFusion = fusion.With(moduleKind: CPUTarget.Kind, body: newbody, parameters: parameters.Cast<Var>().ToArray());
newFusion.Metadata = new PackMetaData();
return new Call(newFusion, callParams.ToArray());
}
private sealed class PackMetaData : IR.IRMetadata
{
}
}

View File

@ -0,0 +1,687 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive;
using System.Text;
using System.Threading.Tasks;
using NetFabric.Hyperlinq;
using Nncase.CostModel;
using Nncase.IR;
using Nncase.IR.CPU;
using Nncase.IR.Math;
using Nncase.IR.NN;
using Nncase.IR.Tensors;
using Nncase.Passes.Analysis;
using Nncase.Passes.Rules.Neutral;
using Nncase.PatternMatch;
using Nncase.Targets;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.F.CPU;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.F.Tensors;
using static Nncase.PatternMatch.Utility;
using static Nncase.Utilities.ReplaceUtility;
namespace Nncase.Passes.Rules.CPU;
[RuleGenerator]
internal sealed partial class CPUOutputBoxingFusion : FusionMaker
{
public override string ModuleKind { get; } = CPUTarget.Kind;
public override Pattern Pattern { get; } = IsBoxing(
target_name: "boxing",
op => op.NewType is TensorType,
IsCallWildcard("call", IsOp<Op>("op", PassUtility.IsCpuSupported))) with
{ TypePattern = HasFixedShape() };
private Call? GetReplace(Call call, Op op, Boxing boxing, IReadOnlyList<Expr> callParams)
{
if (!PassUtility.IsCpuSupported(op, callParams))
{
return null;
}
var newInputs = new List<Expr>();
for (int i = 0; i < callParams.Count; i++)
{
if (callParams[i] is Call or Var)
{
newInputs.Add(new Var(callParams[i].CheckedType!));
}
else
{
if (callParams[i] is TensorConst { Value: Tensor { Shape.IsScalar: true } } tc)
{
newInputs.Add(Const.FromTensor(Tensor.FromBytes(tc.CheckedDataType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })));
}
else
{
newInputs.Add(callParams[i]);
}
}
}
var newCall = new Call(op, newInputs.ToArray());
var newBoxingCall = new Call(boxing, newCall);
var callFusion = new Call(new Fusion($"{op.GetType().Name}_{Count++}_kernel", ModuleKind, newBoxingCall, newInputs.OfType<Var>().ToArray()), newInputs.Select((e, i) => (e, i)).Where(p => p.e is Var).Select(p => callParams[p.i]).ToArray());
return callFusion;
}
}
[RuleGenerator]
internal sealed partial class CPUSingleFusion : FusionMaker
{
public override string ModuleKind { get; } = CPUTarget.Kind;
public override Pattern Pattern { get; } = IsCallWildcard(
"call",
IsOp<Op>("op", PassUtility.IsCpuSupported)) with
{ TypePattern = HasFixedShape() };
private Call? GetReplace(Call call, Op op, IReadOnlyList<Expr> callParams)
{
if (!PassUtility.IsCpuSupported(op, callParams))
{
return null;
}
if (op is Concat concat)
{
var tuple = (IR.Tuple)call.Arguments[0];
var tupleInputs = tuple.Fields.ToArray();
if (!tupleInputs.All(e => e is Var))
{
return null;
}
var newInputs = new List<Expr>();
for (int i = 0; i < tupleInputs.Length; i++)
{
newInputs.Add(new Var(tupleInputs[i].CheckedType!));
}
var newCall = new Call(new IR.Tensors.Concat(concat.Axis), new IR.Tuple(newInputs.ToArray()));
var callFusion = new Call(new Fusion($"{op.GetType().Name}_{Count++}_kernel", ModuleKind, newCall, newInputs.OfType<Var>().ToArray()), newInputs.Select((e, i) => (e, i)).Where(p => p.e is Var).Select(p => tupleInputs[p.i]).ToArray());
return callFusion;
}
else
{
var newInputs = new List<Expr>();
for (int i = 0; i < callParams.Count; i++)
{
if (callParams[i] is Call or Var)
{
newInputs.Add(new Var(callParams[i].CheckedType!));
}
else
{
if (callParams[i] is TensorConst { Value: Tensor { Shape.IsScalar: true } } tc)
{
newInputs.Add(Const.FromTensor(Tensor.FromBytes(tc.CheckedDataType, tc.Value.BytesBuffer.ToArray(), new[] { 1 })));
}
else
{
newInputs.Add(callParams[i]);
}
}
}
var newCall = new Call(op, newInputs.ToArray());
var callFusion = new Call(new Fusion($"{op.GetType().Name}_{Count++}_kernel", ModuleKind, newCall, newInputs.OfType<Var>().ToArray()), newInputs.Select((e, i) => (e, i)).Where(p => p.e is Var).Select(p => callParams[p.i]).ToArray());
return callFusion;
}
}
}
internal sealed class FusionCostEvaluator : Evaluator.IBaseFuncCostEvaluator
{
private readonly CompileOptions _compileOptions;
public FusionCostEvaluator(CompileOptions compileOptions)
{
_compileOptions = compileOptions;
}
public Cost VisitLeaf(IR.BaseFunction target)
{
if (target is Fusion fusion)
{
var vistor = new FusionGraphCostVisitor(_compileOptions);
vistor.Visit(fusion);
return vistor.ExprMemo.Values.Aggregate(Cost.Zero, (a, b) => a + b);
}
else
{
throw new NotSupportedException();
}
}
private sealed class GraphOpCostEvaluateContext : Evaluator.ICostEvaluateContext
{
private readonly IRType? _returnType;
private readonly IRType?[] _argumentTypes;
private readonly Expr[] _arguments;
public GraphOpCostEvaluateContext(IRType? returnType, IRType?[] argumentTypes, ReadOnlySpan<Expr> arguments, CompileOptions compileOptions)
{
_returnType = returnType;
_argumentTypes = argumentTypes;
CompileOptions = compileOptions;
_arguments = arguments.ToArray();
}
public CompileOptions CompileOptions { get; }
public T GetArgument<T>(Op op, ParameterInfo parameter)
where T : IR.BaseFunction
{
return (T)_arguments[parameter.Index];
}
public T GetArgumentType<T>(Op op, ParameterInfo parameter)
where T : IRType
{
if (op.GetType() == parameter.OwnerType)
{
return (T?)_argumentTypes[parameter.Index] ?? throw new InvalidOperationException("Run type infer first.");
}
else
{
throw new ArgumentOutOfRangeException($"Operator {op} doesn't have parameter: {parameter.Name}.");
}
}
public T GetReturnType<T>()
where T : IRType
{
return (T?)_returnType ?? throw new InvalidOperationException("Run type infer first.");
}
}
private sealed class FusionGraphCostVisitor : ExprVisitor<Cost, IRType>
{
public FusionGraphCostVisitor(CompileOptions compileOptions)
{
CompileOptions = compileOptions;
}
public CompileOptions CompileOptions { get; }
protected override Cost VisitLeafVar(Var var)
{
return new Cost()
{
[CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(var.CheckedType!),
};
}
protected override Cost DefaultVisitLeaf(Expr expr)
{
return Cost.Zero;
}
protected override Cost VisitLeafCall(Call call)
{
Cost cost;
if (call.Target is Op op)
{
var context = new GraphOpCostEvaluateContext(call.CheckedType, call.Arguments.AsValueEnumerable().Select(p => p.CheckedType).ToArray(), call.Arguments, CompileOptions);
cost = CompilerServices.EvaluateOpCost(op, context) ?? Cost.Zero;
}
else
{
throw new NotSupportedException();
}
return cost;
}
protected override Cost VisitLeafFusion(Fusion fusion)
{
var cost = new Cost()
{
[CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(fusion.Body.CheckedType!),
};
cost += fusion.Parameters.AsValueEnumerable().Select(Visit).Sum() ?? Cost.Zero;
return cost;
}
}
}
internal sealed class FusionMerger : ExprCloner<Unit>
{
private readonly Dictionary<Var, Expr> _varMap;
public FusionMerger(Dictionary<Var, Expr> varMap)
{
_varMap = varMap;
}
protected override Expr VisitLeafVar(Var v, Unit context)
{
if (_varMap.TryGetValue(v, out var new_expr))
{
return Visit(new_expr, context);
}
return v;
}
}
internal sealed class GeneralFusionMergeRule : IRewriteRule
{
private readonly Dictionary<int, Call> _mergedCache = new();
private int _count;
public IPattern Pattern { get; } =
IsCall(
"caller",
IsFusion("caller_fusion", _ => true, IsWildcard(), IsVArgsRepeat("inputs", exprs =>
{
var patterns = new Pattern[exprs.Length];
for (var i = 0; i < patterns.Length; i++)
{
patterns[i] = IsWildcard($"input_{i}");
}
return patterns;
})),
IsVArgsRepeat("callerInputs", exprs =>
{
var patterns = new Pattern[exprs.Length];
for (var i = 0; i < patterns.Length; i++)
{
patterns[i] = IsWildcard($"callee_{i}");
}
return patterns;
}));
public Expr? GetReplace(IMatchResult result, RunPassContext options)
{
var caller = (Call)result["caller"];
var caller_fusion = (Fusion)result["caller_fusion"];
var callerInputs = (IReadOnlyList<Expr>)result["callerInputs"];
var callees = new List<Call>();
var callee_fusions = new List<Fusion>();
var fusion_index = new List<int>();
for (var i = 0; i < callerInputs.Count; i++)
{
if (result[$"callee_{i}"] is Call { Target: Fusion })
{
var callee = (Call)result[$"callee_{i}"];
var callee_fusion = callee.Target as Fusion;
if (callee_fusion!.ModuleKind == caller_fusion.ModuleKind)
{
callees.Add(callee);
callee_fusions.Add(callee_fusion);
fusion_index.Add(i);
}
}
else if (result[$"callee_{i}"] is Call { Target: GetItem })
{
var expr = ((Call)result[$"callee_{i}"]).Arguments[0];
if (expr is Call { Target: Fusion } callee && ((Fusion)callee.Target)!.ModuleKind == caller_fusion.ModuleKind)
{
callees.Add(callee);
callee_fusions.Add((Fusion)callee.Target);
fusion_index.Add(i);
}
}
}
for (var i = callees.Count - 1; i >= 0; i--)
{
var callee = callees[i];
var callee_fusion = callee_fusions[i];
if (callees.Except(new[] { callee }).Any(c => c.Arguments.ToArray().Any(a => a == callee)))
{
callees.RemoveAt(i);
callee_fusions.RemoveAt(i);
fusion_index.RemoveAt(i);
}
}
if (callees.Count == 0)
{
return null;
}
var hashCodes = new List<int>
{
ReferenceEqualityComparer.Instance.GetHashCode(caller_fusion),
};
foreach (var fusion in callee_fusions)
{
hashCodes.Add(ReferenceEqualityComparer.Instance.GetHashCode(fusion));
}
var hash = default(HashCode);
foreach (var subHash in hashCodes)
{
hash.Add(subHash);
}
var hashcode = hash.ToHashCode();
if (!_mergedCache.TryGetValue(hashcode, out var new_call))
{
var multiVarMap = new Dictionary<Var, Expr>(ReferenceEqualityComparer.Instance);
for (var index = 0; index < fusion_index.Count; index++)
{
var callee = (Call)caller.Arguments[fusion_index[index]];
if (callee is Call { Target: Fusion })
{
multiVarMap.Add(caller_fusion.Parameters[fusion_index[index]], callee_fusions[index].Body);
}
else
{
var newCallee = IR.F.Tensors.GetItem(callee_fusions[index].Body, callee.Arguments[1]);
multiVarMap.Add(caller_fusion.Parameters[fusion_index[index]], newCallee);
}
}
var new_fusion_body = new FusionMerger(multiVarMap).Clone(caller_fusion.Body, default);
// remove duplicate callees
var seen = new HashSet<Expr>();
var remindIndex = Enumerable.Range(0, callerInputs.Count).ToList();
for (var i = callees.Count - 1; i >= 0; i--)
{
if (!seen.Add(callees[i]))
{
callees.RemoveAt(i);
callee_fusions.RemoveAt(i);
remindIndex.RemoveAt(fusion_index[i]);
fusion_index.RemoveAt(i);
}
}
var parameters = remindIndex.Select(i => fusion_index.Contains(i) ? callee_fusions[fusion_index.IndexOf(i)].Parameters.ToArray() : new[] { caller_fusion.Parameters[i] }).SelectMany(e => e).ToArray();
var calleeInputs = remindIndex.Select(i => fusion_index.Contains(i) ? callees[fusion_index.IndexOf(i)].Arguments.ToArray() : new[] { callerInputs[i] }).SelectMany(a => a).ToList();
var indexedParameters = parameters.Select((value, index) => new { value, index }).ToList();
parameters = parameters.Distinct().ToArray();
var distinctedList = parameters.Select(x => new { value = x, index = indexedParameters.First(i => i.value == x).index }).ToList();
var removedIndexes = indexedParameters.Where(x => !distinctedList.Any(d => d.index == x.index)).Select(x => x.index).ToList();
removedIndexes.Sort((a, b) => b.CompareTo(a));
foreach (var index in removedIndexes)
{
calleeInputs.RemoveAt(index);
}
using (new Diagnostics.DumpScope(new Diagnostics.NullDumpper()))
{
new_fusion_body = CompilerServices.ERewrite(new_fusion_body, Array.Empty<IRewriteRule>(), new(), new());
}
var merged_fusion = new Fusion($"mfusion_{_count++}_kernel", caller_fusion.ModuleKind, new_fusion_body, parameters);
new_call = new Call(merged_fusion, calleeInputs.ToArray());
_mergedCache.Add(hashcode, new_call);
}
else
{
// System.Console.WriteLine("Re Add Merged Two Fusion Call");
}
return new_call;
}
}
internal sealed class TupleFusionMergeRule : IRewriteRule
{
private readonly Dictionary<int, Call> _mergedCache = new();
public IPattern Pattern { get; } =
IsTuple(
"tuple",
IsVArgsRepeat("tupleInputs", exprs =>
{
var patterns = new Pattern[exprs.Length];
for (var i = 0; i < patterns.Length; i++)
{
patterns[i] = IsCallWildcard($"call_{i}", IsWildcard());
}
return patterns;
}));
public Expr? GetReplace(IMatchResult result, RunPassContext options)
{
var tuple = (IR.Tuple)result["tuple"];
var tupleInputs = (IReadOnlyList<Expr>)result["tupleInputs"];
var callees = new List<Call>();
var callee_fusions = new List<Fusion>();
for (var i = 0; i < tupleInputs.Count; i++)
{
if (result[$"call_{i}"] is Call { Target: Fusion } callee)
{
callees.Add(callee);
callee_fusions.Add((Fusion)callee.Target);
}
else
{
return null;
}
}
if (callee_fusions.Select(f => f.ModuleKind).Distinct().Count() > 1)
{
return null;
}
var hashCodes = new List<int>
{
ReferenceEqualityComparer.Instance.GetHashCode(tuple),
};
foreach (var fusion in callee_fusions)
{
hashCodes.Add(ReferenceEqualityComparer.Instance.GetHashCode(fusion));
}
var hash = default(HashCode);
foreach (var subHash in hashCodes)
{
hash.Add(subHash);
}
var hashcode = hash.ToHashCode();
if (!_mergedCache.TryGetValue(hashcode, out var new_call))
{
var new_fusion_body = new IR.Tuple(callee_fusions.Select(f => f.Body).ToArray());
var name = $"tuple_" + string.Join("_", callee_fusions.Select(f => f.Name).ToArray());
var parameters = callee_fusions.Select(f => f.Parameters.ToArray()).SelectMany(e => e).ToArray();
var merged_fusion = new Fusion(name, callee_fusions[0].ModuleKind, new_fusion_body, parameters);
var calleeInputs = callees.Select(c => c.Arguments.ToArray()).SelectMany(e => e).ToArray();
new_call = new Call(merged_fusion, calleeInputs.ToArray());
_mergedCache.Add(hashcode, new_call);
}
else
{
// System.Console.WriteLine("Re Add Merged Two Fusion Call");
}
return new_call;
}
}
internal sealed class ConcatFusionMergeRule : IRewriteRule
{
private readonly Dictionary<int, Call> _mergedCache = new();
public IPattern Pattern { get; } =
IsConcat(
"concat",
_ => true,
IsTuple(
"tuple",
IsVArgsRepeat("tupleInputs", exprs =>
{
var patterns = new Pattern[exprs.Length];
for (var i = 0; i < patterns.Length; i++)
{
patterns[i] = IsCallWildcard($"callee_{i}", IsWildcard());
}
return patterns;
})));
public Expr? GetReplace(IMatchResult result, RunPassContext options)
{
var concat = (IR.Tensors.Concat)result["concat"];
var tuple = (IR.Tuple)result["tuple"];
var tupleInputs = (IReadOnlyList<Expr>)result["tupleInputs"];
var callees = new List<Call>();
var callee_fusions = new List<Fusion>();
for (var i = 0; i < tupleInputs.Count; i++)
{
if (result[$"callee_{i}"] is Call { Target: Fusion } callee)
{
callees.Add(callee);
callee_fusions.Add((Fusion)callee.Target);
}
else
{
return null;
}
}
if (callee_fusions.Select(f => f.ModuleKind).Distinct().Count() > 1)
{
return null;
}
var hashCodes = new List<int>
{
ReferenceEqualityComparer.Instance.GetHashCode(tuple),
};
foreach (var fusion in callee_fusions)
{
hashCodes.Add(ReferenceEqualityComparer.Instance.GetHashCode(fusion));
}
var hash = default(HashCode);
foreach (var subHash in hashCodes)
{
hash.Add(subHash);
}
var hashcode = hash.ToHashCode();
if (!_mergedCache.TryGetValue(hashcode, out var new_call))
{
var new_fusion_body = new Call(new Concat(concat.Axis), new IR.Tuple(callee_fusions.Select(f => f.Body).ToArray()));
var name = $"concat_" + string.Join("_", callee_fusions.Select(f => f.Name).ToArray());
var parameters = callee_fusions.Select(f => f.Parameters.ToArray()).SelectMany(e => e).ToArray();
var merged_fusion = new Fusion(name, callee_fusions[0].ModuleKind, new_fusion_body, parameters);
var calleeInputs = callees.Select(c => c.Arguments.ToArray()).SelectMany(e => e).ToArray();
new_call = new Call(merged_fusion, calleeInputs.ToArray());
_mergedCache.Add(hashcode, new_call);
}
else
{
// System.Console.WriteLine("Re Add Merged Two Fusion Call");
}
return new_call;
}
}
internal sealed class DeterminedFusionMergeRule : IRewriteRule
{
private static readonly Pattern _input = IsWildcard("input");
private readonly Dictionary<int, Call> _mergedCache = new();
private int _count;
public IPattern Pattern { get; } =
IsCall(
"caller",
IsFusion("caller_fusion", _ => true, IsWildcard(), IsVArgsRepeat("inputs", exprs =>
{
var patterns = new Pattern[exprs.Length];
for (var i = 0; i < patterns.Length; i++)
{
patterns[i] = IsVar($"input_{i}");
}
return patterns;
})),
IsVArgsRepeat("callerInputs", exprs =>
{
var patterns = new Pattern[exprs.Length];
for (var i = 0; i < patterns.Length; i++)
{
patterns[i] = IsCallWildcard($"callee_{i}", IsWildcard());
}
return patterns;
}));
public Expr? GetReplace(IMatchResult result, RunPassContext options)
{
var userAnalysis = options.GetAnalysis<IExprUserAnalysisResult>();
var caller = (Call)result["caller"];
var caller_fusion = (Fusion)result["caller_fusion"];
var callerInputs = (IReadOnlyList<Expr>)result["callerInputs"];
var callees = new List<Call>();
var callee_fusions = new List<Fusion>();
var fusion_index = new List<int>();
for (var i = 0; i < callerInputs.Count; i++)
{
if (result[$"callee_{i}"] is Call { Target: Fusion } callee)
{
var callee_fusion = callee.Target as Fusion;
if (callee_fusion!.ModuleKind == caller_fusion.ModuleKind && !userAnalysis[callee].Except(new[] { caller }).Any())
{
callees.Add(callee);
callee_fusions.Add(callee_fusion);
fusion_index.Add(i);
}
}
}
if (callees.Count == 0)
{
return null;
}
var multiVarMap = new Dictionary<Var, Expr>(ReferenceEqualityComparer.Instance);
for (var index = 0; index < fusion_index.Count; index++)
{
multiVarMap.Add(caller_fusion.Parameters[fusion_index[index]], callee_fusions[index].Body);
}
var new_fusion_body = new FusionMerger(multiVarMap).Clone(caller_fusion.Body, default);
// remove duplicate callees
var seen = new HashSet<Expr>();
var remindIndex = Enumerable.Range(0, callerInputs.Count).ToList();
for (var i = callees.Count - 1; i >= 0; i--)
{
if (!seen.Add(callees[i]))
{
callees.RemoveAt(i);
callee_fusions.RemoveAt(i);
remindIndex.RemoveAt(fusion_index[i]);
fusion_index.RemoveAt(i);
}
}
var parameters = remindIndex.Select(i => fusion_index.Contains(i) ? callee_fusions[fusion_index.IndexOf(i)].Parameters.ToArray() : new[] { caller_fusion.Parameters[i] }).SelectMany(e => e).ToArray();
var merged_fusion = new Fusion($"determined_fusion_{_count++}_kernel", caller_fusion.ModuleKind, new_fusion_body, parameters);
var calleeInputs = remindIndex.Select(i => fusion_index.Contains(i) ? callees[fusion_index.IndexOf(i)].Arguments.ToArray() : new[] { callerInputs[i] }).SelectMany(a => a).ToArray();
var new_call = new Call(merged_fusion, calleeInputs.ToArray());
return new_call;
}
}

View File

@ -1,34 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR;
using Nncase.IR.Math;
using Nncase.PatternMatch;
using static Nncase.IR.F.CPU;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.Utility;
namespace Nncase.Passes.Rules.CPU;
[RuleGenerator]
public partial class LowerBinary : RewriteRule<Pattern>
{
/// <inheritdoc/>
public override Pattern Pattern { get; } = IsBinary(
target_name: "binary",
_ => true,
IsWildcard("lhs") with { TypePattern = IsFloat() & HasFixedShape() },
IsWildcard("rhs") with { TypePattern = IsFloat() & HasFixedShape() });
private Expr? GetReplace(Binary binary, Expr lhs, Expr rhs)
{
return CPUKernel(binary, lhs, rhs);
}
}

View File

@ -21,15 +21,26 @@ namespace Nncase.Passes.Rules.CPU;
public abstract class PackRule : RewriteRule<Pattern>
{
public int Lane { get; set; } = 32;
public PackRule(int rank, int lane)
{
Rank = rank;
Lane = lane;
}
public int Rank { get; set; } = 2;
public int Lane { get; }
public int Rank { get; }
public override Expr? GetReplace(IMatchResult result, RunPassContext options) => throw new NotImplementedException();
}
public sealed class PackSoftmax : PackRule
{
public PackSoftmax(int rank, int lane)
: base(rank, lane)
{
}
public override Pattern Pattern { get; } = IsSoftmax(
"target",
IsWildcard("input") with { TypePattern = IsFloat() },
@ -71,6 +82,11 @@ public sealed class PackSoftmax : PackRule
public sealed class PackResizeImage : PackRule
{
public PackResizeImage(int rank, int lane)
: base(rank, lane)
{
}
public override Pattern Pattern { get; } = IsResizeImage("target", op => op.TransformationMode == ImageResizeTransformationMode.Asymmetric && op.IsTFResize == false, IsWildcard("input"), IsNone(), IsTensorConst("newSize"), IsTensorConst("cubicCoeffA"), IsTensorConst("excludeOutside"), IsTensorConst("extrapolationValue"));
public override List<Expr> GetReplaceCandidates(IMatchResult result, RunPassContext context)
@ -101,6 +117,11 @@ public sealed class PackResizeImage : PackRule
public sealed class PackInstanceNorm : PackRule
{
public PackInstanceNorm(int rank, int lane)
: base(rank, lane)
{
}
public override Pattern Pattern { get; } = IsInstanceNormalization(
"target",
_ => true,
@ -164,6 +185,11 @@ public sealed class PackInstanceNorm : PackRule
public sealed class PackLayerNorm : PackRule
{
public PackLayerNorm(int rank, int lane)
: base(rank, lane)
{
}
public override Pattern Pattern { get; } = IsLayerNorm(
"target",
_ => true,
@ -225,6 +251,11 @@ public sealed class PackLayerNorm : PackRule
public sealed class PackMatMul : PackRule
{
public PackMatMul(int rank, int lane)
: base(rank, lane)
{
}
public override Pattern Pattern { get; } = IsMatMul(
"target",
IsWildcard("lhs") with { TypePattern = IsFloat() },
@ -247,16 +278,37 @@ public sealed class PackMatMul : PackRule
var matmul = IR.F.CPU.PackedMatMul(packedLhs, packedRhs, lhsPackedAxes, lhsPadNums, rhsPackedAxes, rhsPadNums);
var lhsAlign = System.Math.Max(lhsShape.Length, rhsShape.Length) - lhsShape.Length;
var rhsAlign = System.Math.Max(lhsShape.Length, rhsShape.Length) - rhsShape.Length;
var post = matmul;
if (lhsPackedAxes.Length == 2 && rhsPackedAxes.Length == 2)
var mPackIndex = Array.IndexOf(lhsPackedAxes, lhsShape.Length - 2);
var nPackIndex = Array.IndexOf(rhsPackedAxes, rhsShape.Length - 1);
var unpackAxes = new List<int>();
var unpadNums = new List<int>();
if (mPackIndex != -1)
{
post = PackUtility.SliceForPack(IR.F.CPU.Unpack(matmul, new[] { lhsAlign + lhsPackedAxes[0], rhsAlign + rhsPackedAxes[1] }), candidate.CheckedShape.ToValueArray(), new[] { lhsPadNums[0], rhsPadNums[1] });
unpackAxes.Add(lhsAlign + lhsPackedAxes[mPackIndex]);
unpadNums.Add(lhsPadNums[mPackIndex]);
}
if (nPackIndex != -1)
{
unpackAxes.Add(rhsAlign + rhsPackedAxes[nPackIndex]);
unpadNums.Add(rhsPadNums[nPackIndex]);
}
Expr post = matmul;
if (unpackAxes.Any())
{
post = PackUtility.SliceForPack(IR.F.CPU.Unpack(matmul, unpackAxes.ToArray()), candidate.CheckedShape.ToValueArray(), unpadNums.ToArray());
}
rets.Add(post);
}
// pack A's k and B's k
AddCandidate(new[] { lhsShape.Length - 1 }, new[] { rhsShape.Length - 2 }, new[] { Lane }, new[] { Lane });
// only pack A's m
// AddCandidate(new[] { lhsShape.Length - 2 }, Array.Empty<int>(), new[] { Lane }, Array.Empty<int>());
if (Rank > 1)
{
AddCandidate(new[] { lhsShape.Length - 2, lhsShape.Length - 1 }, new[] { rhsShape.Length - 2, rhsShape.Length - 1 }, new[] { Lane, Lane }, new[] { Lane, Lane });
@ -268,6 +320,11 @@ public sealed class PackMatMul : PackRule
public sealed class PackUnary : PackRule
{
public PackUnary(int rank, int lane)
: base(rank, lane)
{
}
public override Pattern Pattern { get; } = IsUnary(
"target",
_ => true,
@ -309,6 +366,11 @@ public sealed class PackUnary : PackRule
public sealed class PackBinary : PackRule
{
public PackBinary(int rank, int lane)
: base(rank, lane)
{
}
public override Pattern Pattern { get; } = IsBinary(
"target",
_ => true,
@ -373,6 +435,11 @@ public sealed class PackBinary : PackRule
public sealed class PackSwish : PackRule
{
public PackSwish(int rank, int lane)
: base(rank, lane)
{
}
public override Pattern Pattern { get; } = IsSwish(
"target",
IsWildcard("input") with { TypePattern = IsFloat() },
@ -411,6 +478,11 @@ public sealed class PackSwish : PackRule
public sealed class PackTranspose : PackRule
{
public PackTranspose(int rank, int lane)
: base(rank, lane)
{
}
public override Pattern Pattern { get; } = IsTranspose(
"trans",
IsWildcard("input") with { TypePattern = IsFloat() },
@ -462,6 +534,11 @@ public sealed class PackTranspose : PackRule
public sealed class PackUnsqueeze : PackRule
{
public PackUnsqueeze(int rank, int lane)
: base(rank, lane)
{
}
public override Pattern Pattern { get; } = IsUnsqueeze(
"unsq",
IsWildcard("input") with { TypePattern = IsFloat() },
@ -519,6 +596,11 @@ public sealed class PackUnsqueeze : PackRule
public sealed class PackConv2D : PackRule
{
public PackConv2D(int rank, int lane)
: base(rank, lane)
{
}
public override Pattern Pattern { get; } = IsConv2D(
"conv",
conv => conv.PadMode == PadMode.Constant,
@ -591,6 +673,11 @@ public sealed class PackConv2D : PackRule
public sealed class PackReshape : PackRule
{
public PackReshape(int rank, int lane)
: base(rank, lane)
{
}
public override Pattern Pattern { get; } = IsReshape(
"target",
IsWildcard("input") with { TypePattern = IsFloat() },
@ -687,6 +774,11 @@ public sealed class PackReshape : PackRule
public sealed class PackSlice : PackRule
{
public PackSlice(int rank, int lane)
: base(rank, lane)
{
}
public override Pattern Pattern { get; } = IsSlice(
"target",
IsWildcard("input") with { TypePattern = IsFloat() },

View File

@ -1,90 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System.Runtime.CompilerServices;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.Passes.Mutators;
using Nncase.Targets;
[assembly: InternalsVisibleTo("Nncase.Tests.CPU")]
namespace Nncase.Passes.Tile;
internal sealed class CPUSameInputFusionMergeRule : SameInputFusionMergeRule
{
public override string ModuleKind => CPUTarget.Kind;
}
internal sealed class CPUMultiInputFusionMergeRule : MultiInputFusionMergeRule
{
public override string ModuleKind => CPUTarget.Kind;
}
internal sealed class CPUShortCutFusionMergeRuleLeft : ShortCutFusionMergeRuleLeft
{
public override string ModuleKind => CPUTarget.Kind;
}
internal sealed class CPUShortCutFusionMergeRuleRight : ShortCutFusionMergeRuleRight
{
public override string ModuleKind => CPUTarget.Kind;
}
internal sealed class CPUFusionGroupMutator : FusionGroupMutator
{
private readonly Dictionary<Fusion, FusionChecker> _fusioncheckerCache;
private bool _checked;
// private readonly TileOptions _tileOptions = null!;
public CPUFusionGroupMutator(
Dictionary<Fusion, FusionChecker> fusioncheckerCache,
IMergeRewriteRule rule,
RunPassContext passOptions)
: base(rule, passOptions)
{
_fusioncheckerCache = fusioncheckerCache;
_checked = false;
}
/// <inheritdoc/>
public override bool MergedFusionCheckCallBack(Fusion mergedFusion, HashSet<Fusion> candidateFusions)
{
bool ok = false;
if (!_checked)
{
PrimTileVisitor primTileVisitor = new();
primTileVisitor.Visit(mergedFusion.Body);
var checker = new FusionChecker(primTileVisitor.TileList);
// CompilerServices.DumpDotIR(merged_fusion, "before_merge_check", PassOptions.DumpDir,true); // dump sub function.
var ret = checker.Check(mergedFusion.Body);
ok = ret.Count > 0;
// CompilerServices.DumpDotIR(merged_fusion, "after_merge_check", PassOptions.DumpDir,true); // dump sub function.
if (ok)
{
_checked = true;
_fusioncheckerCache.Add(mergedFusion, checker);
foreach (var cand in candidateFusions)
{
// release the merged fusion.
_fusioncheckerCache.Remove(cand);
}
}
}
return ok;
}
public override Expr MergedFusionRewriteCallBack(Expr mergedFusionBody)
{
using var dumpScope = new DumpScope("MergedFusionClear");
return CompilerServices.ERewrite(mergedFusionBody, new[] { new Passes.Rules.CPU.FoldStoreLoad() }, new());
}
protected override Expr RewriteLeafCall(Call expr)
{
return _checked ? expr : base.RewriteLeafCall(expr);
}
}

View File

@ -60,9 +60,13 @@ internal sealed class KernelToTIRVisitor : ExprVisitor<Unit, Unit>
{
var arguments = expr.Arguments.AsValueEnumerable().Select(GetBuffer).ToArray();
var ret = GetBuffer(expr);
var op = expr.Target is IR.CPU.CPUKernelOp kop ? kop.Target : expr.Target;
var op = expr.Target;
switch (op)
{
case PrimFunctionWrapper { Target: TIR.PrimFunction { ModuleKind: string mkind } deviceFunc } when mkind == Targets.CPUTarget.Kind:
_devices.Add(deviceFunc);
_mainBody.Add(new Call(deviceFunc, arguments.Concat(new[] { ret }).ToArray()));
break;
case Fusion deviceFunc:
{
var r = new DeviceFusionToPrimFuncRewriter(_fusionCheckCache);
@ -176,7 +180,7 @@ internal sealed class KernelToTIRVisitor : ExprVisitor<Unit, Unit>
_mainBody.Add(TIR.F.CPU.Gather(arguments[0], arguments[1], ret, gather.Axis));
break;
case IR.NN.Pad pad:
_mainBody.Add(TIR.F.CPU.Pad(arguments[0], ret, ((TensorConst)expr.Arguments[1]).Value.ToArray<int>(), ((TensorConst)expr.Arguments[2]).Value.ToScalar<float>()));
_mainBody.Add(TIR.F.CPU.Pad(arguments[0], ret, ((TensorConst)expr.Arguments[1]).Value.ToArray<int>(), ((TensorConst)expr.Arguments[2]).Value.ToArray<float>()[0]));
break;
default:
throw new NotSupportedException();
@ -271,9 +275,6 @@ internal sealed class KernelToTIRVisitor : ExprVisitor<Unit, Unit>
private void GenerateBinary(Binary binary, Buffer[] arguments, Buffer ret, Call expr)
{
_ = (DistributedType)expr.Arguments[0].CheckedType;
_ = (DistributedType)expr.Arguments[1].CheckedType;
_ = (DistributedType)expr.CheckedType;
_mainBody.Add(TIR.F.CPU.Binary(binary.BinaryOp, arguments[0], arguments[1], ret));
}

View File

@ -117,7 +117,7 @@ public partial class CPU
return new Call(new Reshape(newShape), input, ret);
}
public static Expr Swish(Buffer buffer, Buffer ret, float v)
public static Expr Swish(Expr buffer, Expr ret, float v)
{
return new Call(new Swish(v), buffer, ret);
}

View File

@ -32,4 +32,6 @@ public sealed partial class Swish : CPUKernelOp
/// Gets begins.
/// </summary>
public float Beta { get; }
public override string DisplayProperty() => $"Beta: {Beta}";
}

View File

@ -9,6 +9,25 @@ using System.Threading.Tasks;
namespace Nncase.Targets;
public enum MemoryArch : byte
{
/// <summary>
/// Unified Memory Access.
/// </summary>
UMA = 0,
/// <summary>
/// Non-Unified Memory Access.
/// </summary>
NUMA = 1,
}
public enum NocArch : byte
{
Mesh = 0,
CrossBar = 1,
}
public sealed class CpuTargetOptions : ITargetOptions
{
public string ModelName { get; set; } = string.Empty;
@ -17,6 +36,15 @@ public sealed class CpuTargetOptions : ITargetOptions
public int[] TargetTileSize { get; set; } = Array.Empty<int>();
/// <summary>
/// Gets or sets a value indicating whether Unified Memory Architecture. see https://en.wikipedia.org/wiki/Unified_Memory_Access.
/// </summary>
public bool UnifiedMemoryArchitecture { get; set; } = true;
public MemoryArch MemoryArch { get; set; } = MemoryArch.UMA;
public NocArch NocArch { get; set; } = NocArch.Mesh;
public int[] Hierarchy { get; set; } = new[] { 1 };
public string HierarchyNames { get; set; } = "b";

View File

@ -41,7 +41,7 @@ public class CPUTarget : ITarget
ITargetOptions ParseTargetCompileOptions(InvocationContext context, Command command)
{
var packing = context.ParseResult.GetValueForOption<bool>(packingOption);
var packing = context.ParseResult.GetValueForOption(packingOption);
return new CpuTargetOptions() { Packing = packing };
}
@ -61,25 +61,6 @@ public class CPUTarget : ITarget
/// <inheritdoc/>
public void RegisterTargetDependentPass(IPassManager passManager, CompileOptions options)
{
passManager.AddWithName<DataflowPass>("MakeFusion").Configure(p =>
{
p.Add<Passes.Rules.CombineMHA>();
p.Add<Passes.Rules.Neutral.FoldConstCall>();
p.Add<Passes.Rules.FuseMHA2>();
p.Add<Passes.Rules.FuseVAEDecRes>();
});
#if false
passManager.AddWithName<DataflowPass>("CPUDeviceFusion").Configure(p =>
{
p.Add<Passes.Rules.CPU.Affine.LowerUnary>();
});
#endif
passManager.AddWithName<DataflowPass>("CPUKernelFusion").Configure(p =>
{
p.Add<Passes.Rules.CPUSingleKernelFusion>();
});
}
/// <inheritdoc/>
@ -113,35 +94,67 @@ public class CPUTarget : ITarget
if (options.TargetCompileOptions is CpuTargetOptions { Packing: true })
{
passManager.AddWithName<DataflowPass>("AutoPacking").Configure(p =>
passManager.AddWithName<EGraphRulesPass>("AutoPacking").Configure(p =>
{
p.Add<Passes.Rules.AutoPacking>();
// todo config it in the target options.
var rank = 1;
var lane = System.Runtime.Intrinsics.Vector256.IsHardwareAccelerated ? 8 : 4;
p.Add<Passes.Rules.CPU.PackSoftmax>(rank, lane);
p.Add<Passes.Rules.CPU.PackSwish>(rank, lane);
p.Add<Passes.Rules.CPU.PackLayerNorm>(rank, lane);
p.Add<Passes.Rules.CPU.PackResizeImage>(rank, lane);
p.Add<Passes.Rules.CPU.PackMatMul>(rank, lane);
p.Add<Passes.Rules.CPU.PackConv2D>(rank, lane);
p.Add<Passes.Rules.CPU.PackUnary>(rank, lane);
p.Add<Passes.Rules.CPU.PackBinary>(rank, lane);
p.Add<Passes.Rules.CPU.PackTranspose>(rank, lane);
p.Add<Passes.Rules.CPU.PackUnsqueeze>(rank, lane);
p.Add<Passes.Rules.CPU.PackReshape>(rank, lane);
p.Add<Passes.Rules.CPU.PackSlice>(rank, lane);
p.Add<Passes.Rules.Neutral.FoldConstCall>();
p.Add<Passes.Rules.CPU.FoldPackUnpack>();
p.Add<Passes.Rules.CPU.FoldPackConcatUnpack>();
p.Add<Passes.Rules.Neutral.FoldTwoReshapes>();
});
}
passManager.AddWithName<DataflowPass>("AutoDistributed").Configure(p =>
// need refactor tiling.
// passManager.Add<AutoDistributedPass>();
passManager.Add<DataflowPass>().Configure(p =>
{
p.Add<Passes.Rules.AutoDistributed>();
p.Add<Passes.Rules.CPU.CPUOutputBoxingFusion>();
p.Add<Passes.Rules.CPU.CPUSingleFusion>();
});
passManager.Add<DataflowPass>().Configure(p =>
{
p.AddAnalysis<Passes.Analysis.IExprUserAnalysisResult>();
p.Add<Passes.Rules.CPU.DeterminedFusionMergeRule>();
});
passManager.AddWithName<EGraphRulesPass>("PartitionConstruct").Configure(p =>
{
p.Add<Passes.Rules.CPU.GeneralFusionMergeRule>();
});
passManager.AddWithName<EGraphExtractPass>("PartitionExtract").Configure(p =>
{
p.AddBaseFuncCostEvaluator<Passes.Rules.CPU.FusionCostEvaluator>();
});
// passManager.Add<CPUFunctionPartitionPass>();
passManager.Add<CPUFusionToModulePass>();
#if false
// FIX ME: Disable macos as macho loader is buggy.
if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
passManager.AddWithName<DataflowPass>("LowerToAffine").Configure(p =>
{
passManager.AddWithName<DataflowPass>("CPUDeviceFusion").Configure(p =>
{
p.AddAnalysis<Passes.Analysis.IExprUserAnalysisResult>();
p.Add<Passes.Rules.CPUDeviceFusion>();
});
}
#endif
p.Add<Passes.Rules.CPU.Affine.LowerUnary>();
p.Add<Passes.Rules.CPU.Affine.LowerSwish>();
});
// concat/reshape lower
// tile and lower to tir.
passManager.Add<AutoTilePass>();
passManager.Add<CPUFusionToTirPass>();
// todo add auto fusion merge pass here.
passManager.Add<PrimFuncPass>().Configure(p =>
{
p.Add<Passes.Mutators.UnFoldBlock>();

View File

@ -21,10 +21,13 @@
namespace nncase::ntt {
namespace slice_detail {
template <IsFixedDims TStart, IsFixedDims TStop, IsFixedDims TStride,
size_t... Ints>
IsFixedDims TAxes, IsFixedDims TShape, size_t... Ints>
inline constexpr auto compute_inner_domain(std::index_sequence<Ints...>) {
return fixed_shape<((TStop::at(Ints) - TStart::at(Ints)) /
TStride::at(Ints))...>{};
return fixed_shape<(
((std::min(TShape::at(TAxes::at(Ints)), TStop::at(Ints)) - 1 -
TStart::at(Ints)) /
TStride::at(Ints)) +
1)...>{};
}
} // namespace slice_detail
@ -41,11 +44,11 @@ inline constexpr auto compute_inner_domain(std::index_sequence<Ints...>) {
template <IsFixedDims TStart, IsFixedDims TStop, IsFixedDims TAxes,
IsFixedDims TStride, IsFixedTensor TIn, IsFixedTensor TOut>
void slice(const TIn &input, TOut &&output) {
constexpr auto domain = shape_infer::reduced_shape_by_axes(
typename std::decay_t<TOut>::shape_type{}, TAxes{});
constexpr auto inner_domain =
slice_detail::compute_inner_domain<TStart, TStop, TStride>(
slice_detail::compute_inner_domain<TStart, TStop, TStride, TAxes,
typename TIn::shape_type>(
std::make_index_sequence<TAxes::rank()>{});
auto in_index = ranked_shape<domain.rank()>{};

View File

@ -14,6 +14,7 @@
*/
#pragma once
#include "compiler_defs.h"
#include <cstdint>
#include <stdexcept>
#include <string_view>
#include <type_traits>

View File

@ -107,6 +107,28 @@ class NNCASE_API value_type_node : public datatype_node {
using value_type_t = object_t<value_type_node>;
class NNCASE_API vector_type_node : public datatype_node {
DEFINE_OBJECT_KIND(datatype_node, object_value_type)
public:
vector_type_node(datatype_t elemtype, dims_t lanes) noexcept
: elemtype_(elemtype), lanes_(lanes) {}
size_t size_bytes() const noexcept override {
auto acc = elemtype_->size_bytes();
for (size_t i = 0; i < lanes_.size(); i++) {
acc *= lanes_[i];
}
return acc;
}
typecode_t typecode() const noexcept override { return dt_vectortype; }
private:
datatype_t elemtype_;
dims_t lanes_;
};
using vector_type_t = object_t<vector_type_node>;
namespace detail {
template <class T> struct datatype_of {};

View File

@ -14,3 +14,4 @@ DEFINE_TYPECODE(float64, f64, 0x0C)
DEFINE_TYPECODE(bfloat16, bf16, 0x0D)
DEFINE_TYPECODE(pointer, *, 0xF0)
DEFINE_TYPECODE(valuetype, val, 0xF1)
DEFINE_TYPECODE(vectortype, vec, 0xF2)

View File

@ -82,6 +82,15 @@ result<datatype_t> deserialize_datatype_impl(TReader &sr) noexcept {
auto uuid = sr.template read_unaligned<nncase::uuid_t>();
auto size_bytes = sr.template read_unaligned<uint32_t>();
return ok<datatype_t>(value_type_t(std::in_place, uuid, size_bytes));
}
case dt_vectortype: {
checked_try_var(elem_type, deserialize_datatype(sr));
auto rank = (int32_t)sr.template read_unaligned<uint8_t>();
dims_t lanes(rank);
for (int32_t i = 0; i < rank; i++) {
lanes[i] = sr.template read_unaligned<uint32_t>();
}
return ok<datatype_t>(vector_type_t(std::in_place, elem_type, lanes));
}
// prim types
default: {

View File

@ -11,6 +11,7 @@ using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Nncase.Hosting;

View File

@ -31,6 +31,16 @@ public static class TypeSerializer
writer.Write((byte)Runtime.TypeCode.ValueType);
writer.Write(t.Uuid.ToByteArray());
writer.Write(t.SizeInBytes);
break;
case VectorType t:
writer.Write((byte)Runtime.TypeCode.VectorType);
Serialize(writer, t.ElemType);
writer.Write(checked((byte)t.Lanes.Count));
for (int i = 0; i < t.Lanes.Count; i++)
{
writer.Write(t.Lanes[i]);
}
break;
default:
throw new ArgumentException($"Unsupported datatype: {dataType}");

View File

@ -251,11 +251,6 @@ internal class Compiler : ICompiler
public void ClearFixShape(IPassManager p)
{
if (!_compileSession.CompileOptions.ShapeBucketOptions.Enable)
{
return;
}
p.AddWithName<DataflowPass>("ClearUnused").Configure(c =>
{
c.Add<FoldFixShape>();
@ -285,7 +280,12 @@ internal class Compiler : ICompiler
"TargetDependentAfterQuantPass",
progress,
token);
await RunPassAsync(p => ClearFixShape(p), "ClearFixShape", progress, token);
if (_compileSession.CompileOptions.ShapeBucketOptions.Enable)
{
await RunPassAsync(ClearFixShape, "ClearFixShape", progress, token);
}
await RunPassAsync(
p => target.RegisterTargetDependentBeforeCodeGen(p, _compileSession.CompileOptions),
"TargetDependentBeforeCodeGen",

View File

@ -118,8 +118,9 @@ public interface ICompilerServicesProvider
/// Evaluate cost of the expression tree.
/// </summary>
/// <param name="expr">Expression.</param>
/// <param name="compileOptions">options.</param>
/// <returns>Evaluate result.</returns>
Cost EvaluateCost(Expr expr);
Cost EvaluateCost(Expr expr, CompileOptions compileOptions);
/// <summary>
/// Evaluate metric of the expression tree.
@ -208,8 +209,9 @@ public interface ICompilerServicesProvider
/// <param name="expr">Expression.</param>
/// <param name="rules">Rewrite rules.</param>
/// <param name="options">Options.</param>
/// <param name="compileOptions">compileOptions.</param>
/// <returns>Rewrited expression.</returns>
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options);
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options, CompileOptions compileOptions);
/// <summary>
/// Using EGraph rewrite expression.
@ -304,10 +306,11 @@ public static class CompilerServices
/// Evaluate cost of the expression tree.
/// </summary>
/// <param name="expr">Expression.</param>
/// <param name="compileOptions">compileOptions.</param>
/// <returns>Evaluate result.</returns>
public static Cost EvaluateCost(Expr expr)
public static Cost EvaluateCost(Expr expr, CompileOptions compileOptions)
{
return Provider.EvaluateCost(expr);
return Provider.EvaluateCost(expr, compileOptions);
}
/// <summary>
@ -414,10 +417,11 @@ public static class CompilerServices
/// <param name="expr">Expression.</param>
/// <param name="rules">Rewrite rules.</param>
/// <param name="options">Options.</param>
/// <param name="compileOptions">compileOptions.</param>
/// <returns>Rewrited expression.</returns>
public static Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
public static Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options, CompileOptions compileOptions)
{
return Provider.ERewrite(expr, rules, options);
return Provider.ERewrite(expr, rules, options, compileOptions);
}
/// <summary>
@ -652,9 +656,9 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
}
/// <inheritdoc/>
public Cost EvaluateCost(Expr expr)
public Cost EvaluateCost(Expr expr, CompileOptions compileOptions)
{
return _costEvaluateProvider.EvaluateCost(expr);
return _costEvaluateProvider.EvaluateCost(expr, compileOptions);
}
/// <inheritdoc/>
@ -696,9 +700,9 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
return _targetProvider.GetTarget(name);
}
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options, CompileOptions compileOptions)
{
return _eGraphrewriteProvider.ERewrite(expr, rules, options);
return _eGraphrewriteProvider.ERewrite(expr, rules, options, compileOptions);
}
public IEGraph ERewrite(IEGraph graph, IEnumerable<IRewriteRule> rules, RunPassContext options)

View File

@ -16,6 +16,11 @@ namespace Nncase.Evaluator;
/// </summary>
public interface ICostEvaluateContext
{
/// <summary>
/// Gets the CompileOptions.
/// </summary>
public CompileOptions CompileOptions { get; }
/// <summary>
/// Get return type.
/// </summary>

View File

@ -20,8 +20,9 @@ public interface ICostEvaluateProvider
/// Evaluate cost of the expression tree.
/// </summary>
/// <param name="expr">Expression.</param>
/// <param name="compileOptions">options.</param>
/// <returns>Evaluate result.</returns>
Cost EvaluateCost(Expr expr);
Cost EvaluateCost(Expr expr, CompileOptions compileOptions);
/// <summary>
/// Evaluate cost of operator.

View File

@ -32,14 +32,19 @@ public sealed class Function : BaseFunction
{
}
public Function(string name, Expr body, ReadOnlySpan<Var> parameters, Dictionary<Var, Expr[]>? varMap)
: base(name, StackVMModuleKind, ArrayUtility.Concat(body, SpanUtility.UnsafeCast<Var, Expr>(parameters)))
public Function(string name, string moduleKind, Expr body, ReadOnlySpan<Var> parameters, Dictionary<Var, Expr[]>? varMap)
: base(name, moduleKind, ArrayUtility.Concat(body, SpanUtility.UnsafeCast<Var, Expr>(parameters)))
{
VarMap = varMap ?? new();
var dynamicDims = VarMap.Values.SelectMany(x => x).ToArray();
_pinner = new ExprPinner(dynamicDims);
}
public Function(string name, Expr body, ReadOnlySpan<Var> parameters, Dictionary<Var, Expr[]>? varMap)
: this(name, StackVMModuleKind, body, parameters, varMap)
{
}
/// <summary>
/// Initializes a new instance of the <see cref="Function"/> class.
/// build function.
@ -82,6 +87,6 @@ public sealed class Function : BaseFunction
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
=> functor.VisitFunction(this, context);
public Function With(string? name = null, Expr? body = null, Var[]? parameters = null)
=> new Function(name ?? Name, body ?? Body, parameters ?? Parameters, VarMap);
public Function With(string? name = null, string? moduleKind = null, Expr? body = null, Var[]? parameters = null)
=> new Function(name ?? Name, moduleKind ?? ModuleKind, body ?? Body, parameters ?? Parameters, VarMap);
}

View File

@ -21,7 +21,7 @@ public sealed partial class Split : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Split), 0, "input");
public static readonly ParameterInfo Input = new(typeof(Split), 0, "input", ParameterKind.Input);
/// <summary>
/// Gets axis.

View File

@ -104,6 +104,7 @@ public static partial class TypePatternUtility
x => x switch
{
TensorType ttype => shapeCond(ttype.Shape),
DistributedType distributedType => shapeCond(distributedType.TensorType.Shape),
_ => false,
},
reason);

View File

@ -36,8 +36,9 @@ public interface IEGraphRewriteProvider
/// <param name="expr">Expression.</param>
/// <param name="rules">Rewrite rules.</param>
/// <param name="options">Options.</param>
/// <param name="compileOptions">compileOptions.</param>
/// <returns>Rewrited expression.</returns>
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options);
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options, CompileOptions compileOptions);
/// <summary>
/// Rewrite egraph.

View File

@ -13,12 +13,12 @@ namespace Nncase.PatternMatch;
/// <summary>
/// Pattern for <see cref="Fusion"/>.
/// </summary>
public sealed record FusionPattern(Pattern Body, string ModuleKind, VArgsPattern Parameters, string? Name) : Pattern<Fusion>(Name)
public sealed record FusionPattern(Pattern Body, Func<string, bool> Condition, VArgsPattern Parameters, string? Name) : Pattern<Fusion>(Name)
{
/// <inheritdoc/>
protected override bool MatchLeafCore(Fusion expr)
{
return ModuleKind == expr.ModuleKind;
return Condition(expr.ModuleKind);
}
}
@ -28,9 +28,11 @@ public static partial class Utility
/// Create the Fusion pattern.
/// </summary>
/// <param name="name">name.</param>
/// <param name="module_kind">module kind.</param>
/// <param name="moduleKind">module kind.</param>
/// <param name="body">body.</param>
/// <param name="parameters">params.</param>
/// <returns>FusionPattern .</returns>
public static FusionPattern IsFusion(string? name, string module_kind, Pattern body, VArgsPattern parameters) => new FusionPattern(body, module_kind, parameters, name);
public static FusionPattern IsFusion(string? name, string moduleKind, Pattern body, VArgsPattern parameters) => new FusionPattern(body, m => m == moduleKind, parameters, name);
public static FusionPattern IsFusion(string? name, Func<string, bool> condition, Pattern body, VArgsPattern parameters) => new FusionPattern(body, condition, parameters, name);
}

View File

@ -93,4 +93,9 @@ public enum TypeCode : byte
/// <see cref="Nncase.ValueType"/>.
/// </summary>
ValueType = 0xF1,
/// <summary>
/// <see cref="Nncase.VectorType"/>.
/// </summary>
VectorType = 0xF2,
}

View File

@ -88,7 +88,7 @@ public static class T
/// <returns> for builder. </returns>
public static ISequentialBuilder<For> ForLoop(out Var loopVar, Range domain, LoopMode mode, [CallerArgumentExpression("loopVar")] string var_name = "v")
{
var newLoopVar = loopVar = new Var(var_name.StartsWith("var ") ? var_name[4..] : var_name, TensorType.Scalar(DataTypes.Int32));
var newLoopVar = loopVar = new Var(var_name.StartsWith("var ") ? var_name[4..] : var_name, domain.Start.CheckedType);
return new SequentialBuilder<For>(body => new For(newLoopVar, domain, mode, body));
}

View File

@ -560,7 +560,14 @@ internal sealed class ILPrintVisitor : ExprFunctor<string, string>
protected override string VisitBufferOf(BufferOf expr)
{
return $"bufferof({Visit(expr.Input)})";
if (_names.TryGetValue(expr, out var name))
{
return name;
}
name = $"bufferof({Visit(expr.Input)})";
_names.Add(expr, name);
return name;
}
/// <inheritdoc/>
@ -578,7 +585,8 @@ internal sealed class ILPrintVisitor : ExprFunctor<string, string>
// 1. For Loop signature
_scope.Append($"{name} = Grid({string.Join(", ", reads)})");
AppendCheckedType(expr.CheckedType, " {");
AppendCheckedType(expr.CheckedType);
_scope.IndWriteLine(" {");
using (_scope.IndentUp())
{

View File

@ -15,6 +15,7 @@ namespace Nncase.CostModel;
internal sealed class EGraphCostEvaluator
{
private readonly EClass _root;
private readonly CompileOptions _compileOptions;
private readonly Dictionary<ENode, Cost> _costs = new(ReferenceEqualityComparer.Instance);
private readonly Dictionary<EClass, Cost> _eclassCosts = new();
private readonly HashSet<EClass> _allEclasses = new();
@ -22,9 +23,10 @@ internal sealed class EGraphCostEvaluator
private readonly bool _accumulate;
private bool _changed;
public EGraphCostEvaluator(EClass root, IBaseFuncCostEvaluator? basefunc_cost_evaluator, bool accumulate = true)
public EGraphCostEvaluator(EClass root, CompileOptions compileOptions, IBaseFuncCostEvaluator? basefunc_cost_evaluator, bool accumulate = true)
{
_root = root;
_compileOptions = compileOptions;
_accumulate = accumulate;
_baseFuncCostEvaluator = basefunc_cost_evaluator;
PopulateAllEclasses(_root);
@ -163,7 +165,7 @@ internal sealed class EGraphCostEvaluator
Cost? newCost;
if (targetEnode.Expr is Op op)
{
var context = new EGraphOpCostEvaluateContext(returnType, enode.Children.Skip(1).Select(x => x.CheckedType).ToArray(), enode.Children.Skip(1).ToArray());
var context = new EGraphOpCostEvaluateContext(returnType, enode.Children.Skip(1).Select(x => x.CheckedType).ToArray(), enode.Children.Skip(1).ToArray(), _compileOptions);
newCost = CompilerServices.EvaluateOpCost(op, context);
}
else
@ -285,13 +287,16 @@ internal sealed class EGraphOpCostEvaluateContext : ICostEvaluateContext
private readonly IRType?[] _argumentTypes;
private readonly EClass[] _arguments;
public EGraphOpCostEvaluateContext(IRType? returnType, IRType?[] argumentTypes, EClass[] arguments)
public EGraphOpCostEvaluateContext(IRType? returnType, IRType?[] argumentTypes, EClass[] arguments, CompileOptions compileOptions)
{
_returnType = returnType;
_argumentTypes = argumentTypes;
_arguments = arguments;
CompileOptions = compileOptions;
}
public CompileOptions CompileOptions { get; }
public T GetArgument<T>(Op op, ParameterInfo parameter)
where T : BaseFunction
{

View File

@ -26,9 +26,10 @@ public static class EGraphExtensions
/// </summary>
/// <param name="eGraph">egraph.</param>
/// <param name="root">Root eclass.</param>
/// <param name="compileOptions">compileOptions.</param>
/// <param name="basefunc_cost_evaluator">base func cost evaluator.</param>
/// <param name="constrains">the cp model constrains.</param>
public static Expr Extract(this IEGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, EGraphExtractConstrains[]? constrains = null)
public static Expr Extract(this IEGraph eGraph, EClass root, CompileOptions compileOptions, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, EGraphExtractConstrains[]? constrains = null)
{
// 1. set enode expr with more accuracy type.
foreach (var eclass in eGraph.Classes)
@ -43,7 +44,7 @@ public static class EGraphExtensions
}
// 2. start the cost evaluator
var costModel = new CostModel.EGraphCostEvaluator(root.Find(), basefunc_cost_evaluator, false).Evaluate();
var costModel = new CostModel.EGraphCostEvaluator(root.Find(), compileOptions, basefunc_cost_evaluator, false).Evaluate();
return new EGraphExtractor(costModel).Extract(root.Find(), eGraph, constrains);
}

View File

@ -27,7 +27,7 @@ internal class EGraphRewriteProvider : IEGraphRewriteProvider
_logger = logger;
}
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options, CompileOptions compileOptions)
{
if (expr.CheckedType is null)
{
@ -36,7 +36,7 @@ internal class EGraphRewriteProvider : IEGraphRewriteProvider
var graph = new EGraph(expr);
ERewrite(graph, rules, options);
var post = graph.Extract(graph.Root!, null, Array.Empty<EGraphExtractConstrains>());
var post = graph.Extract(graph.Root!, compileOptions, null, Array.Empty<EGraphExtractConstrains>());
return post;
}

View File

@ -1,12 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Nncase.IR;
using Nncase.IR.Affine;
namespace Nncase.Evaluator;
internal sealed partial class TypeInferenceVisitor
{
protected override IRType VisitLeafGrid(Grid expr) => expr.Buffers[^1].CheckedType;
}

View File

@ -17,9 +17,10 @@ internal sealed class CostEvaluateContext : ICostEvaluateContext
private readonly Dictionary<Expr, Cost> _exprMemo;
private Call? _currentCall;
public CostEvaluateContext(Dictionary<Expr, Cost> exprMemo)
public CostEvaluateContext(Dictionary<Expr, Cost> exprMemo, CompileOptions compileOptions)
{
_exprMemo = exprMemo;
CompileOptions = compileOptions;
}
public Call CurrentCall
@ -28,6 +29,8 @@ internal sealed class CostEvaluateContext : ICostEvaluateContext
set => _currentCall = value;
}
public CompileOptions CompileOptions { get; }
public T GetArgument<T>(Op op, ParameterInfo parameter)
where T : BaseFunction
{

View File

@ -22,7 +22,7 @@ internal sealed class CostEvaluateProvider : ICostEvaluateProvider
_serviceProvider = serviceProvider;
}
public Cost EvaluateCost(Expr expr)
public Cost EvaluateCost(Expr expr, CompileOptions compileOptions)
{
if (expr.CheckedType is null)
{
@ -34,7 +34,7 @@ internal sealed class CostEvaluateProvider : ICostEvaluateProvider
throw new InvalidOperationException("Expr in Cost Evaluator need a valid type");
}
var evaluatorVisitor = new CostEvaluateVisitor();
var evaluatorVisitor = new CostEvaluateVisitor(compileOptions);
return evaluatorVisitor.Visit(expr);
}

View File

@ -15,9 +15,9 @@ internal sealed class CostEvaluateVisitor : ExprVisitor<Cost, Unit>
{
private readonly CostEvaluateContext _context;
public CostEvaluateVisitor()
public CostEvaluateVisitor(CompileOptions compileOptions)
{
_context = new CostEvaluateContext(ExprMemo);
_context = new CostEvaluateContext(ExprMemo, compileOptions);
}
/// <inheritdoc/>
@ -41,7 +41,7 @@ internal sealed class CostEvaluateVisitor : ExprVisitor<Cost, Unit>
var targetCost = expr.Target switch
{
Op op => CompilerServices.EvaluateOpCost(op, _context),
Function func => CompilerServices.EvaluateCost(func.Body),
Function func => CompilerServices.EvaluateCost(func.Body, _context.CompileOptions),
_ => throw new NotImplementedException(expr.Target.ToString()),
};
return argumentsCost + targetCost;

View File

@ -126,7 +126,7 @@ internal sealed class EvaluateVisitor : ExprVisitor<IValue, Unit>, IDisposable
{
Op op => CompilerServices.EvaluateOp(op, _context, _evaluator_cache),
Function func => CompilerServices.Evaluate(func.Body, CreateFunctionEvaluateArguments(func.Parameters, expr.Arguments), _evaluator_cache),
Fusion { ModuleKind: "stackvm" } fusion => CompilerServices.Evaluate(fusion.Body, CreateFunctionEvaluateArguments(fusion.Parameters, expr.Arguments), _evaluator_cache),
Fusion fusion => CompilerServices.Evaluate(fusion.Body, CreateFunctionEvaluateArguments(fusion.Parameters, expr.Arguments), _evaluator_cache),
_ => throw new NotImplementedException(expr.Target.ToString()),
};
}

View File

@ -35,7 +35,8 @@ public class SliceEvaluator : IEvaluator<Slice>, ITypeInferencer<Slice>, ICostEv
var ends = context.GetInt64OrtTensorArgumentValue(sl, Slice.Ends);
var axes = context.GetInt64OrtTensorArgumentValue(sl, Slice.Axes);
var strides = context.GetInt64OrtTensorArgumentValue(sl, Slice.Strides);
return OrtKI.Slice(input, begins, ends, axes, strides).ToValue();
var sliced = OrtKI.Slice(input, begins, ends, axes, strides);
return Value.FromTensor(context.CurrentCall.CheckedType is AnyType ? sliced.ToTensor() : sliced.ToTensor(context.CurrentCall.CheckedTensorType));
}
/// <inheritdoc/>

View File

@ -242,6 +242,33 @@ internal sealed partial class TypeInferenceVisitor : ExprVisitor<IRType, Unit>
return type;
}
protected override IRType VisitLeafGrid(Grid expr)
{
foreach (var p in expr.BodyParameters)
{
VerifySubField(expr, p);
}
foreach (var p in expr.AccessMaps)
{
VerifySubField(expr, p);
}
foreach (var p in expr.Buffers)
{
VerifySubField(expr, p);
}
foreach (var p in expr.Reads)
{
VerifySubField(expr, p);
}
VerifySubField(expr, expr.Body);
return expr.Buffers[^1].CheckedType;
}
protected override IRType VisitLeafAffineExpr(AffineExpr expr) => TensorType.Scalar(DataTypes.Int64);
protected override IRType VisitLeafAffineDomain(AffineDomain expr) => new TupleType(ImmutableArray.Create(expr.Offset.CheckedType, expr.Extent.CheckedType));

View File

@ -7,6 +7,7 @@ using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Nncase.Diagnostics;
using Nncase.Evaluator;
using Nncase.IR;
@ -15,16 +16,25 @@ namespace Nncase.Passes;
public sealed class EGraphExtractPass : Pass<IEGraph, BaseFunction>
{
private readonly IBaseFuncCostEvaluator? _costEvaluator;
private readonly CompileOptions _compileOptions;
private IBaseFuncCostEvaluator? _costEvaluator;
public EGraphExtractPass(IBaseFuncCostEvaluator? costEvaluator = null)
public EGraphExtractPass(CompileOptions compileOptions)
{
_costEvaluator = costEvaluator;
_compileOptions = compileOptions;
}
public void AddBaseFuncCostEvaluator<T>(params object[] parameters)
where T : IBaseFuncCostEvaluator
{
var compileSession = ((IPassIntern)this).CompileSession;
using var scope = new CompileSessionScope(compileSession);
_costEvaluator = ActivatorUtilities.CreateInstance<T>(compileSession, parameters);
}
protected override Task<BaseFunction> RunCoreAsync(IEGraph input, RunPassContext context)
{
var post = (BaseFunction)input.Extract(input.Root!, _costEvaluator, Array.Empty<EGraphExtractConstrains>());
var post = (BaseFunction)input.Extract(input.Root!, _compileOptions, _costEvaluator, Array.Empty<EGraphExtractConstrains>());
IRHelpers.DCE(post);
return Task.FromResult(post);
}

View File

@ -203,7 +203,7 @@ internal sealed class PassManager : IPassManager
if (replaced && DumpScope.Current.IsEnabled(DumpFlags.PassIR))
{
DumpScope.Current.DumpModule(module, $"Epoch_{StartPassIndex}/After");
DumpScope.Current.DumpModule(module, $"FunctionCallUpdate_{StartPassIndex}/After");
}
return module;

View File

@ -46,9 +46,9 @@ internal sealed class AffineTiler
for (int loop = 0; loop < _loopBuilders.Length; loop++)
{
var domain = schedule.Loops[loop].Domain.Offset.Position;
var begin = 0;
var end = begin + _dims[domain];
var stride = schedule.Loops[loop].TileSize;
var begin = 0ul;
var end = begin + (ulong)_dims[domain];
var stride = (ulong)schedule.Loops[loop].TileSize;
_domainExtents[domain] = stride;
_loopBuilders[loop] = T.ForLoop(out _domainOffsets[domain], (begin, end, stride), LoopMode.Serial, $"l{loop}");
}
@ -71,24 +71,47 @@ internal sealed class AffineTiler
// 3. Nest compute body
var bodyBuffers = new Expr[_grid.Buffers.Length];
var bufferOfVars = new Expr[_grid.Reads.Length + 1];
var bodyVarReplaces = new Dictionary<Expr, Expr>();
var bufferOfReplaces = new Dictionary<Expr, Expr>();
for (int i = 0; i < bodyBuffers.Length; i++)
{
(bodyBuffers[i], cntBlock) = AllocateSubBuffer(cntBlock, _tempBuffers[i], schedule.BodyBufferViews[i]);
bodyVarReplaces.Add(_grid.BodyParameters[i], bodyBuffers[i]);
}
for (int i = 0; i < _grid.Reads.Length; i++)
{
bufferOfVars[i] = AllocateBufferOf(_grid.Buffers[i], i);
bufferOfReplaces.Add(_grid.Buffers[i], bufferOfVars[i]);
}
bufferOfVars[^1] = _grid.Buffers[^1];
var cloner = new ReplacingExprCloner(bodyVarReplaces);
var nestBody = cloner.Clone(_grid.Body, default);
cntBlock.Body(nestBody);
// 4. Create PrimFunction
var body = root.Build();
var primFunc = new PrimFunction(_grid.ModuleKind, body, _grid.Buffers);
cloner = new ReplacingExprCloner(bufferOfReplaces);
body = cloner.Clone(body, default);
var primFunc = new PrimFunction(_grid.ModuleKind, body, bufferOfVars);
var wrapper = new PrimFunctionWrapper(primFunc, _grid.Buffers.Length - 1);
module.Add(primFunc);
module.Add(wrapper);
return new Call(wrapper, _grid.Buffers);
// module.Add(primFunc);
// module.Add(wrapper);
return new Call(wrapper, _grid.Reads);
}
private TIR.Buffer AllocateBufferOf(Expr expr, int i)
{
if (expr is IR.Buffers.BufferOf bufof)
{
return T.CreateBuffer(bufof.Input.CheckedTensorType, MemoryLocation.Input, out _, "buffer_" + i.ToString());
}
throw new NotSupportedException();
}
private ISequentialBuilder<Expr> AllocateTempBuffers(GridSchedule.Place place, ISequentialBuilder<Expr> sequential)

View File

@ -112,7 +112,7 @@ public class TilingSolver
var objeciveMonitor = _solver.MakeMinimize(_objective, 1);
var searchLog = _solver.MakeSearchLog(100000, objeciveMonitor);
var searchLimit = _solver.MakeImprovementLimit(_objective, false, 1, 0, 1, 2);
var timeLimit = _solver.MakeTimeLimit(5000);
var timeLimit = _solver.MakeTimeLimit(50000);
_solver.Solve(_decisionBuilder, new SearchMonitor[] { objeciveMonitor, searchLimit, timeLimit, searchLog, _solutionCollector });
@ -557,8 +557,11 @@ public class TilingSolver
}
}
var constraint = placeVar * (anyOrder ?? _solver.MakeIntConst(1)) == 0;
_solver.Add(constraint);
if (anyOrder != null)
{
var constraint = placeVar * (anyOrder ?? _solver.MakeIntConst(1)) == 0;
_solver.Add(constraint);
}
}
}

View File

@ -39,6 +39,13 @@ public sealed class TestVisitor : ExprVisitor<bool, IRType>
return count;
}
public int CountCallFusion<T>()
where T : Fusion
{
var count = ExprMemo.Keys.OfType<Call>().Where(call => call is { Target: T }).Count();
return count;
}
/// <inheritdoc/>
protected override bool DefaultVisitLeaf(Expr expr) => true;
}

View File

@ -133,7 +133,8 @@ public sealed class UnitTestDumpper : TestClassBase
new Passes.Rules.Lower.RemoveMarker(),
new TestMulToAdd(),
},
new());
new(),
CompileOptions);
Assert.True(File.Exists(Path.Join(Dumpper.Directory, "Costs/Solve.txt")));
}

File diff suppressed because it is too large Load Diff

View File

@ -48,7 +48,7 @@ public class UnitTestEGraphRewrite : TestClassBase
{
Expr pre = (Const)10 * 11 * 12;
var rule = new Passes.Rules.Neutral.ReassociateMul();
CompilerServices.ERewrite(pre, new[] { rule }, new());
CompilerServices.ERewrite(pre, new[] { rule }, new(), CompileOptions);
// Assert.Equal(newExpr, 10 * ((Const)11 * 12));
}
@ -76,7 +76,7 @@ public class UnitTestEGraphRewrite : TestClassBase
Assert.True(pre.InferenceType());
var post = CompilerServices.ERewrite(pre, new[] { new Passes.Rules.Neutral.CombineBinaryTranspose() }, new());
var post = CompilerServices.ERewrite(pre, new[] { new Passes.Rules.Neutral.CombineBinaryTranspose() }, new(), CompileOptions);
Assert.True(post.InferenceType());
Assert.Equal(pre.Evaluate(), post.Evaluate());
@ -106,7 +106,8 @@ public class UnitTestEGraphRewrite : TestClassBase
new Passes.Rules.Lower.RemoveMarker(),
new TestMulToAdd(),
},
new());
new(),
CompileOptions);
Assert.True(post.InferenceType());

View File

@ -53,14 +53,14 @@ public class UnitTestFoldSlice : TransformTestBase
new object[]
{
new[] { 4, 4, 6, 8 },
new[] { 0 }, new[] { 6 }, new[] { 3 }, new[] { -3 },
new[] { 0 }, new[] { 4 }, new[] { 2 }, new[] { -2 },
new[] { 0 }, new[] { 6 }, new[] { 3 }, new[] { 3 },
new[] { 0 }, new[] { 4 }, new[] { 2 }, new[] { 2 },
}, // negative axis
new object[]
{
new[] { 3, 4, 6, 8 },
new[] { 0 }, new[] { -1 }, new[] { 3 }, new[] { -3 },
new[] { -5 }, new[] { 4 }, new[] { 2 }, new[] { -2 },
new[] { 0 }, new[] { -1 }, new[] { 3 }, new[] { 3 },
new[] { -5 }, new[] { 4 }, new[] { 2 }, new[] { 2 },
}, // negative begin|end
}.Select((o, i) => o.Concat(new object[] { i }).ToArray());

View File

@ -60,7 +60,7 @@ public sealed class UnitTestCPUKernels : TestClassBase
public static int Rank => 1;
[Theory]
[InlineData(new object[] { new[] { 1, 512, 64, 64 }, 0 })]
[InlineData(new object[] { new[] { 32, 512, 64, 64 }, 0 })]
public async Task TestSwish(int[] shape, int count)
{
var input = new Var(new TensorType(DataTypes.Float32, shape));
@ -69,7 +69,7 @@ public sealed class UnitTestCPUKernels : TestClassBase
{ input, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 1, shape).Evaluate() },
};
var rule = new Passes.Rules.CPU.PackSwish() { Lane = Lane, Rank = Rank };
var rule = new Passes.Rules.CPU.PackSwish(Rank, Lane);
CompilerServices.TryMatch(pre, rule.Pattern, out var result);
var posts = new[] { pre }.Concat(rule.GetReplaceCandidates(result!, new Passes.RunPassContext()).Where(e => e is not Call { Target: Slice }));
await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{count}"), feedDict, posts);
@ -89,7 +89,7 @@ public sealed class UnitTestCPUKernels : TestClassBase
{ rhs, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 3, rhsShape).Evaluate() },
};
var rule = new Passes.Rules.CPU.PackBinary() { Lane = Lane, Rank = Rank };
var rule = new Passes.Rules.CPU.PackBinary(Rank, Lane);
CompilerServices.TryMatch(pre, rule.Pattern, out var result);
var posts = rule.GetReplaceCandidates(result!, new Passes.RunPassContext());
await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{count}"), feedDict, posts);
@ -112,7 +112,7 @@ public sealed class UnitTestCPUKernels : TestClassBase
{ bias, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 1, pshape).Evaluate() },
};
var rule = new Passes.Rules.CPU.PackLayerNorm() { Lane = Lane, Rank = Rank };
var rule = new Passes.Rules.CPU.PackLayerNorm(Rank, Lane);
CompilerServices.TryMatch(pre, rule.Pattern, out var result);
var posts = new[] { pre }.Concat(rule.GetReplaceCandidates(result!, new Passes.RunPassContext())).Where(e => e is not Call { Target: Slice });
await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{count}"), feedDict, posts);
@ -135,7 +135,7 @@ public sealed class UnitTestCPUKernels : TestClassBase
{ bias, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 1, pshape).Evaluate() },
};
var rule = new Passes.Rules.CPU.PackInstanceNorm() { Lane = Lane, Rank = Rank };
var rule = new Passes.Rules.CPU.PackInstanceNorm(Rank, Lane);
CompilerServices.TryMatch(pre, rule.Pattern, out var result);
var posts = new[] { pre }.Concat(rule.GetReplaceCandidates(result!, new Passes.RunPassContext())).Where(e => e is not Call { Target: Slice });
await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{count}"), feedDict, posts);
@ -153,7 +153,7 @@ public sealed class UnitTestCPUKernels : TestClassBase
{ input, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 1, shape).Evaluate() },
};
var rule = new Passes.Rules.CPU.PackResizeImage() { Lane = Lane, Rank = Rank };
var rule = new Passes.Rules.CPU.PackResizeImage(Rank, Lane);
CompilerServices.TryMatch(pre, rule.Pattern, out var result);
var posts = new[] { pre }.Concat(rule.GetReplaceCandidates(result!, new Passes.RunPassContext())).Where(e => e is not Call { Target: Slice });
await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{count}"), feedDict, posts);
@ -173,7 +173,7 @@ public sealed class UnitTestCPUKernels : TestClassBase
{ rhs, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 3, rhsShape).Evaluate() },
};
var rule = new Passes.Rules.CPU.PackMatMul() { Lane = Lane, Rank = Rank };
var rule = new Passes.Rules.CPU.PackMatMul(Rank, Lane);
CompilerServices.TryMatch(pre, rule.Pattern, out var result);
var posts = new[] { pre }.Concat(rule.GetReplaceCandidates(result!, new Passes.RunPassContext()));
await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{count}"), feedDict, posts);
@ -336,13 +336,14 @@ public sealed class UnitTestCPUKernels : TestClassBase
internal async Task RunCases(string dumpDir, Dictionary<Var, IValue> feedDict, IEnumerable<Expr> posts)
{
int count = 0;
foreach (var post in posts)
var postArray = posts.ToArray();
using var pinner = new ExprPinner(postArray);
for (int i = 0; i < postArray.Length; i++)
{
#if DEBUG
System.Console.WriteLine(CompilerServices.Print(post));
System.Console.WriteLine(CompilerServices.Print(postArray[i]));
#endif
var kernelCase = new CpuKernelCase($"Case{count++}", new Fusion("kernel", CPUTarget.Kind, post, feedDict.Keys.ToArray()), feedDict.Keys.ToArray(), feedDict.Values.Select(v => v.AsTensor()).ToArray());
var kernelCase = new CpuKernelCase($"Case{i}", new Fusion("kernel", CPUTarget.Kind, postArray[i], feedDict.Keys.ToArray()), feedDict.Keys.ToArray(), feedDict.Values.Select(v => v.AsTensor()).ToArray());
await Run(dumpDir, kernelCase);
}
}