mirror of https://github.com/kendryte/nncase.git
parent
2498b1ba0c
commit
bb47ea5803
|
@ -2,11 +2,13 @@
|
|||
<configuration>
|
||||
<packageSources>
|
||||
<clear />
|
||||
<add key="nuget.cnblogs.com" value="https://nuget.cnblogs.com/v3/index.json" protocolVersion="3" />
|
||||
<add key="nuget.org" value="https://api.nuget.org/v3/index.json" protocolVersion="3" />
|
||||
<add key="design-packages" value="tools/design-packages" />
|
||||
<add key="sunnycase" value="https://nuget.sunnycase.moe/v3/index.json" />
|
||||
</packageSources>
|
||||
<activePackageSource>
|
||||
<add key="nuget.cnblogs.com" value="https://nuget.cnblogs.com/v3/index.json" protocolVersion="3" />
|
||||
<add key="nuget.org" value="https://api.nuget.org/v3/index.json" protocolVersion="3" />
|
||||
<add key="Nncase.Libs" value="https://www.myget.org/F/magicallibs/api/v3/index.json" protocolVersion="3" />
|
||||
<add key="myget-xunit" value="https://www.myget.org/F/xunit/api/v3/index.json" />
|
||||
|
|
|
@ -121,8 +121,7 @@
|
|||
"dependencies": {
|
||||
"Extension.Mathematics": "[1.2.12, )",
|
||||
"Nncase.Core": "[1.0.0, )",
|
||||
"Nncase.IO": "[1.0.0, )",
|
||||
"Razor.Templating.Core": "[1.9.0, )"
|
||||
"Nncase.IO": "[1.0.0, )"
|
||||
}
|
||||
},
|
||||
"nncase.core": {
|
||||
|
@ -267,12 +266,6 @@
|
|||
"libortki": "0.0.2"
|
||||
}
|
||||
},
|
||||
"Razor.Templating.Core": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[1.9.0, )",
|
||||
"resolved": "1.9.0",
|
||||
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
|
||||
},
|
||||
"Singulink.Collections.Weak": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[1.0.2, )",
|
||||
|
|
|
@ -179,7 +179,7 @@ internal partial class Program
|
|||
private static void ConfigureAppConfiguration(HostBuilderContext context, IConfigurationBuilder builder)
|
||||
{
|
||||
var baseDirectory = Path.GetDirectoryName(typeof(Program).Assembly.Location);
|
||||
builder.SetBasePath(baseDirectory!)
|
||||
builder.SetBasePath(baseDirectory)
|
||||
.AddJsonFile("config.json", true, false);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -664,8 +664,7 @@
|
|||
"dependencies": {
|
||||
"Extension.Mathematics": "[1.2.12, )",
|
||||
"Nncase.Core": "[1.0.0, )",
|
||||
"Nncase.IO": "[1.0.0, )",
|
||||
"Razor.Templating.Core": "[1.9.0, )"
|
||||
"Nncase.IO": "[1.0.0, )"
|
||||
}
|
||||
},
|
||||
"nncase.compiler": {
|
||||
|
@ -933,12 +932,6 @@
|
|||
"libortki": "0.0.2"
|
||||
}
|
||||
},
|
||||
"Razor.Templating.Core": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[1.9.0, )",
|
||||
"resolved": "1.9.0",
|
||||
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
|
||||
},
|
||||
"Singulink.Collections.Weak": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[1.0.2, )",
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Extension.Mathematics" />
|
||||
<PackageReference Include="Razor.Templating.Core" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
|
|
@ -8,12 +8,6 @@
|
|||
"resolved": "1.2.12",
|
||||
"contentHash": "D4mn5Cab4ztPLJ0V8uMErDrO/Y61098nwrvyIOLZymVAYOQcwP1vomVWKbTagf1aPU3cX5Q7adZtQEQwOy6XEg=="
|
||||
},
|
||||
"Razor.Templating.Core": {
|
||||
"type": "Direct",
|
||||
"requested": "[1.9.0, )",
|
||||
"resolved": "1.9.0",
|
||||
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
|
||||
},
|
||||
"StyleCop.Analyzers": {
|
||||
"type": "Direct",
|
||||
"requested": "[1.2.0-beta.435, )",
|
||||
|
|
|
@ -100,6 +100,7 @@ internal class Compiler : ICompiler
|
|||
p.Add<Passes.Rules.Neutral.NormAxisReshape>();
|
||||
p.Add<Passes.Rules.Neutral.NormAxisReduceArg>();
|
||||
p.Add<Passes.Rules.Neutral.NormAxisSlice>();
|
||||
p.Add<Passes.Rules.Neutral.SwapBinaryArgs>();
|
||||
p.Add<Passes.Rules.Neutral.SqueezeTransposeShape>();
|
||||
p.Add<Passes.Rules.Neutral.Squeeze5DTranspose>();
|
||||
p.Add<Passes.Rules.Neutral.SqueezeBinaryShape>();
|
||||
|
@ -140,7 +141,6 @@ internal class Compiler : ICompiler
|
|||
p.Add<Passes.Rules.Neutral.FoldNopReduce>();
|
||||
p.Add<Passes.Rules.Neutral.SliceToGetItem>();
|
||||
p.Add<Passes.Rules.Neutral.FoldTwoPads>();
|
||||
p.Add<Passes.Rules.Neutral.SwapBinaryArgs>();
|
||||
p.Add<Passes.Rules.Neutral.FoldDilatedConv2D>();
|
||||
});
|
||||
|
||||
|
|
|
@ -661,8 +661,7 @@
|
|||
"dependencies": {
|
||||
"Extension.Mathematics": "[1.2.12, )",
|
||||
"Nncase.Core": "[1.0.0, )",
|
||||
"Nncase.IO": "[1.0.0, )",
|
||||
"Razor.Templating.Core": "[1.9.0, )"
|
||||
"Nncase.IO": "[1.0.0, )"
|
||||
}
|
||||
},
|
||||
"nncase.core": {
|
||||
|
@ -881,12 +880,6 @@
|
|||
"libortki": "0.0.2"
|
||||
}
|
||||
},
|
||||
"Razor.Templating.Core": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[1.9.0, )",
|
||||
"resolved": "1.9.0",
|
||||
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
|
||||
},
|
||||
"Singulink.Collections.Weak": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[1.0.2, )",
|
||||
|
|
|
@ -71,21 +71,4 @@ public static class TIRUtilities
|
|||
IR.F.Math.Max(0, t.First.Start),
|
||||
IR.F.Math.Min(t.Second.FixedValue, t.First.Stop),
|
||||
t.First.Step)).ToArray();
|
||||
|
||||
public static bool TryGetFixedRegions(TIR.BufferRegion region, out (int Start, int Stop, int Step)[] slice)
|
||||
{
|
||||
slice = new (int Start, int Stop, int Step)[region.Region.Length];
|
||||
for (int i = 0; i < region.Region.Length; i++)
|
||||
{
|
||||
var rg = region.Region[i];
|
||||
if (rg is not Range { Start: IR.TensorConst start, Stop: IR.TensorConst stop, Step: IR.TensorConst step })
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
slice[i] = (start.Value.ToScalar<int>(), stop.Value.ToScalar<int>(), step.Value.ToScalar<int>());
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using GiGraph.Dot.Output.Writers.Edges;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using Nncase.IR.Tensors;
|
||||
|
|
|
@ -32,25 +32,6 @@ public partial class EGraphPrinter
|
|||
return printer.SaveToStream(file);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// find the minCostEnode in eclass.
|
||||
/// <remarks>
|
||||
/// the marker first.
|
||||
/// </remarks>
|
||||
/// </summary>
|
||||
internal static ENode MinByWithMarker(EClass eClass, CostModel.EGraphCostModel costModel)
|
||||
{
|
||||
return eClass.Nodes.OrderBy(e => e.Expr, ENodeTypeComparer.Instance).MinBy(x => x.Expr is Marker ? CostModel.Cost.Zero : costModel[x])!;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// find the minCostEnode in eclass skip marker.
|
||||
/// </summary>
|
||||
internal static ENode MinByWithOutMarker(EClass eClass, CostModel.EGraphCostModel costModel)
|
||||
{
|
||||
return eClass.Nodes.Where(e => e.Expr is not Marker).MinBy(x => costModel[x])!;
|
||||
}
|
||||
|
||||
private DotGraph AttachEGraphCost(CostModel.EGraphCostModel costModel, EClass entry)
|
||||
{
|
||||
// 1. display each enode costs.
|
||||
|
@ -91,12 +72,12 @@ public partial class EGraphPrinter
|
|||
continue;
|
||||
}
|
||||
|
||||
var minCostEnode = MinByWithMarker(parent, costModel);
|
||||
var minCostEnode = parent.MinByWithMarker(costModel);
|
||||
|
||||
// when this marker ecalss has been visited, skip it.
|
||||
if (markerEclassMemo.Contains(parent))
|
||||
{
|
||||
minCostEnode = MinByWithOutMarker(parent, costModel);
|
||||
minCostEnode = parent.MinByWithOutMarker(costModel);
|
||||
}
|
||||
|
||||
var (minCostDotnode, table) = NodesMap[minCostEnode];
|
||||
|
@ -112,7 +93,7 @@ public partial class EGraphPrinter
|
|||
if (minCostEnode.Expr is Marker && child == parent)
|
||||
{
|
||||
markerEclassMemo.Add(child);
|
||||
var otherminCostENode = MinByWithOutMarker(child, costModel);
|
||||
var otherminCostENode = child.MinByWithOutMarker(costModel);
|
||||
var (childDotNode, _) = NodesMap[otherminCostENode];
|
||||
_dotGraph.Edges.Add(childDotNode, minCostDotnode, edge =>
|
||||
{
|
||||
|
@ -122,7 +103,7 @@ public partial class EGraphPrinter
|
|||
}
|
||||
else
|
||||
{
|
||||
var childEnode = MinByWithMarker(child.Find(), costModel);
|
||||
var childEnode = child.Find().MinByWithMarker(costModel);
|
||||
var (childDotNode, _) = NodesMap[childEnode];
|
||||
_dotGraph.Edges.Add(childDotNode, minCostDotnode, edge =>
|
||||
{
|
||||
|
@ -145,23 +126,3 @@ public partial class EGraphPrinter
|
|||
return _dotGraph;
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed class ENodeTypeComparer : IComparer<Expr>
|
||||
{
|
||||
public static readonly ENodeTypeComparer Instance = new();
|
||||
|
||||
public int Compare(Expr? x, Expr? y) => (x, y) switch
|
||||
{
|
||||
(null, null) => 0,
|
||||
(Expr, null) => 1,
|
||||
(null, Expr) => -1,
|
||||
(Expr, Expr) => GetPriority(x).CompareTo(GetPriority(y)),
|
||||
};
|
||||
|
||||
private int GetPriority(Expr x) => x switch
|
||||
{
|
||||
Marker => 0,
|
||||
Const => 1,
|
||||
_ => 2,
|
||||
};
|
||||
}
|
||||
|
|
|
@ -1,50 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using Google.OrTools.Sat;
|
||||
using Nncase.CostModel;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using Nncase.PatternMatch;
|
||||
using static Nncase.PatternMatch.F.Math;
|
||||
using static Nncase.PatternMatch.Utility;
|
||||
|
||||
namespace Nncase.Passes;
|
||||
|
||||
/// <summary>
|
||||
/// EGraph extract extensions.
|
||||
/// </summary>
|
||||
public static class EGraphExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Extract egraph.
|
||||
/// </summary>
|
||||
/// <param name="eGraph">egraph.</param>
|
||||
/// <param name="root">Root eclass.</param>
|
||||
/// <param name="basefunc_cost_evaluator">base func cost evaluator.</param>
|
||||
/// <param name="constrains">the cp model constrains.</param>
|
||||
public static Expr Extract(this IEGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, EGraphExtractConstrains[] constrains)
|
||||
{
|
||||
// 1. set enode expr with more accuracy type.
|
||||
foreach (var eclass in eGraph.Classes)
|
||||
{
|
||||
foreach (var nodes in eclass.Nodes)
|
||||
{
|
||||
if (eclass.CheckedType.CompareTo(nodes.Expr.CheckedType) > 0)
|
||||
{
|
||||
nodes.Expr.CheckedType = eclass.CheckedType;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. start the cost evaluator
|
||||
var costModel = new CostModel.EGraphCostEvaluator(root.Find(), basefunc_cost_evaluator, false).Evaluate();
|
||||
|
||||
return new EGraphExtractor(costModel).Extract(root.Find(), eGraph, constrains);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using Nncase.CostModel;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using Nncase.PatternMatch;
|
||||
using static Nncase.PatternMatch.F.Math;
|
||||
using static Nncase.PatternMatch.Utility;
|
||||
|
||||
namespace Nncase.Passes;
|
||||
|
||||
/// <summary>
|
||||
/// EGraph extract extensions.
|
||||
/// </summary>
|
||||
public static class EGraphExtractExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Extract egraph.
|
||||
/// </summary>
|
||||
/// <param name="eGraph">eGraph.</param>
|
||||
/// <param name="root">Root eclass.</param>
|
||||
/// <param name="basefunc_cost_evaluator">base func cost evaluator.</param>
|
||||
/// <param name="picks">the picks.</param>
|
||||
/// <returns>Extracted root expression.</returns>
|
||||
public static Expr Extract(this IEGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, out IReadOnlyDictionary<ENode, bool> picks)
|
||||
{
|
||||
// 1. set enode expr with more accuracy type.
|
||||
foreach (var eclass in eGraph.Classes)
|
||||
{
|
||||
foreach (var nodes in eclass.Nodes)
|
||||
{
|
||||
if (eclass.CheckedType.CompareTo(nodes.Expr.CheckedType) > 0)
|
||||
{
|
||||
nodes.Expr.CheckedType = eclass.CheckedType;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. start the cost evaluator
|
||||
var costModel = new EGraphCostEvaluator(root.Find(), basefunc_cost_evaluator, false).Evaluate();
|
||||
|
||||
// if (DumpScope.Current.IsEnabled(DumpFlags.EGraphCost))
|
||||
// {
|
||||
// using var fs = DumpScope.Current.OpenFile(Path.Combine("Costs", $"V{eGraph.Version}.dot"));
|
||||
// EGraphPrinter.DumpEgraphAsDot(eGraph, costModel, root.Find(), fs);
|
||||
// }
|
||||
// return new EGraphExtractor(costModel).Extract(root.Find(), eGraph);
|
||||
return new EGraphExtractors.SatExtractor(costModel).Extract(root.Find(), eGraph, out picks);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// find the minCostEnode in eclass.
|
||||
/// <remarks>
|
||||
/// the marker first.
|
||||
/// </remarks>
|
||||
/// </summary>
|
||||
internal static ENode MinByWithMarker(this EClass eClass, CostModel.EGraphCostModel costModel)
|
||||
{
|
||||
return eClass.Nodes.OrderBy(e => e.Expr, ENodeTypeComparer.Instance).MinBy(x => x.Expr is Marker ? Cost.Zero : costModel[x])!;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// find the minCostEnode in eclass skip marker.
|
||||
/// </summary>
|
||||
internal static ENode MinByWithOutMarker(this EClass eClass, CostModel.EGraphCostModel costModel)
|
||||
{
|
||||
return eClass.Nodes.Where(e => e.Expr is not Marker).MinBy(x => costModel[x])!;
|
||||
}
|
||||
|
||||
internal sealed class ENodeTypeComparer : IComparer<Expr>
|
||||
{
|
||||
public static readonly ENodeTypeComparer Instance = new();
|
||||
|
||||
public int Compare(Expr? x, Expr? y) => (x, y) switch
|
||||
{
|
||||
(null, null) => 0,
|
||||
(Expr, null) => 1,
|
||||
(null, Expr) => -1,
|
||||
(Expr, Expr) => GetPriority(x).CompareTo(GetPriority(y)),
|
||||
};
|
||||
|
||||
private int GetPriority(Expr x) => x switch
|
||||
{
|
||||
Marker => 0,
|
||||
Const => 1,
|
||||
_ => 2,
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,200 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using Nncase.CostModel;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using Nncase.PatternMatch;
|
||||
using static Nncase.PatternMatch.F.Math;
|
||||
using static Nncase.PatternMatch.Utility;
|
||||
|
||||
namespace Nncase.Passes.EGraphExtractors;
|
||||
|
||||
internal interface IExtractor
|
||||
{
|
||||
Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary<ENode, bool> picks);
|
||||
}
|
||||
|
||||
internal class Extractor : IExtractor
|
||||
{
|
||||
private readonly EGraphCostModel _costModel;
|
||||
private readonly Dictionary<EClass, Expr> _eclassMemo = new();
|
||||
private readonly Dictionary<EClass, Expr> _markerEclassMemo = new();
|
||||
private readonly Dictionary<ENode, bool> _picks = new();
|
||||
private StreamWriter? _dumpWriter;
|
||||
|
||||
public Extractor(EGraphCostModel costModel)
|
||||
{
|
||||
_costModel = costModel;
|
||||
}
|
||||
|
||||
public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary<ENode, bool> picks)
|
||||
{
|
||||
_dumpWriter = DumpScope.Current.IsEnabled(DumpFlags.EGraphCost)
|
||||
? new StreamWriter(DumpScope.Current.OpenFile($"{nameof(Extractor)}_Class_{root.Id}.txt"))
|
||||
: null;
|
||||
try
|
||||
{
|
||||
Visit(root);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_dumpWriter?.Dispose();
|
||||
}
|
||||
|
||||
foreach (var enode in eGraph.Nodes)
|
||||
{
|
||||
if (!_picks.ContainsKey(enode))
|
||||
{
|
||||
_picks[enode] = false;
|
||||
}
|
||||
}
|
||||
|
||||
picks = _picks;
|
||||
return _eclassMemo[root];
|
||||
}
|
||||
|
||||
private void Visit(EClass eclass)
|
||||
{
|
||||
var stack = new Stack<(EClass, ENode)>();
|
||||
stack.Push((eclass, eclass.MinByWithMarker(_costModel)));
|
||||
var markerEclassSet = new HashSet<EClass>();
|
||||
while (stack.Any())
|
||||
{
|
||||
(eclass, var minCostEnode) = stack.Peek();
|
||||
if (_eclassMemo.ContainsKey(eclass))
|
||||
{
|
||||
stack.Pop();
|
||||
continue;
|
||||
}
|
||||
|
||||
Expr? expr = null;
|
||||
switch (minCostEnode.Expr)
|
||||
{
|
||||
case Var or TensorConst or TupleConst or Op or Fusion or None:
|
||||
expr = minCostEnode.Expr;
|
||||
break;
|
||||
case Function or Call or IR.Tuple or Marker or IR.If:
|
||||
var childrenExprs = new List<Expr>();
|
||||
foreach (var child in minCostEnode.Children)
|
||||
{
|
||||
if (!_eclassMemo.TryGetValue(child, out var childExpr))
|
||||
{
|
||||
if (minCostEnode.Expr is Marker && child == eclass)
|
||||
{
|
||||
if (!_markerEclassMemo.TryGetValue(eclass, out var markerInputExpr))
|
||||
{
|
||||
markerEclassSet.Add(eclass);
|
||||
stack.Push((eclass, eclass.MinByWithOutMarker(_costModel)));
|
||||
}
|
||||
else
|
||||
{
|
||||
childrenExprs.Add(markerInputExpr);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
stack.Push((child, child.MinByWithMarker(_costModel)));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
childrenExprs.Add(childExpr);
|
||||
}
|
||||
}
|
||||
|
||||
if (childrenExprs.Count != minCostEnode.Children.Count)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
expr = minCostEnode.Expr switch
|
||||
{
|
||||
Function function => Visit(minCostEnode, function, new(childrenExprs)),
|
||||
Call call => Visit(minCostEnode, call, new(childrenExprs)),
|
||||
IR.Tuple tuple => Visit(minCostEnode, tuple, new(childrenExprs)),
|
||||
Marker marker => Visit(minCostEnode, marker, new(childrenExprs)),
|
||||
IR.If @if => Visit(minCostEnode, @if, new(childrenExprs)),
|
||||
_ => throw new ArgumentException("Unsupported expression type."),
|
||||
};
|
||||
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentException("Unsupported expression type.");
|
||||
}
|
||||
|
||||
if (expr is null)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (markerEclassSet.Contains(eclass) && minCostEnode.Expr is not Marker)
|
||||
{
|
||||
_markerEclassMemo.Add(eclass, expr);
|
||||
}
|
||||
else
|
||||
{
|
||||
_eclassMemo.Add(eclass, expr);
|
||||
}
|
||||
|
||||
_picks[minCostEnode] = true;
|
||||
stack.Pop();
|
||||
}
|
||||
}
|
||||
|
||||
private Marker Visit(ENode enode, Marker marker, IRArray<Expr> children)
|
||||
{
|
||||
var target = children[0];
|
||||
var attr = children[1];
|
||||
return marker.With(target: target, attribute: attr);
|
||||
}
|
||||
|
||||
private Function Visit(ENode enode, Function func, IRArray<Expr> children)
|
||||
{
|
||||
if (children.Count == 0)
|
||||
{
|
||||
return func;
|
||||
}
|
||||
|
||||
var body = children[0];
|
||||
return func.With(body: body);
|
||||
}
|
||||
|
||||
private IR.Tuple Visit(ENode enode, IR.Tuple tuple, IRArray<Expr> children)
|
||||
{
|
||||
return tuple.With(fields: children.ToArray());
|
||||
}
|
||||
|
||||
private IR.If Visit(ENode enode, IR.If @if, IRArray<Expr> children)
|
||||
{
|
||||
return @if.With(condition: children[^3], then: children[^2], @else: children[^1], paramList: children[..^3].ToArray());
|
||||
}
|
||||
|
||||
private Call Visit(ENode enode, Call call, IRArray<Expr> children)
|
||||
{
|
||||
var target = children[0];
|
||||
var arguments = children.Skip(1);
|
||||
|
||||
// for mix quant debug.
|
||||
if (call.EnodeQuantConfigWithCosine != null && _dumpWriter != null)
|
||||
{
|
||||
_dumpWriter.WriteLine(call + " " + call.CheckedType);
|
||||
for (int i = 0; i < call.EnodeQuantConfigWithCosine.Count; i++)
|
||||
{
|
||||
for (int j = 0; j < call.EnodeQuantConfigWithCosine[i].Item1.Count; j++)
|
||||
{
|
||||
_dumpWriter.Write(call.EnodeQuantConfigWithCosine[i].Item1[j] + " ");
|
||||
}
|
||||
|
||||
_dumpWriter.WriteLine(call.EnodeQuantConfigWithCosine[i].Item3);
|
||||
}
|
||||
}
|
||||
|
||||
return call.With(target: target, arguments: arguments.ToArray(), call.Metadata);
|
||||
}
|
||||
}
|
|
@ -11,20 +11,18 @@ using Nncase.CostModel;
|
|||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase.Passes;
|
||||
namespace Nncase.Passes.EGraphExtractors;
|
||||
|
||||
public delegate void EGraphExtractConstrains(CpModel model, IReadOnlyDictionary<ENode, BoolVar> vars);
|
||||
|
||||
internal class EGraphExtractor
|
||||
internal class SatExtractor : IExtractor
|
||||
{
|
||||
private readonly EGraphCostModel _costModel;
|
||||
|
||||
public EGraphExtractor(EGraphCostModel costModel)
|
||||
public SatExtractor(EGraphCostModel costModel)
|
||||
{
|
||||
_costModel = costModel;
|
||||
}
|
||||
|
||||
public Expr Extract(EClass root, IEGraph eGraph, EGraphExtractConstrains[] constrains)
|
||||
public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary<ENode, bool> picks)
|
||||
{
|
||||
var cpmodel = new CpModel();
|
||||
|
||||
|
@ -70,11 +68,6 @@ internal class EGraphExtractor
|
|||
EliminateAllCycles(root, new(), new(), visited, cpmodel, vars);
|
||||
}
|
||||
|
||||
foreach (var constrain in constrains)
|
||||
{
|
||||
constrain(cpmodel, vars);
|
||||
}
|
||||
|
||||
// 3. add pick weights for all enode.
|
||||
cpmodel.Minimize(LinearExpr.WeightedSum(eGraph.Nodes.Select(n => vars[n]), eGraph.Nodes.Select(n => checked((long)_costModel[n].Score))));
|
||||
|
||||
|
@ -128,7 +121,7 @@ internal class EGraphExtractor
|
|||
throw new InvalidProgramException("SatExtract Failed!");
|
||||
}
|
||||
|
||||
var picks = eGraph.Nodes.ToDictionary(e => e, e => solver.BooleanValue(vars[e]));
|
||||
picks = eGraph.Nodes.ToDictionary(e => e, e => solver.BooleanValue(vars[e]));
|
||||
using (var dumpStream = enableDump ? DumpScope.Current.OpenFile("Costs/Pick.dot") : Stream.Null)
|
||||
{
|
||||
EGraphPrinter.DumpEgraphAsDot(eGraph, _costModel, picks, root.Find(), dumpStream);
|
|
@ -36,7 +36,7 @@ internal class EGraphRewriteProvider : IEGraphRewriteProvider
|
|||
|
||||
var graph = new EGraph(expr);
|
||||
ERewrite(graph, rules, options);
|
||||
var post = graph.Extract(graph.Root!, null, Array.Empty<EGraphExtractConstrains>());
|
||||
var post = graph.Extract(graph.Root!, null, out _);
|
||||
return post;
|
||||
}
|
||||
|
||||
|
|
|
@ -3,34 +3,54 @@
|
|||
|
||||
namespace Nncase.Passes.BufferSchedule;
|
||||
|
||||
public sealed class Interval
|
||||
internal sealed class TimeInterval
|
||||
{
|
||||
public Interval(int start, int end)
|
||||
public TimeInterval(int start, int end)
|
||||
{
|
||||
Brith = start;
|
||||
Death = end;
|
||||
}
|
||||
|
||||
public int Brith { get; set; }
|
||||
|
||||
public int Death { get; set; }
|
||||
|
||||
public int Size => Death - Brith;
|
||||
|
||||
public override string ToString()
|
||||
{
|
||||
return $"TimeInterval({Brith}, {Death})";
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed class MemSpan
|
||||
{
|
||||
public MemSpan(int start, int end)
|
||||
{
|
||||
Start = start;
|
||||
Stop = end;
|
||||
End = end;
|
||||
}
|
||||
|
||||
public int Start { get; set; }
|
||||
|
||||
public int Stop { get; set; }
|
||||
public int End { get; set; }
|
||||
|
||||
public int Size => Stop - Start;
|
||||
public int Size => End - Start;
|
||||
|
||||
public override string ToString()
|
||||
{
|
||||
return $"Interval({Start}, {Stop})";
|
||||
return $"MemSpan({Start}, {End})";
|
||||
}
|
||||
}
|
||||
|
||||
public class ScheduleBuffer
|
||||
internal class ScheduleBuffer
|
||||
{
|
||||
public ScheduleBuffer(string name, int number, Interval timeInterval, Interval memInterval, int[] shape, int[] strides, bool inplace)
|
||||
public ScheduleBuffer(string name, int number, TimeInterval interval, MemSpan span, int[] shape, int[] strides, bool inplace)
|
||||
{
|
||||
Name = name;
|
||||
Number = number;
|
||||
TimeInterval = timeInterval;
|
||||
MemInterval = memInterval;
|
||||
Interval = interval;
|
||||
Span = span;
|
||||
Shape = shape;
|
||||
Strides = strides;
|
||||
Inplace = inplace;
|
||||
|
@ -40,9 +60,9 @@ public class ScheduleBuffer
|
|||
|
||||
public int Number { get; }
|
||||
|
||||
public Interval TimeInterval { get; }
|
||||
public TimeInterval Interval { get; }
|
||||
|
||||
public Interval MemInterval { get; }
|
||||
public MemSpan Span { get; }
|
||||
|
||||
public int[] Shape { get; }
|
||||
|
||||
|
@ -52,6 +72,6 @@ public class ScheduleBuffer
|
|||
|
||||
public override string ToString()
|
||||
{
|
||||
return $"ScheduledBuffer('{Name}', {Number}, {TimeInterval}, {MemInterval}, ConstraintsMode.No, [{string.Join(",", Shape)}], [{string.Join(",", Strides)}], {Inplace})";
|
||||
return $"ScheduledBuffer('{Name}', {Number}, {Interval}, {Span}, ConstraintsMode.No, [{string.Join(",", Shape)}], [{string.Join(",", Strides)}], {Inplace})";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,10 +13,49 @@ using Nncase.IR;
|
|||
|
||||
namespace Nncase.Passes.BufferSchedule;
|
||||
|
||||
public class BufferScheduler
|
||||
internal sealed class BufferScheduler
|
||||
{
|
||||
public virtual void ExternalConstrains(CpModel model, IReadOnlyDictionary<Expr, ScheduleBuffer> bufferMap, IReadOnlyDictionary<Expr, (IntervalVar X, IntervalVar Y)> boxs)
|
||||
public IReadOnlyDictionary<Expr, ScheduleBuffer> CollectLifeTime(Function func)
|
||||
{
|
||||
var c = new LifeTimeCollector();
|
||||
return c.Collect(func);
|
||||
}
|
||||
|
||||
public void Schedule(IReadOnlyDictionary<Expr, ScheduleBuffer> bufferMap)
|
||||
{
|
||||
var model = new CpModel();
|
||||
var noOverlap = model.AddNoOverlap2D();
|
||||
var boxs = new Dictionary<Expr, (IntervalVar X, IntervalVar Y)>(ReferenceEqualityComparer.Instance);
|
||||
var timeMap = new Dictionary<int, List<Expr>>();
|
||||
var yStarts = new List<IntVar>();
|
||||
foreach (var (expr, item) in bufferMap)
|
||||
{
|
||||
var xInterval = model.NewIntervalVar(model.NewConstant(item.Interval.Brith), model.NewConstant(item.Interval.Size), model.NewConstant(item.Interval.Death), item.Name + $"{item.Number}_x");
|
||||
|
||||
var upbound = 2147483648 - item.Span.End;
|
||||
if (upbound <= 0)
|
||||
{
|
||||
throw new System.NotSupportedException();
|
||||
}
|
||||
|
||||
var memStartVar = model.NewIntVar(0, upbound, $"{item.Name}_{item.Number}_y_start");
|
||||
var yInterval = model.NewFixedSizeIntervalVar(memStartVar, item.Span.End, $"{item.Name}_{item.Number}_y");
|
||||
noOverlap.AddRectangle(xInterval, yInterval);
|
||||
yStarts.Add(memStartVar);
|
||||
boxs.Add(expr, (xInterval, yInterval));
|
||||
|
||||
for (int time = item.Interval.Brith; time < item.Interval.Death; time++)
|
||||
{
|
||||
if (!timeMap.TryGetValue(time, out var timelist))
|
||||
{
|
||||
timelist = new();
|
||||
timeMap.Add(time, timelist);
|
||||
}
|
||||
|
||||
timelist.Add(expr);
|
||||
}
|
||||
}
|
||||
|
||||
foreach (var (expr, item) in bufferMap)
|
||||
{
|
||||
if (expr is Call { Target: IR.Tensors.Concat } concatCall && concatCall.Arguments[0] is IR.Tuple tuple)
|
||||
|
@ -26,7 +65,7 @@ public class BufferScheduler
|
|||
for (int i = 0; i < tuple.Fields.Length; i++)
|
||||
{
|
||||
model.Add((boxs[concatCall].Y.StartExpr() + offset) == boxs[tuple.Fields[i]].Y.StartExpr());
|
||||
offset += bufferMap[tuple.Fields[i]].MemInterval.Size;
|
||||
offset += bufferMap[tuple.Fields[i]].Span.Size;
|
||||
}
|
||||
}
|
||||
else if (expr is Call { Target: IR.Tensors.Split } splitCall)
|
||||
|
@ -40,7 +79,7 @@ public class BufferScheduler
|
|||
foreach (var user in users.OrderBy(e => ((Call)e).Arguments[1].Evaluate().AsTensor().ToScalar<int>()))
|
||||
{
|
||||
model.Add((boxs[splitCall].Y.StartExpr() + offset) == boxs[user].Y.StartExpr());
|
||||
offset += bufferMap[user].MemInterval.Size;
|
||||
offset += bufferMap[user].Span.Size;
|
||||
}
|
||||
}
|
||||
else if (expr is Call { Target: IR.Tensors.Reshape } reshapCall)
|
||||
|
@ -49,44 +88,6 @@ public class BufferScheduler
|
|||
model.Add(boxs[reshapCall].Y.StartExpr() == boxs[reshapCall.Arguments[0]].Y.StartExpr());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void Schedule(IReadOnlyDictionary<Expr, ScheduleBuffer> bufferMap)
|
||||
{
|
||||
var model = new CpModel();
|
||||
var noOverlap = model.AddNoOverlap2D();
|
||||
var boxs = new Dictionary<Expr, (IntervalVar X, IntervalVar Y)>(ReferenceEqualityComparer.Instance);
|
||||
var timeMap = new Dictionary<int, List<Expr>>();
|
||||
var yStarts = new List<IntVar>();
|
||||
foreach (var (expr, item) in bufferMap)
|
||||
{
|
||||
var xInterval = model.NewIntervalVar(model.NewConstant(item.TimeInterval.Start), model.NewConstant(item.TimeInterval.Size), model.NewConstant(item.TimeInterval.Stop), item.Name + $"{item.Number}_x");
|
||||
|
||||
var upbound = 2147483648 - item.MemInterval.Stop;
|
||||
if (upbound <= 0)
|
||||
{
|
||||
throw new System.NotSupportedException();
|
||||
}
|
||||
|
||||
var memStartVar = model.NewIntVar(0, upbound, $"{item.Name}_{item.Number}_y_start");
|
||||
var yInterval = model.NewFixedSizeIntervalVar(memStartVar, item.MemInterval.Stop, $"{item.Name}_{item.Number}_y");
|
||||
noOverlap.AddRectangle(xInterval, yInterval);
|
||||
yStarts.Add(memStartVar);
|
||||
boxs.Add(expr, (xInterval, yInterval));
|
||||
|
||||
for (int time = item.TimeInterval.Start; time < item.TimeInterval.Stop; time++)
|
||||
{
|
||||
if (!timeMap.TryGetValue(time, out var timelist))
|
||||
{
|
||||
timelist = new();
|
||||
timeMap.Add(time, timelist);
|
||||
}
|
||||
|
||||
timelist.Add(expr);
|
||||
}
|
||||
}
|
||||
|
||||
ExternalConstrains(model, bufferMap, boxs);
|
||||
|
||||
model.Minimize(LinearExpr.Sum(yStarts));
|
||||
|
||||
|
@ -98,10 +99,10 @@ public class BufferScheduler
|
|||
throw new System.NotSupportedException();
|
||||
}
|
||||
|
||||
foreach (var (k, _) in bufferMap)
|
||||
foreach (var (k, v) in bufferMap)
|
||||
{
|
||||
bufferMap[k].MemInterval.Start = checked((int)solver.Value(boxs[k].Y.StartExpr()));
|
||||
bufferMap[k].MemInterval.Stop = checked((int)solver.Value(boxs[k].Y.EndExpr()));
|
||||
bufferMap[k].Span.Start = checked((int)solver.Value(boxs[k].Y.StartExpr()));
|
||||
bufferMap[k].Span.End = checked((int)solver.Value(boxs[k].Y.EndExpr()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -118,11 +119,18 @@ from enum import Enum
|
|||
from typing import List
|
||||
|
||||
@dataclass
|
||||
class Interval():
|
||||
class TimeInterval():
|
||||
start: int
|
||||
end: int
|
||||
def __str__(self) -> str:
|
||||
return f'(start: {self.start}, end {self.end}, size {self.end - self.start})'
|
||||
return f'(start: {self.start}, end {self.end})'
|
||||
|
||||
@dataclass
|
||||
class MemSpan():
|
||||
depth_start: int
|
||||
depth_end: int
|
||||
def __str__(self) -> str:
|
||||
return f'(start: {self.depth_start}, size {self.depth_end - self.depth_start})'
|
||||
|
||||
class ConstraintsMode(Enum):
|
||||
No = 0
|
||||
|
@ -132,8 +140,8 @@ class ConstraintsMode(Enum):
|
|||
class ScheduledBuffer():
|
||||
name: str
|
||||
number: int
|
||||
time_interval: Interval
|
||||
mem_interval: Interval
|
||||
interval: TimeInterval
|
||||
location: MemSpan
|
||||
constraints: ConstraintsMode
|
||||
shape: List[int]
|
||||
stride: List[int]
|
||||
|
@ -158,8 +166,8 @@ source = {
|
|||
'height': [],
|
||||
'alpha': [],
|
||||
'color': [],
|
||||
'mem_interval': [],
|
||||
'time_interval': [],
|
||||
'location': [],
|
||||
'interval': [],
|
||||
'shape': [],
|
||||
'stride': [],
|
||||
}
|
||||
|
@ -169,10 +177,10 @@ x_range_max = 0
|
|||
color_dict = {}
|
||||
for buffer in buffers:
|
||||
source['name'].append(buffer.name)
|
||||
width = buffer.time_interval.end - buffer.time_interval.start
|
||||
x = buffer.time_interval.start + (width / 2)
|
||||
height = buffer.mem_interval.end - buffer.mem_interval.start
|
||||
y = buffer.mem_interval.start + (height / 2)
|
||||
width = buffer.interval.end - buffer.interval.start
|
||||
x = buffer.interval.start + (width / 2)
|
||||
height = buffer.location.depth_end - buffer.location.depth_start
|
||||
y = buffer.location.depth_start + (height / 2)
|
||||
y_range_max = max(y_range_max, y)
|
||||
x_range_max = max(x_range_max, buffer.interval.end)
|
||||
source['x'].append(x)
|
||||
|
@ -185,13 +193,13 @@ for buffer in buffers:
|
|||
color_dict[buffer.name] = color
|
||||
source['color'].append(color)
|
||||
source['alpha'].append(0.2 if buffer.inplace else 1.0)
|
||||
source['time_interval'].append(str(buffer.time_interval))
|
||||
source['mem_interval'].append(str(buffer.mem_interval))
|
||||
source['interval'].append(str(buffer.interval))
|
||||
source['location'].append(str(buffer.location))
|
||||
source['shape'].append(','.join([str(s) for s in buffer.shape]))
|
||||
source['stride'].append(','.join([str(s) for s in buffer.stride]))
|
||||
|
||||
source = ColumnDataSource(source)
|
||||
hover = HoverTool(tooltips=[('name', '@name'), ('time_interval', '@time_interval'), ('mem_interval', '@mem_interval'),
|
||||
hover = HoverTool(tooltips=[('name', '@name'), ('interval', '@interval'), ('location', '@location'),
|
||||
('shape', '@shape'), ('stride', '@stride')])
|
||||
|
||||
p = figure(tools=[hover, WheelPanTool(), SaveTool(), WheelZoomTool(), ResetTool()], width=1280, height=720,
|
||||
|
|
|
@ -10,16 +10,16 @@ using Nncase.IR;
|
|||
|
||||
namespace Nncase.Passes.BufferSchedule;
|
||||
|
||||
public class LifeTimeCollector : ExprVisitor<Unit, Unit>
|
||||
internal sealed class LifeTimeCollector : ExprVisitor<Unit, Unit>
|
||||
{
|
||||
public int TimeStamp { get; private set; }
|
||||
|
||||
public Dictionary<Expr, Interval> LifenessMap { get; } = new(ReferenceEqualityComparer.Instance);
|
||||
public Dictionary<Expr, TimeInterval> LifenessMap { get; } = new(ReferenceEqualityComparer.Instance);
|
||||
|
||||
public IReadOnlyDictionary<Expr, ScheduleBuffer> Collect(Expr expr)
|
||||
public IReadOnlyDictionary<Expr, ScheduleBuffer> Collect(Function entry)
|
||||
{
|
||||
Visit(expr);
|
||||
Update(expr); // avoid final call time interval size == 1.
|
||||
Visit(entry.Body);
|
||||
Update(entry.Body); // avoid final call time interval size == 1.
|
||||
Alias();
|
||||
|
||||
var d = new Dictionary<Expr, ScheduleBuffer>(ReferenceEqualityComparer.Instance);
|
||||
|
@ -32,7 +32,8 @@ public class LifeTimeCollector : ExprVisitor<Unit, Unit>
|
|||
Var va => va.Name,
|
||||
_ => k.GetType().Name,
|
||||
};
|
||||
var size = ComputeBufferSize(k.CheckedType, out var shape, out var stride);
|
||||
var size = GetSize(k.CheckedType, out var shape, out var stride);
|
||||
|
||||
d.Add(k, new(name, count++, v, new(0, size), shape, stride, false));
|
||||
}
|
||||
|
||||
|
@ -61,29 +62,6 @@ public class LifeTimeCollector : ExprVisitor<Unit, Unit>
|
|||
return Unit.Default;
|
||||
}
|
||||
|
||||
protected virtual int ComputeBufferSize(IRType type, out int[] shape, out int[] stride)
|
||||
{
|
||||
shape = Array.Empty<int>();
|
||||
stride = Array.Empty<int>();
|
||||
var size = 0;
|
||||
if (type is TensorType tensorType)
|
||||
{
|
||||
shape = tensorType.Shape.ToValueArray();
|
||||
stride = TensorUtilities.GetStrides(shape);
|
||||
size = TensorUtilities.GetSize(shape, stride, tensorType.DType.SizeInBytes);
|
||||
}
|
||||
else if (type is TupleType tupleType)
|
||||
{
|
||||
size = 0;
|
||||
foreach (var item in tupleType)
|
||||
{
|
||||
size += ComputeBufferSize(item, out _, out _);
|
||||
}
|
||||
}
|
||||
|
||||
return size;
|
||||
}
|
||||
|
||||
private void Update(Expr expr)
|
||||
{
|
||||
if (expr is Const or None)
|
||||
|
@ -107,7 +85,7 @@ public class LifeTimeCollector : ExprVisitor<Unit, Unit>
|
|||
}
|
||||
else
|
||||
{
|
||||
interval.Stop = TimeStamp + 1;
|
||||
interval.Death = TimeStamp + 1;
|
||||
}
|
||||
|
||||
LifenessMap[expr] = interval;
|
||||
|
@ -145,12 +123,12 @@ public class LifeTimeCollector : ExprVisitor<Unit, Unit>
|
|||
} while (changed);
|
||||
}
|
||||
|
||||
private bool AliasTime(Call call, Interval interval)
|
||||
private bool AliasTime(Call call, TimeInterval interval)
|
||||
{
|
||||
var brith = call.GetArguments().Select(arg => LifenessMap[arg].Stop).Concat(new[] { interval.Start }).Max();
|
||||
var death = call.GetUsers().Select(usr => LifenessMap[usr].Start).Concat(new[] { interval.Stop }).Min();
|
||||
var brith = call.GetArguments().Select(arg => LifenessMap[arg].Death).Concat(new[] { interval.Brith }).Max();
|
||||
var death = call.GetUsers().Select(usr => LifenessMap[usr].Brith).Concat(new[] { interval.Death }).Min();
|
||||
|
||||
if (brith == interval.Start && death == interval.Stop)
|
||||
if (brith == interval.Brith && death == interval.Death)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
@ -160,8 +138,31 @@ public class LifeTimeCollector : ExprVisitor<Unit, Unit>
|
|||
throw new InvalidOperationException();
|
||||
}
|
||||
|
||||
interval.Start = brith;
|
||||
interval.Stop = death;
|
||||
interval.Brith = brith;
|
||||
interval.Death = death;
|
||||
return true;
|
||||
}
|
||||
|
||||
private int GetSize(IRType type, out int[] shape, out int[] stride)
|
||||
{
|
||||
shape = Array.Empty<int>();
|
||||
stride = Array.Empty<int>();
|
||||
var size = 0;
|
||||
if (type is TensorType tensorType)
|
||||
{
|
||||
shape = tensorType.Shape.ToValueArray();
|
||||
stride = TensorUtilities.GetStrides(shape);
|
||||
size = TensorUtilities.GetSize(shape, stride, tensorType.DType.SizeInBytes);
|
||||
}
|
||||
else if (type is TupleType tupleType)
|
||||
{
|
||||
size = 0;
|
||||
foreach (var item in tupleType)
|
||||
{
|
||||
size += GetSize(item, out _, out _);
|
||||
}
|
||||
}
|
||||
|
||||
return size;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,8 +46,7 @@ public sealed class DDrBufferSchdeulePass : ModulePass
|
|||
if (module.Entry is Function { ModuleKind: Callable.StackVMModuleKind, Body: Expr body } func && IsFixedType(body.CheckedType))
|
||||
{
|
||||
var sch = new BufferSchedule.BufferScheduler();
|
||||
var c = new BufferSchedule.LifeTimeCollector();
|
||||
var buffers = c.Collect(func.Body);
|
||||
var buffers = sch.CollectLifeTime(func);
|
||||
sch.Schedule(buffers);
|
||||
using (var fs = Diagnostics.DumpScope.Current.OpenFile("draw_buffers.py"))
|
||||
{
|
||||
|
|
|
@ -24,7 +24,7 @@ public sealed class EGraphExtractPass : Pass<IEGraph, BaseFunction>
|
|||
|
||||
protected override Task<BaseFunction> RunCoreAsync(IEGraph input, RunPassContext context)
|
||||
{
|
||||
var post = (BaseFunction)input.Extract(input.Root!, _costEvaluator, Array.Empty<EGraphExtractConstrains>());
|
||||
var post = (BaseFunction)input.Extract(input.Root!, _costEvaluator, out _);
|
||||
IRHelpers.DCE(post);
|
||||
return Task.FromResult(post);
|
||||
}
|
||||
|
|
|
@ -916,8 +916,7 @@
|
|||
"dependencies": {
|
||||
"Extension.Mathematics": "[1.2.12, )",
|
||||
"Nncase.Core": "[1.0.0, )",
|
||||
"Nncase.IO": "[1.0.0, )",
|
||||
"Razor.Templating.Core": "[1.9.0, )"
|
||||
"Nncase.IO": "[1.0.0, )"
|
||||
}
|
||||
},
|
||||
"nncase.compiler": {
|
||||
|
@ -1200,12 +1199,6 @@
|
|||
"libortki": "0.0.2"
|
||||
}
|
||||
},
|
||||
"Razor.Templating.Core": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[1.9.0, )",
|
||||
"resolved": "1.9.0",
|
||||
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
|
||||
},
|
||||
"Singulink.Collections.Weak": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[1.0.2, )",
|
||||
|
|
|
@ -65,8 +65,7 @@
|
|||
"dependencies": {
|
||||
"Extension.Mathematics": "[1.2.12, )",
|
||||
"Nncase.Core": "[1.0.0, )",
|
||||
"Nncase.IO": "[1.0.0, )",
|
||||
"Razor.Templating.Core": "[1.9.0, )"
|
||||
"Nncase.IO": "[1.0.0, )"
|
||||
}
|
||||
},
|
||||
"nncase.core": {
|
||||
|
@ -154,12 +153,6 @@
|
|||
"System.Runtime.CompilerServices.Unsafe": "5.0.0"
|
||||
}
|
||||
},
|
||||
"Razor.Templating.Core": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[1.9.0, )",
|
||||
"resolved": "1.9.0",
|
||||
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
|
||||
},
|
||||
"System.CommandLine": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[2.0.0-beta4.22272.1, )",
|
||||
|
|
|
@ -718,8 +718,7 @@
|
|||
"dependencies": {
|
||||
"Extension.Mathematics": "[1.2.12, )",
|
||||
"Nncase.Core": "[1.0.0, )",
|
||||
"Nncase.IO": "[1.0.0, )",
|
||||
"Razor.Templating.Core": "[1.9.0, )"
|
||||
"Nncase.IO": "[1.0.0, )"
|
||||
}
|
||||
},
|
||||
"nncase.compiler": {
|
||||
|
@ -1002,12 +1001,6 @@
|
|||
"libortki": "0.0.2"
|
||||
}
|
||||
},
|
||||
"Razor.Templating.Core": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[1.9.0, )",
|
||||
"resolved": "1.9.0",
|
||||
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
|
||||
},
|
||||
"Singulink.Collections.Weak": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[1.0.2, )",
|
||||
|
|
|
@ -47,10 +47,10 @@ public sealed class UnitTestEGraphCostModel
|
|||
},
|
||||
};
|
||||
|
||||
Assert.IsType<TensorConst>(list.OrderBy(e => e, ENodeTypeComparer.Instance).First());
|
||||
Assert.IsType<TensorConst>(list.OrderBy(e => e, EGraphExtractExtensions.ENodeTypeComparer.Instance).First());
|
||||
|
||||
Assert.True(cost[b] < cost[c]);
|
||||
|
||||
Assert.IsType<TensorConst>(list.OrderBy(e => e, ENodeTypeComparer.Instance).MinBy(e => cost[e]));
|
||||
Assert.IsType<TensorConst>(list.OrderBy(e => e, EGraphExtractExtensions.ENodeTypeComparer.Instance).MinBy(e => cost[e]));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -839,8 +839,7 @@
|
|||
"dependencies": {
|
||||
"Extension.Mathematics": "[1.2.12, )",
|
||||
"Nncase.Core": "[1.0.0, )",
|
||||
"Nncase.IO": "[1.0.0, )",
|
||||
"Razor.Templating.Core": "[1.9.0, )"
|
||||
"Nncase.IO": "[1.0.0, )"
|
||||
}
|
||||
},
|
||||
"nncase.compiler": {
|
||||
|
@ -1095,12 +1094,6 @@
|
|||
"libortki": "0.0.2"
|
||||
}
|
||||
},
|
||||
"Razor.Templating.Core": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[1.9.0, )",
|
||||
"resolved": "1.9.0",
|
||||
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
|
||||
},
|
||||
"Singulink.Collections.Weak": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[1.0.2, )",
|
||||
|
|
Loading…
Reference in New Issue