Revert "add Razor.Templating.Core (#1169)" (#1173)

This reverts commit 2498b1ba0c.
pull/1175/head
sunnycase 2024-03-08 14:40:03 +08:00 committed by GitHub
parent 2498b1ba0c
commit bb47ea5803
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 457 additions and 300 deletions

View File

@ -2,11 +2,13 @@
<configuration>
<packageSources>
<clear />
<add key="nuget.cnblogs.com" value="https://nuget.cnblogs.com/v3/index.json" protocolVersion="3" />
<add key="nuget.org" value="https://api.nuget.org/v3/index.json" protocolVersion="3" />
<add key="design-packages" value="tools/design-packages" />
<add key="sunnycase" value="https://nuget.sunnycase.moe/v3/index.json" />
</packageSources>
<activePackageSource>
<add key="nuget.cnblogs.com" value="https://nuget.cnblogs.com/v3/index.json" protocolVersion="3" />
<add key="nuget.org" value="https://api.nuget.org/v3/index.json" protocolVersion="3" />
<add key="Nncase.Libs" value="https://www.myget.org/F/magicallibs/api/v3/index.json" protocolVersion="3" />
<add key="myget-xunit" value="https://www.myget.org/F/xunit/api/v3/index.json" />

View File

@ -121,8 +121,7 @@
"dependencies": {
"Extension.Mathematics": "[1.2.12, )",
"Nncase.Core": "[1.0.0, )",
"Nncase.IO": "[1.0.0, )",
"Razor.Templating.Core": "[1.9.0, )"
"Nncase.IO": "[1.0.0, )"
}
},
"nncase.core": {
@ -267,12 +266,6 @@
"libortki": "0.0.2"
}
},
"Razor.Templating.Core": {
"type": "CentralTransitive",
"requested": "[1.9.0, )",
"resolved": "1.9.0",
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
},
"Singulink.Collections.Weak": {
"type": "CentralTransitive",
"requested": "[1.0.2, )",

View File

@ -179,7 +179,7 @@ internal partial class Program
private static void ConfigureAppConfiguration(HostBuilderContext context, IConfigurationBuilder builder)
{
var baseDirectory = Path.GetDirectoryName(typeof(Program).Assembly.Location);
builder.SetBasePath(baseDirectory!)
builder.SetBasePath(baseDirectory)
.AddJsonFile("config.json", true, false);
}
}

View File

@ -664,8 +664,7 @@
"dependencies": {
"Extension.Mathematics": "[1.2.12, )",
"Nncase.Core": "[1.0.0, )",
"Nncase.IO": "[1.0.0, )",
"Razor.Templating.Core": "[1.9.0, )"
"Nncase.IO": "[1.0.0, )"
}
},
"nncase.compiler": {
@ -933,12 +932,6 @@
"libortki": "0.0.2"
}
},
"Razor.Templating.Core": {
"type": "CentralTransitive",
"requested": "[1.9.0, )",
"resolved": "1.9.0",
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
},
"Singulink.Collections.Weak": {
"type": "CentralTransitive",
"requested": "[1.0.2, )",

View File

@ -8,7 +8,6 @@
<ItemGroup>
<PackageReference Include="Extension.Mathematics" />
<PackageReference Include="Razor.Templating.Core" />
</ItemGroup>
<ItemGroup>

View File

@ -8,12 +8,6 @@
"resolved": "1.2.12",
"contentHash": "D4mn5Cab4ztPLJ0V8uMErDrO/Y61098nwrvyIOLZymVAYOQcwP1vomVWKbTagf1aPU3cX5Q7adZtQEQwOy6XEg=="
},
"Razor.Templating.Core": {
"type": "Direct",
"requested": "[1.9.0, )",
"resolved": "1.9.0",
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
},
"StyleCop.Analyzers": {
"type": "Direct",
"requested": "[1.2.0-beta.435, )",

View File

@ -100,6 +100,7 @@ internal class Compiler : ICompiler
p.Add<Passes.Rules.Neutral.NormAxisReshape>();
p.Add<Passes.Rules.Neutral.NormAxisReduceArg>();
p.Add<Passes.Rules.Neutral.NormAxisSlice>();
p.Add<Passes.Rules.Neutral.SwapBinaryArgs>();
p.Add<Passes.Rules.Neutral.SqueezeTransposeShape>();
p.Add<Passes.Rules.Neutral.Squeeze5DTranspose>();
p.Add<Passes.Rules.Neutral.SqueezeBinaryShape>();
@ -140,7 +141,6 @@ internal class Compiler : ICompiler
p.Add<Passes.Rules.Neutral.FoldNopReduce>();
p.Add<Passes.Rules.Neutral.SliceToGetItem>();
p.Add<Passes.Rules.Neutral.FoldTwoPads>();
p.Add<Passes.Rules.Neutral.SwapBinaryArgs>();
p.Add<Passes.Rules.Neutral.FoldDilatedConv2D>();
});

View File

@ -661,8 +661,7 @@
"dependencies": {
"Extension.Mathematics": "[1.2.12, )",
"Nncase.Core": "[1.0.0, )",
"Nncase.IO": "[1.0.0, )",
"Razor.Templating.Core": "[1.9.0, )"
"Nncase.IO": "[1.0.0, )"
}
},
"nncase.core": {
@ -881,12 +880,6 @@
"libortki": "0.0.2"
}
},
"Razor.Templating.Core": {
"type": "CentralTransitive",
"requested": "[1.9.0, )",
"resolved": "1.9.0",
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
},
"Singulink.Collections.Weak": {
"type": "CentralTransitive",
"requested": "[1.0.2, )",

View File

@ -71,21 +71,4 @@ public static class TIRUtilities
IR.F.Math.Max(0, t.First.Start),
IR.F.Math.Min(t.Second.FixedValue, t.First.Stop),
t.First.Step)).ToArray();
public static bool TryGetFixedRegions(TIR.BufferRegion region, out (int Start, int Stop, int Step)[] slice)
{
slice = new (int Start, int Stop, int Step)[region.Region.Length];
for (int i = 0; i < region.Region.Length; i++)
{
var rg = region.Region[i];
if (rg is not Range { Start: IR.TensorConst start, Stop: IR.TensorConst stop, Step: IR.TensorConst step })
{
return false;
}
slice[i] = (start.Value.ToScalar<int>(), stop.Value.ToScalar<int>(), step.Value.ToScalar<int>());
}
return true;
}
}

View File

@ -1,6 +1,7 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using GiGraph.Dot.Output.Writers.Edges;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.IR.Tensors;

View File

@ -32,25 +32,6 @@ public partial class EGraphPrinter
return printer.SaveToStream(file);
}
/// <summary>
/// find the minCostEnode in eclass.
/// <remarks>
/// the marker first.
/// </remarks>
/// </summary>
internal static ENode MinByWithMarker(EClass eClass, CostModel.EGraphCostModel costModel)
{
return eClass.Nodes.OrderBy(e => e.Expr, ENodeTypeComparer.Instance).MinBy(x => x.Expr is Marker ? CostModel.Cost.Zero : costModel[x])!;
}
/// <summary>
/// find the minCostEnode in eclass skip marker.
/// </summary>
internal static ENode MinByWithOutMarker(EClass eClass, CostModel.EGraphCostModel costModel)
{
return eClass.Nodes.Where(e => e.Expr is not Marker).MinBy(x => costModel[x])!;
}
private DotGraph AttachEGraphCost(CostModel.EGraphCostModel costModel, EClass entry)
{
// 1. display each enode costs.
@ -91,12 +72,12 @@ public partial class EGraphPrinter
continue;
}
var minCostEnode = MinByWithMarker(parent, costModel);
var minCostEnode = parent.MinByWithMarker(costModel);
// when this marker ecalss has been visited, skip it.
if (markerEclassMemo.Contains(parent))
{
minCostEnode = MinByWithOutMarker(parent, costModel);
minCostEnode = parent.MinByWithOutMarker(costModel);
}
var (minCostDotnode, table) = NodesMap[minCostEnode];
@ -112,7 +93,7 @@ public partial class EGraphPrinter
if (minCostEnode.Expr is Marker && child == parent)
{
markerEclassMemo.Add(child);
var otherminCostENode = MinByWithOutMarker(child, costModel);
var otherminCostENode = child.MinByWithOutMarker(costModel);
var (childDotNode, _) = NodesMap[otherminCostENode];
_dotGraph.Edges.Add(childDotNode, minCostDotnode, edge =>
{
@ -122,7 +103,7 @@ public partial class EGraphPrinter
}
else
{
var childEnode = MinByWithMarker(child.Find(), costModel);
var childEnode = child.Find().MinByWithMarker(costModel);
var (childDotNode, _) = NodesMap[childEnode];
_dotGraph.Edges.Add(childDotNode, minCostDotnode, edge =>
{
@ -145,23 +126,3 @@ public partial class EGraphPrinter
return _dotGraph;
}
}
internal sealed class ENodeTypeComparer : IComparer<Expr>
{
public static readonly ENodeTypeComparer Instance = new();
public int Compare(Expr? x, Expr? y) => (x, y) switch
{
(null, null) => 0,
(Expr, null) => 1,
(null, Expr) => -1,
(Expr, Expr) => GetPriority(x).CompareTo(GetPriority(y)),
};
private int GetPriority(Expr x) => x switch
{
Marker => 0,
Const => 1,
_ => 2,
};
}

View File

@ -1,50 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Google.OrTools.Sat;
using Nncase.CostModel;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.PatternMatch;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.Utility;
namespace Nncase.Passes;
/// <summary>
/// EGraph extract extensions.
/// </summary>
public static class EGraphExtensions
{
/// <summary>
/// Extract egraph.
/// </summary>
/// <param name="eGraph">egraph.</param>
/// <param name="root">Root eclass.</param>
/// <param name="basefunc_cost_evaluator">base func cost evaluator.</param>
/// <param name="constrains">the cp model constrains.</param>
public static Expr Extract(this IEGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, EGraphExtractConstrains[] constrains)
{
// 1. set enode expr with more accuracy type.
foreach (var eclass in eGraph.Classes)
{
foreach (var nodes in eclass.Nodes)
{
if (eclass.CheckedType.CompareTo(nodes.Expr.CheckedType) > 0)
{
nodes.Expr.CheckedType = eclass.CheckedType;
}
}
}
// 2. start the cost evaluator
var costModel = new CostModel.EGraphCostEvaluator(root.Find(), basefunc_cost_evaluator, false).Evaluate();
return new EGraphExtractor(costModel).Extract(root.Find(), eGraph, constrains);
}
}

View File

@ -0,0 +1,95 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Nncase.CostModel;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.PatternMatch;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.Utility;
namespace Nncase.Passes;
/// <summary>
/// EGraph extract extensions.
/// </summary>
public static class EGraphExtractExtensions
{
/// <summary>
/// Extract egraph.
/// </summary>
/// <param name="eGraph">eGraph.</param>
/// <param name="root">Root eclass.</param>
/// <param name="basefunc_cost_evaluator">base func cost evaluator.</param>
/// <param name="picks">the picks.</param>
/// <returns>Extracted root expression.</returns>
public static Expr Extract(this IEGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, out IReadOnlyDictionary<ENode, bool> picks)
{
// 1. set enode expr with more accuracy type.
foreach (var eclass in eGraph.Classes)
{
foreach (var nodes in eclass.Nodes)
{
if (eclass.CheckedType.CompareTo(nodes.Expr.CheckedType) > 0)
{
nodes.Expr.CheckedType = eclass.CheckedType;
}
}
}
// 2. start the cost evaluator
var costModel = new EGraphCostEvaluator(root.Find(), basefunc_cost_evaluator, false).Evaluate();
// if (DumpScope.Current.IsEnabled(DumpFlags.EGraphCost))
// {
// using var fs = DumpScope.Current.OpenFile(Path.Combine("Costs", $"V{eGraph.Version}.dot"));
// EGraphPrinter.DumpEgraphAsDot(eGraph, costModel, root.Find(), fs);
// }
// return new EGraphExtractor(costModel).Extract(root.Find(), eGraph);
return new EGraphExtractors.SatExtractor(costModel).Extract(root.Find(), eGraph, out picks);
}
/// <summary>
/// find the minCostEnode in eclass.
/// <remarks>
/// the marker first.
/// </remarks>
/// </summary>
internal static ENode MinByWithMarker(this EClass eClass, CostModel.EGraphCostModel costModel)
{
return eClass.Nodes.OrderBy(e => e.Expr, ENodeTypeComparer.Instance).MinBy(x => x.Expr is Marker ? Cost.Zero : costModel[x])!;
}
/// <summary>
/// find the minCostEnode in eclass skip marker.
/// </summary>
internal static ENode MinByWithOutMarker(this EClass eClass, CostModel.EGraphCostModel costModel)
{
return eClass.Nodes.Where(e => e.Expr is not Marker).MinBy(x => costModel[x])!;
}
internal sealed class ENodeTypeComparer : IComparer<Expr>
{
public static readonly ENodeTypeComparer Instance = new();
public int Compare(Expr? x, Expr? y) => (x, y) switch
{
(null, null) => 0,
(Expr, null) => 1,
(null, Expr) => -1,
(Expr, Expr) => GetPriority(x).CompareTo(GetPriority(y)),
};
private int GetPriority(Expr x) => x switch
{
Marker => 0,
Const => 1,
_ => 2,
};
}
}

View File

@ -0,0 +1,200 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Nncase.CostModel;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.PatternMatch;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.Utility;
namespace Nncase.Passes.EGraphExtractors;
internal interface IExtractor
{
Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary<ENode, bool> picks);
}
internal class Extractor : IExtractor
{
private readonly EGraphCostModel _costModel;
private readonly Dictionary<EClass, Expr> _eclassMemo = new();
private readonly Dictionary<EClass, Expr> _markerEclassMemo = new();
private readonly Dictionary<ENode, bool> _picks = new();
private StreamWriter? _dumpWriter;
public Extractor(EGraphCostModel costModel)
{
_costModel = costModel;
}
public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary<ENode, bool> picks)
{
_dumpWriter = DumpScope.Current.IsEnabled(DumpFlags.EGraphCost)
? new StreamWriter(DumpScope.Current.OpenFile($"{nameof(Extractor)}_Class_{root.Id}.txt"))
: null;
try
{
Visit(root);
}
finally
{
_dumpWriter?.Dispose();
}
foreach (var enode in eGraph.Nodes)
{
if (!_picks.ContainsKey(enode))
{
_picks[enode] = false;
}
}
picks = _picks;
return _eclassMemo[root];
}
private void Visit(EClass eclass)
{
var stack = new Stack<(EClass, ENode)>();
stack.Push((eclass, eclass.MinByWithMarker(_costModel)));
var markerEclassSet = new HashSet<EClass>();
while (stack.Any())
{
(eclass, var minCostEnode) = stack.Peek();
if (_eclassMemo.ContainsKey(eclass))
{
stack.Pop();
continue;
}
Expr? expr = null;
switch (minCostEnode.Expr)
{
case Var or TensorConst or TupleConst or Op or Fusion or None:
expr = minCostEnode.Expr;
break;
case Function or Call or IR.Tuple or Marker or IR.If:
var childrenExprs = new List<Expr>();
foreach (var child in minCostEnode.Children)
{
if (!_eclassMemo.TryGetValue(child, out var childExpr))
{
if (minCostEnode.Expr is Marker && child == eclass)
{
if (!_markerEclassMemo.TryGetValue(eclass, out var markerInputExpr))
{
markerEclassSet.Add(eclass);
stack.Push((eclass, eclass.MinByWithOutMarker(_costModel)));
}
else
{
childrenExprs.Add(markerInputExpr);
}
}
else
{
stack.Push((child, child.MinByWithMarker(_costModel)));
}
}
else
{
childrenExprs.Add(childExpr);
}
}
if (childrenExprs.Count != minCostEnode.Children.Count)
{
break;
}
expr = minCostEnode.Expr switch
{
Function function => Visit(minCostEnode, function, new(childrenExprs)),
Call call => Visit(minCostEnode, call, new(childrenExprs)),
IR.Tuple tuple => Visit(minCostEnode, tuple, new(childrenExprs)),
Marker marker => Visit(minCostEnode, marker, new(childrenExprs)),
IR.If @if => Visit(minCostEnode, @if, new(childrenExprs)),
_ => throw new ArgumentException("Unsupported expression type."),
};
break;
default:
throw new ArgumentException("Unsupported expression type.");
}
if (expr is null)
{
continue;
}
if (markerEclassSet.Contains(eclass) && minCostEnode.Expr is not Marker)
{
_markerEclassMemo.Add(eclass, expr);
}
else
{
_eclassMemo.Add(eclass, expr);
}
_picks[minCostEnode] = true;
stack.Pop();
}
}
private Marker Visit(ENode enode, Marker marker, IRArray<Expr> children)
{
var target = children[0];
var attr = children[1];
return marker.With(target: target, attribute: attr);
}
private Function Visit(ENode enode, Function func, IRArray<Expr> children)
{
if (children.Count == 0)
{
return func;
}
var body = children[0];
return func.With(body: body);
}
private IR.Tuple Visit(ENode enode, IR.Tuple tuple, IRArray<Expr> children)
{
return tuple.With(fields: children.ToArray());
}
private IR.If Visit(ENode enode, IR.If @if, IRArray<Expr> children)
{
return @if.With(condition: children[^3], then: children[^2], @else: children[^1], paramList: children[..^3].ToArray());
}
private Call Visit(ENode enode, Call call, IRArray<Expr> children)
{
var target = children[0];
var arguments = children.Skip(1);
// for mix quant debug.
if (call.EnodeQuantConfigWithCosine != null && _dumpWriter != null)
{
_dumpWriter.WriteLine(call + " " + call.CheckedType);
for (int i = 0; i < call.EnodeQuantConfigWithCosine.Count; i++)
{
for (int j = 0; j < call.EnodeQuantConfigWithCosine[i].Item1.Count; j++)
{
_dumpWriter.Write(call.EnodeQuantConfigWithCosine[i].Item1[j] + " ");
}
_dumpWriter.WriteLine(call.EnodeQuantConfigWithCosine[i].Item3);
}
}
return call.With(target: target, arguments: arguments.ToArray(), call.Metadata);
}
}

View File

@ -11,20 +11,18 @@ using Nncase.CostModel;
using Nncase.Diagnostics;
using Nncase.IR;
namespace Nncase.Passes;
namespace Nncase.Passes.EGraphExtractors;
public delegate void EGraphExtractConstrains(CpModel model, IReadOnlyDictionary<ENode, BoolVar> vars);
internal class EGraphExtractor
internal class SatExtractor : IExtractor
{
private readonly EGraphCostModel _costModel;
public EGraphExtractor(EGraphCostModel costModel)
public SatExtractor(EGraphCostModel costModel)
{
_costModel = costModel;
}
public Expr Extract(EClass root, IEGraph eGraph, EGraphExtractConstrains[] constrains)
public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary<ENode, bool> picks)
{
var cpmodel = new CpModel();
@ -70,11 +68,6 @@ internal class EGraphExtractor
EliminateAllCycles(root, new(), new(), visited, cpmodel, vars);
}
foreach (var constrain in constrains)
{
constrain(cpmodel, vars);
}
// 3. add pick weights for all enode.
cpmodel.Minimize(LinearExpr.WeightedSum(eGraph.Nodes.Select(n => vars[n]), eGraph.Nodes.Select(n => checked((long)_costModel[n].Score))));
@ -128,7 +121,7 @@ internal class EGraphExtractor
throw new InvalidProgramException("SatExtract Failed!");
}
var picks = eGraph.Nodes.ToDictionary(e => e, e => solver.BooleanValue(vars[e]));
picks = eGraph.Nodes.ToDictionary(e => e, e => solver.BooleanValue(vars[e]));
using (var dumpStream = enableDump ? DumpScope.Current.OpenFile("Costs/Pick.dot") : Stream.Null)
{
EGraphPrinter.DumpEgraphAsDot(eGraph, _costModel, picks, root.Find(), dumpStream);

View File

@ -36,7 +36,7 @@ internal class EGraphRewriteProvider : IEGraphRewriteProvider
var graph = new EGraph(expr);
ERewrite(graph, rules, options);
var post = graph.Extract(graph.Root!, null, Array.Empty<EGraphExtractConstrains>());
var post = graph.Extract(graph.Root!, null, out _);
return post;
}

View File

@ -3,34 +3,54 @@
namespace Nncase.Passes.BufferSchedule;
public sealed class Interval
internal sealed class TimeInterval
{
public Interval(int start, int end)
public TimeInterval(int start, int end)
{
Brith = start;
Death = end;
}
public int Brith { get; set; }
public int Death { get; set; }
public int Size => Death - Brith;
public override string ToString()
{
return $"TimeInterval({Brith}, {Death})";
}
}
internal sealed class MemSpan
{
public MemSpan(int start, int end)
{
Start = start;
Stop = end;
End = end;
}
public int Start { get; set; }
public int Stop { get; set; }
public int End { get; set; }
public int Size => Stop - Start;
public int Size => End - Start;
public override string ToString()
{
return $"Interval({Start}, {Stop})";
return $"MemSpan({Start}, {End})";
}
}
public class ScheduleBuffer
internal class ScheduleBuffer
{
public ScheduleBuffer(string name, int number, Interval timeInterval, Interval memInterval, int[] shape, int[] strides, bool inplace)
public ScheduleBuffer(string name, int number, TimeInterval interval, MemSpan span, int[] shape, int[] strides, bool inplace)
{
Name = name;
Number = number;
TimeInterval = timeInterval;
MemInterval = memInterval;
Interval = interval;
Span = span;
Shape = shape;
Strides = strides;
Inplace = inplace;
@ -40,9 +60,9 @@ public class ScheduleBuffer
public int Number { get; }
public Interval TimeInterval { get; }
public TimeInterval Interval { get; }
public Interval MemInterval { get; }
public MemSpan Span { get; }
public int[] Shape { get; }
@ -52,6 +72,6 @@ public class ScheduleBuffer
public override string ToString()
{
return $"ScheduledBuffer('{Name}', {Number}, {TimeInterval}, {MemInterval}, ConstraintsMode.No, [{string.Join(",", Shape)}], [{string.Join(",", Strides)}], {Inplace})";
return $"ScheduledBuffer('{Name}', {Number}, {Interval}, {Span}, ConstraintsMode.No, [{string.Join(",", Shape)}], [{string.Join(",", Strides)}], {Inplace})";
}
}

View File

@ -13,10 +13,49 @@ using Nncase.IR;
namespace Nncase.Passes.BufferSchedule;
public class BufferScheduler
internal sealed class BufferScheduler
{
public virtual void ExternalConstrains(CpModel model, IReadOnlyDictionary<Expr, ScheduleBuffer> bufferMap, IReadOnlyDictionary<Expr, (IntervalVar X, IntervalVar Y)> boxs)
public IReadOnlyDictionary<Expr, ScheduleBuffer> CollectLifeTime(Function func)
{
var c = new LifeTimeCollector();
return c.Collect(func);
}
public void Schedule(IReadOnlyDictionary<Expr, ScheduleBuffer> bufferMap)
{
var model = new CpModel();
var noOverlap = model.AddNoOverlap2D();
var boxs = new Dictionary<Expr, (IntervalVar X, IntervalVar Y)>(ReferenceEqualityComparer.Instance);
var timeMap = new Dictionary<int, List<Expr>>();
var yStarts = new List<IntVar>();
foreach (var (expr, item) in bufferMap)
{
var xInterval = model.NewIntervalVar(model.NewConstant(item.Interval.Brith), model.NewConstant(item.Interval.Size), model.NewConstant(item.Interval.Death), item.Name + $"{item.Number}_x");
var upbound = 2147483648 - item.Span.End;
if (upbound <= 0)
{
throw new System.NotSupportedException();
}
var memStartVar = model.NewIntVar(0, upbound, $"{item.Name}_{item.Number}_y_start");
var yInterval = model.NewFixedSizeIntervalVar(memStartVar, item.Span.End, $"{item.Name}_{item.Number}_y");
noOverlap.AddRectangle(xInterval, yInterval);
yStarts.Add(memStartVar);
boxs.Add(expr, (xInterval, yInterval));
for (int time = item.Interval.Brith; time < item.Interval.Death; time++)
{
if (!timeMap.TryGetValue(time, out var timelist))
{
timelist = new();
timeMap.Add(time, timelist);
}
timelist.Add(expr);
}
}
foreach (var (expr, item) in bufferMap)
{
if (expr is Call { Target: IR.Tensors.Concat } concatCall && concatCall.Arguments[0] is IR.Tuple tuple)
@ -26,7 +65,7 @@ public class BufferScheduler
for (int i = 0; i < tuple.Fields.Length; i++)
{
model.Add((boxs[concatCall].Y.StartExpr() + offset) == boxs[tuple.Fields[i]].Y.StartExpr());
offset += bufferMap[tuple.Fields[i]].MemInterval.Size;
offset += bufferMap[tuple.Fields[i]].Span.Size;
}
}
else if (expr is Call { Target: IR.Tensors.Split } splitCall)
@ -40,7 +79,7 @@ public class BufferScheduler
foreach (var user in users.OrderBy(e => ((Call)e).Arguments[1].Evaluate().AsTensor().ToScalar<int>()))
{
model.Add((boxs[splitCall].Y.StartExpr() + offset) == boxs[user].Y.StartExpr());
offset += bufferMap[user].MemInterval.Size;
offset += bufferMap[user].Span.Size;
}
}
else if (expr is Call { Target: IR.Tensors.Reshape } reshapCall)
@ -49,44 +88,6 @@ public class BufferScheduler
model.Add(boxs[reshapCall].Y.StartExpr() == boxs[reshapCall.Arguments[0]].Y.StartExpr());
}
}
}
public void Schedule(IReadOnlyDictionary<Expr, ScheduleBuffer> bufferMap)
{
var model = new CpModel();
var noOverlap = model.AddNoOverlap2D();
var boxs = new Dictionary<Expr, (IntervalVar X, IntervalVar Y)>(ReferenceEqualityComparer.Instance);
var timeMap = new Dictionary<int, List<Expr>>();
var yStarts = new List<IntVar>();
foreach (var (expr, item) in bufferMap)
{
var xInterval = model.NewIntervalVar(model.NewConstant(item.TimeInterval.Start), model.NewConstant(item.TimeInterval.Size), model.NewConstant(item.TimeInterval.Stop), item.Name + $"{item.Number}_x");
var upbound = 2147483648 - item.MemInterval.Stop;
if (upbound <= 0)
{
throw new System.NotSupportedException();
}
var memStartVar = model.NewIntVar(0, upbound, $"{item.Name}_{item.Number}_y_start");
var yInterval = model.NewFixedSizeIntervalVar(memStartVar, item.MemInterval.Stop, $"{item.Name}_{item.Number}_y");
noOverlap.AddRectangle(xInterval, yInterval);
yStarts.Add(memStartVar);
boxs.Add(expr, (xInterval, yInterval));
for (int time = item.TimeInterval.Start; time < item.TimeInterval.Stop; time++)
{
if (!timeMap.TryGetValue(time, out var timelist))
{
timelist = new();
timeMap.Add(time, timelist);
}
timelist.Add(expr);
}
}
ExternalConstrains(model, bufferMap, boxs);
model.Minimize(LinearExpr.Sum(yStarts));
@ -98,10 +99,10 @@ public class BufferScheduler
throw new System.NotSupportedException();
}
foreach (var (k, _) in bufferMap)
foreach (var (k, v) in bufferMap)
{
bufferMap[k].MemInterval.Start = checked((int)solver.Value(boxs[k].Y.StartExpr()));
bufferMap[k].MemInterval.Stop = checked((int)solver.Value(boxs[k].Y.EndExpr()));
bufferMap[k].Span.Start = checked((int)solver.Value(boxs[k].Y.StartExpr()));
bufferMap[k].Span.End = checked((int)solver.Value(boxs[k].Y.EndExpr()));
}
}
@ -118,11 +119,18 @@ from enum import Enum
from typing import List
@dataclass
class Interval():
class TimeInterval():
start: int
end: int
def __str__(self) -> str:
return f'(start: {self.start}, end {self.end}, size {self.end - self.start})'
return f'(start: {self.start}, end {self.end})'
@dataclass
class MemSpan():
depth_start: int
depth_end: int
def __str__(self) -> str:
return f'(start: {self.depth_start}, size {self.depth_end - self.depth_start})'
class ConstraintsMode(Enum):
No = 0
@ -132,8 +140,8 @@ class ConstraintsMode(Enum):
class ScheduledBuffer():
name: str
number: int
time_interval: Interval
mem_interval: Interval
interval: TimeInterval
location: MemSpan
constraints: ConstraintsMode
shape: List[int]
stride: List[int]
@ -158,8 +166,8 @@ source = {
'height': [],
'alpha': [],
'color': [],
'mem_interval': [],
'time_interval': [],
'location': [],
'interval': [],
'shape': [],
'stride': [],
}
@ -169,10 +177,10 @@ x_range_max = 0
color_dict = {}
for buffer in buffers:
source['name'].append(buffer.name)
width = buffer.time_interval.end - buffer.time_interval.start
x = buffer.time_interval.start + (width / 2)
height = buffer.mem_interval.end - buffer.mem_interval.start
y = buffer.mem_interval.start + (height / 2)
width = buffer.interval.end - buffer.interval.start
x = buffer.interval.start + (width / 2)
height = buffer.location.depth_end - buffer.location.depth_start
y = buffer.location.depth_start + (height / 2)
y_range_max = max(y_range_max, y)
x_range_max = max(x_range_max, buffer.interval.end)
source['x'].append(x)
@ -185,13 +193,13 @@ for buffer in buffers:
color_dict[buffer.name] = color
source['color'].append(color)
source['alpha'].append(0.2 if buffer.inplace else 1.0)
source['time_interval'].append(str(buffer.time_interval))
source['mem_interval'].append(str(buffer.mem_interval))
source['interval'].append(str(buffer.interval))
source['location'].append(str(buffer.location))
source['shape'].append(','.join([str(s) for s in buffer.shape]))
source['stride'].append(','.join([str(s) for s in buffer.stride]))
source = ColumnDataSource(source)
hover = HoverTool(tooltips=[('name', '@name'), ('time_interval', '@time_interval'), ('mem_interval', '@mem_interval'),
hover = HoverTool(tooltips=[('name', '@name'), ('interval', '@interval'), ('location', '@location'),
('shape', '@shape'), ('stride', '@stride')])
p = figure(tools=[hover, WheelPanTool(), SaveTool(), WheelZoomTool(), ResetTool()], width=1280, height=720,

View File

@ -10,16 +10,16 @@ using Nncase.IR;
namespace Nncase.Passes.BufferSchedule;
public class LifeTimeCollector : ExprVisitor<Unit, Unit>
internal sealed class LifeTimeCollector : ExprVisitor<Unit, Unit>
{
public int TimeStamp { get; private set; }
public Dictionary<Expr, Interval> LifenessMap { get; } = new(ReferenceEqualityComparer.Instance);
public Dictionary<Expr, TimeInterval> LifenessMap { get; } = new(ReferenceEqualityComparer.Instance);
public IReadOnlyDictionary<Expr, ScheduleBuffer> Collect(Expr expr)
public IReadOnlyDictionary<Expr, ScheduleBuffer> Collect(Function entry)
{
Visit(expr);
Update(expr); // avoid final call time interval size == 1.
Visit(entry.Body);
Update(entry.Body); // avoid final call time interval size == 1.
Alias();
var d = new Dictionary<Expr, ScheduleBuffer>(ReferenceEqualityComparer.Instance);
@ -32,7 +32,8 @@ public class LifeTimeCollector : ExprVisitor<Unit, Unit>
Var va => va.Name,
_ => k.GetType().Name,
};
var size = ComputeBufferSize(k.CheckedType, out var shape, out var stride);
var size = GetSize(k.CheckedType, out var shape, out var stride);
d.Add(k, new(name, count++, v, new(0, size), shape, stride, false));
}
@ -61,29 +62,6 @@ public class LifeTimeCollector : ExprVisitor<Unit, Unit>
return Unit.Default;
}
protected virtual int ComputeBufferSize(IRType type, out int[] shape, out int[] stride)
{
shape = Array.Empty<int>();
stride = Array.Empty<int>();
var size = 0;
if (type is TensorType tensorType)
{
shape = tensorType.Shape.ToValueArray();
stride = TensorUtilities.GetStrides(shape);
size = TensorUtilities.GetSize(shape, stride, tensorType.DType.SizeInBytes);
}
else if (type is TupleType tupleType)
{
size = 0;
foreach (var item in tupleType)
{
size += ComputeBufferSize(item, out _, out _);
}
}
return size;
}
private void Update(Expr expr)
{
if (expr is Const or None)
@ -107,7 +85,7 @@ public class LifeTimeCollector : ExprVisitor<Unit, Unit>
}
else
{
interval.Stop = TimeStamp + 1;
interval.Death = TimeStamp + 1;
}
LifenessMap[expr] = interval;
@ -145,12 +123,12 @@ public class LifeTimeCollector : ExprVisitor<Unit, Unit>
} while (changed);
}
private bool AliasTime(Call call, Interval interval)
private bool AliasTime(Call call, TimeInterval interval)
{
var brith = call.GetArguments().Select(arg => LifenessMap[arg].Stop).Concat(new[] { interval.Start }).Max();
var death = call.GetUsers().Select(usr => LifenessMap[usr].Start).Concat(new[] { interval.Stop }).Min();
var brith = call.GetArguments().Select(arg => LifenessMap[arg].Death).Concat(new[] { interval.Brith }).Max();
var death = call.GetUsers().Select(usr => LifenessMap[usr].Brith).Concat(new[] { interval.Death }).Min();
if (brith == interval.Start && death == interval.Stop)
if (brith == interval.Brith && death == interval.Death)
{
return false;
}
@ -160,8 +138,31 @@ public class LifeTimeCollector : ExprVisitor<Unit, Unit>
throw new InvalidOperationException();
}
interval.Start = brith;
interval.Stop = death;
interval.Brith = brith;
interval.Death = death;
return true;
}
private int GetSize(IRType type, out int[] shape, out int[] stride)
{
shape = Array.Empty<int>();
stride = Array.Empty<int>();
var size = 0;
if (type is TensorType tensorType)
{
shape = tensorType.Shape.ToValueArray();
stride = TensorUtilities.GetStrides(shape);
size = TensorUtilities.GetSize(shape, stride, tensorType.DType.SizeInBytes);
}
else if (type is TupleType tupleType)
{
size = 0;
foreach (var item in tupleType)
{
size += GetSize(item, out _, out _);
}
}
return size;
}
}

View File

@ -46,8 +46,7 @@ public sealed class DDrBufferSchdeulePass : ModulePass
if (module.Entry is Function { ModuleKind: Callable.StackVMModuleKind, Body: Expr body } func && IsFixedType(body.CheckedType))
{
var sch = new BufferSchedule.BufferScheduler();
var c = new BufferSchedule.LifeTimeCollector();
var buffers = c.Collect(func.Body);
var buffers = sch.CollectLifeTime(func);
sch.Schedule(buffers);
using (var fs = Diagnostics.DumpScope.Current.OpenFile("draw_buffers.py"))
{

View File

@ -24,7 +24,7 @@ public sealed class EGraphExtractPass : Pass<IEGraph, BaseFunction>
protected override Task<BaseFunction> RunCoreAsync(IEGraph input, RunPassContext context)
{
var post = (BaseFunction)input.Extract(input.Root!, _costEvaluator, Array.Empty<EGraphExtractConstrains>());
var post = (BaseFunction)input.Extract(input.Root!, _costEvaluator, out _);
IRHelpers.DCE(post);
return Task.FromResult(post);
}

View File

@ -916,8 +916,7 @@
"dependencies": {
"Extension.Mathematics": "[1.2.12, )",
"Nncase.Core": "[1.0.0, )",
"Nncase.IO": "[1.0.0, )",
"Razor.Templating.Core": "[1.9.0, )"
"Nncase.IO": "[1.0.0, )"
}
},
"nncase.compiler": {
@ -1200,12 +1199,6 @@
"libortki": "0.0.2"
}
},
"Razor.Templating.Core": {
"type": "CentralTransitive",
"requested": "[1.9.0, )",
"resolved": "1.9.0",
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
},
"Singulink.Collections.Weak": {
"type": "CentralTransitive",
"requested": "[1.0.2, )",

View File

@ -65,8 +65,7 @@
"dependencies": {
"Extension.Mathematics": "[1.2.12, )",
"Nncase.Core": "[1.0.0, )",
"Nncase.IO": "[1.0.0, )",
"Razor.Templating.Core": "[1.9.0, )"
"Nncase.IO": "[1.0.0, )"
}
},
"nncase.core": {
@ -154,12 +153,6 @@
"System.Runtime.CompilerServices.Unsafe": "5.0.0"
}
},
"Razor.Templating.Core": {
"type": "CentralTransitive",
"requested": "[1.9.0, )",
"resolved": "1.9.0",
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
},
"System.CommandLine": {
"type": "CentralTransitive",
"requested": "[2.0.0-beta4.22272.1, )",

View File

@ -718,8 +718,7 @@
"dependencies": {
"Extension.Mathematics": "[1.2.12, )",
"Nncase.Core": "[1.0.0, )",
"Nncase.IO": "[1.0.0, )",
"Razor.Templating.Core": "[1.9.0, )"
"Nncase.IO": "[1.0.0, )"
}
},
"nncase.compiler": {
@ -1002,12 +1001,6 @@
"libortki": "0.0.2"
}
},
"Razor.Templating.Core": {
"type": "CentralTransitive",
"requested": "[1.9.0, )",
"resolved": "1.9.0",
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
},
"Singulink.Collections.Weak": {
"type": "CentralTransitive",
"requested": "[1.0.2, )",

View File

@ -47,10 +47,10 @@ public sealed class UnitTestEGraphCostModel
},
};
Assert.IsType<TensorConst>(list.OrderBy(e => e, ENodeTypeComparer.Instance).First());
Assert.IsType<TensorConst>(list.OrderBy(e => e, EGraphExtractExtensions.ENodeTypeComparer.Instance).First());
Assert.True(cost[b] < cost[c]);
Assert.IsType<TensorConst>(list.OrderBy(e => e, ENodeTypeComparer.Instance).MinBy(e => cost[e]));
Assert.IsType<TensorConst>(list.OrderBy(e => e, EGraphExtractExtensions.ENodeTypeComparer.Instance).MinBy(e => cost[e]));
}
}

View File

@ -839,8 +839,7 @@
"dependencies": {
"Extension.Mathematics": "[1.2.12, )",
"Nncase.Core": "[1.0.0, )",
"Nncase.IO": "[1.0.0, )",
"Razor.Templating.Core": "[1.9.0, )"
"Nncase.IO": "[1.0.0, )"
}
},
"nncase.compiler": {
@ -1095,12 +1094,6 @@
"libortki": "0.0.2"
}
},
"Razor.Templating.Core": {
"type": "CentralTransitive",
"requested": "[1.9.0, )",
"resolved": "1.9.0",
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
},
"Singulink.Collections.Weak": {
"type": "CentralTransitive",
"requested": "[1.0.2, )",