mirror of https://github.com/kendryte/nncase.git
Add expr user analysis (#827)
* Make IR mutable * Add expr user analysis * fix il dot dump && fusion merge condition * Merge from rebuild-ir * Fix fusion tests * Fix warnings * Fix sg reference * Fix merge * Apply code-format changes * Fix TestNoneValue * Fix test dump * Update with k230 requires * Fix k230 tests * ExprRewriter: Use replace opr instead of rauw * Fix * Fix * Optimize performance * Disable CS0659;CS0661 * Fix invalidate hashcode cache * Fix new commits --------- Co-authored-by: 郑启航 <597323109@qq.com> Co-authored-by: sunnycase <sunnycase@users.noreply.github.com>pull/837/head
parent
00ee74c402
commit
e88d9c8d5f
|
@ -306,7 +306,7 @@ dotnet_diagnostic.IDE1006.severity = suggestion
|
|||
|
||||
dotnet_code_quality.copy_analysis = true
|
||||
dotnet_diagnostic.CA1000.severity = none
|
||||
dotnet_diagnostic.CA1001.severity = warning
|
||||
dotnet_diagnostic.CA1001.severity = suggestion
|
||||
dotnet_diagnostic.CA1018.severity = warning
|
||||
dotnet_diagnostic.CA1019.severity = warning
|
||||
dotnet_diagnostic.CA1036.severity = warning
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
<CodeAnalysisRuleSet>$(MSBuildThisFileDirectory)/tools/StyleCopAnalyzers.ruleset</CodeAnalysisRuleSet>
|
||||
<GenerateDocumentationFile>true</GenerateDocumentationFile>
|
||||
<Nullable>enable</Nullable>
|
||||
<NoWarn>$(NoWarn);MSB3270</NoWarn>
|
||||
<NoWarn>$(NoWarn);MSB3270;CS0659;CS0661</NoWarn>
|
||||
</PropertyGroup>
|
||||
<PropertyGroup Condition="'$(Configuration)' == 'Release'">
|
||||
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
|
||||
|
@ -23,7 +23,7 @@
|
|||
<PackageVersion Include="Google.Protobuf" Version="3.19.1" />
|
||||
<PackageVersion Include="Grpc.Tools" Version="2.42.0" />
|
||||
<PackageVersion Include="Humanizer.Core" Version="2.14.1" />
|
||||
<PackageVersion Include="LanguageExt.Core" Version="4.0.3" />
|
||||
<PackageVersion Include="LanguageExt.Core" Version="4.4.0" />
|
||||
<PackageVersion Include="libtorch-cpu-linux-x64" Version="1.13.0.1" />
|
||||
<PackageVersion Include="libtorch-cpu-osx-x64" Version="1.13.0.1" />
|
||||
<PackageVersion Include="libtorch-cpu-win-x64" Version="1.13.0.1" />
|
||||
|
@ -55,6 +55,7 @@
|
|||
<PackageVersion Include="StyleCop.Analyzers" Version="1.2.0-beta.435" />
|
||||
<PackageVersion Include="System.CommandLine.Hosting" Version="0.3.0-alpha.21216.1" />
|
||||
<PackageVersion Include="System.Linq.Async" Version="6.0.1" />
|
||||
<PackageVersion Include="System.Reactive" Version="5.0.0" />
|
||||
<PackageVersion Include="Tomlyn.Extensions.Configuration" Version="1.0.5" />
|
||||
<PackageVersion Include="TorchSharp" Version="0.99.0" />
|
||||
<PackageVersion Include="xunit" Version="2.4.1" />
|
||||
|
|
|
@ -6,12 +6,14 @@
|
|||
<add key="nuget.org" value="https://api.nuget.org/v3/index.json" protocolVersion="3" />
|
||||
<add key="Nncase.Libs" value="https://www.myget.org/F/magicallibs/api/v3/index.json" protocolVersion="3"/>
|
||||
<add key="myget-xunit" value="https://www.myget.org/F/xunit/api/v3/index.json" />
|
||||
<add key="design-packages" value="tools/design-packages" />
|
||||
</packageSources>
|
||||
<activePackageSource>
|
||||
<add key="nuget.cnblogs.com" value="https://nuget.cnblogs.com/v3/index.json" protocolVersion="3" />
|
||||
<add key="nuget.org" value="https://api.nuget.org/v3/index.json" protocolVersion="3" />
|
||||
<add key="Nncase.Libs" value="https://www.myget.org/F/magicallibs/api/v3/index.json" protocolVersion="3"/>
|
||||
<add key="myget-xunit" value="https://www.myget.org/F/xunit/api/v3/index.json" />
|
||||
<add key="design-packages" value="tools/design-packages" />
|
||||
</activePackageSource>
|
||||
<packageSourceMapping>
|
||||
<!-- key value for <packageSource> should match key values from <packageSources> element -->
|
||||
|
@ -26,5 +28,8 @@
|
|||
<packageSource key="myget-xunit">
|
||||
<package pattern="xunit.v3.assert" />
|
||||
</packageSource>
|
||||
<packageSource key="design-packages">
|
||||
<package pattern="Nncase.SourceGenerator" />
|
||||
</packageSource>
|
||||
</packageSourceMapping>
|
||||
</configuration>
|
||||
|
|
|
@ -11,10 +11,10 @@ using Nncase.CodeGen;
|
|||
using Nncase.CodeGen.K210;
|
||||
using Nncase.CodeGen.StackVM;
|
||||
using Nncase.IR;
|
||||
using Nncase.Passes;
|
||||
using Nncase.Passes.Rules.K210;
|
||||
using Nncase.Quantization;
|
||||
using Nncase.Runtime.K210;
|
||||
using Nncase.Transform;
|
||||
using Nncase.Transform.Rules.K210;
|
||||
|
||||
namespace Nncase.Targets;
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ using static Nncase.IR.TypePatternUtility;
|
|||
using static Nncase.PatternMatch.F.NN;
|
||||
using static Nncase.PatternMatch.Utility;
|
||||
|
||||
namespace Nncase.Transform.Rules.K210;
|
||||
namespace Nncase.Passes.Rules.K210;
|
||||
|
||||
/// <summary>
|
||||
/// Lower <see cref="IR.NN.Conv2D"/> to <see cref="IR.K210.FakeKPUConv2D"/>.
|
||||
|
|
|
@ -21,7 +21,7 @@ using static Nncase.PatternMatch.Utility;
|
|||
using static Nncase.Quantization.Utility;
|
||||
using Math = Nncase.IR.F.Math;
|
||||
|
||||
namespace Nncase.Transform.Rules.K210;
|
||||
namespace Nncase.Passes.Rules.K210;
|
||||
|
||||
/// <summary>
|
||||
/// Lower <see cref="FakeDequantize"/> to <see cref="Dequantize"/>.
|
||||
|
|
|
@ -22,7 +22,7 @@ using static Nncase.PatternMatch.F.NN;
|
|||
using static Nncase.PatternMatch.Utility;
|
||||
using Math = System.Math;
|
||||
|
||||
namespace Nncase.Transform.Rules.K210;
|
||||
namespace Nncase.Passes.Rules.K210;
|
||||
|
||||
/// <summary>
|
||||
/// Lower <see cref="IR.K210.FakeKPUConv2D"/> to <see cref="IR.K210.KPUConv2D"/>.
|
||||
|
|
|
@ -20,7 +20,7 @@ using static Nncase.PatternMatch.F.NN;
|
|||
using static Nncase.PatternMatch.Utility;
|
||||
using Math = Nncase.IR.F.Math;
|
||||
|
||||
namespace Nncase.Transform.Rules.K210;
|
||||
namespace Nncase.Passes.Rules.K210;
|
||||
|
||||
/// <summary>
|
||||
/// Lower <see cref="IR.K210.FakeKPUDownload"/> to <see cref="IR.K210.KPUDownload"/>.
|
||||
|
|
|
@ -21,7 +21,7 @@ using static Nncase.PatternMatch.F.NN;
|
|||
using static Nncase.PatternMatch.Utility;
|
||||
using Math = Nncase.IR.F.Math;
|
||||
|
||||
namespace Nncase.Transform.Rules.K210;
|
||||
namespace Nncase.Passes.Rules.K210;
|
||||
|
||||
/// <summary>
|
||||
/// Lower <see cref="IR.K210.FakeKPUUpload"/> to <see cref="IR.K210.KPUUpload"/>.
|
||||
|
|
|
@ -21,7 +21,7 @@ using static Nncase.PatternMatch.Utility;
|
|||
using static Nncase.Quantization.Utility;
|
||||
using Math = Nncase.IR.F.Math;
|
||||
|
||||
namespace Nncase.Transform.Rules.K210;
|
||||
namespace Nncase.Passes.Rules.K210;
|
||||
|
||||
/// <summary>
|
||||
/// Lower <see cref="FakeQuantize"/> to <see cref="Quantize"/>.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/3/3 下午1:12:08 +08:00. */
|
||||
|
||||
|
@ -188,7 +188,7 @@ internal partial class CodeGenVisitor
|
|||
Emitter.T.GetItem();
|
||||
break;
|
||||
case IR.Tensors.LSTM top:
|
||||
Emitter.T.LSTM(top.Direction, top.Layout, top.Activations);
|
||||
Emitter.T.LSTM(top.Direction, top.Layout, top.Activations.ToArray());
|
||||
break;
|
||||
case IR.Tensors.Prod top:
|
||||
Emitter.T.Prod();
|
||||
|
|
|
@ -4,8 +4,10 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reactive;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using NetFabric.Hyperlinq;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase.CodeGen.StackVM;
|
||||
|
@ -105,7 +107,7 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
|
|||
|
||||
private StackVMEmitter Emitter => CurrentTextSnippet.Emitter;
|
||||
|
||||
public override TextSnippet VisitLeaf(Const expr)
|
||||
protected override TextSnippet VisitLeafConst(Const expr)
|
||||
{
|
||||
if (expr is TensorConst tc)
|
||||
{
|
||||
|
@ -113,11 +115,11 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
|
|||
}
|
||||
else
|
||||
{
|
||||
return Visit(new IR.Tuple(((TupleConst)expr).Fields));
|
||||
return Visit(new IR.Tuple(((TupleConst)expr).Value.Select(x => Const.FromValue(x)).ToArray()));
|
||||
}
|
||||
}
|
||||
|
||||
public override TextSnippet VisitLeaf(Var expr)
|
||||
protected override TextSnippet VisitLeafVar(Var expr)
|
||||
{
|
||||
var snippet = BeginTextSnippet(expr);
|
||||
var varIndex = ((Function)_function).Parameters.IndexOf(expr);
|
||||
|
@ -130,13 +132,13 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
|
|||
return snippet;
|
||||
}
|
||||
|
||||
public override TextSnippet VisitLeaf(IR.Tuple expr)
|
||||
protected override TextSnippet VisitLeafTuple(IR.Tuple expr)
|
||||
{
|
||||
var snippet = BeginTextSnippet(expr);
|
||||
foreach (var field in expr.Fields.Reverse())
|
||||
foreach (var field in expr.Fields.ToArray().Reverse())
|
||||
{
|
||||
var inputSnippet = Visit(field);
|
||||
inputSnippet.MaxUserParameters = Math.Max(inputSnippet.MaxUserParameters, expr.Fields.Count);
|
||||
inputSnippet.MaxUserParameters = Math.Max(inputSnippet.MaxUserParameters, expr.Fields.Length);
|
||||
snippet.AddInput(inputSnippet);
|
||||
}
|
||||
|
||||
|
@ -145,7 +147,7 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
|
|||
return snippet;
|
||||
}
|
||||
|
||||
public override TextSnippet Visit(Function expr)
|
||||
protected override TextSnippet VisitFunction(Function expr)
|
||||
{
|
||||
if (ReferenceEquals(expr, _function))
|
||||
{
|
||||
|
@ -157,7 +159,7 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
|
|||
}
|
||||
}
|
||||
|
||||
public override TextSnippet Visit(PrimFunctionWrapper expr)
|
||||
protected override TextSnippet VisitPrimFunctionWrapper(PrimFunctionWrapper expr)
|
||||
{
|
||||
if (ReferenceEquals(expr, _function))
|
||||
{
|
||||
|
@ -177,25 +179,25 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
|
|||
}
|
||||
}
|
||||
|
||||
public override TextSnippet VisitLeaf(Op expr)
|
||||
protected override TextSnippet VisitLeafOp(Op expr)
|
||||
{
|
||||
return null!;
|
||||
}
|
||||
|
||||
public override TextSnippet VisitLeaf(Call expr)
|
||||
protected override TextSnippet VisitLeafCall(Call expr)
|
||||
{
|
||||
var snippet = BeginTextSnippet(expr);
|
||||
foreach (var param in expr.Parameters.Reverse())
|
||||
foreach (var param in expr.Arguments.ToArray().Reverse())
|
||||
{
|
||||
var paramSnippet = Visit(param);
|
||||
paramSnippet.MaxUserParameters = Math.Max(paramSnippet.MaxUserParameters, expr.Parameters.Count);
|
||||
paramSnippet.MaxUserParameters = Math.Max(paramSnippet.MaxUserParameters, expr.Arguments.Length);
|
||||
snippet.AddInput(paramSnippet);
|
||||
}
|
||||
|
||||
if (expr.Target is CustomOp custom_op)
|
||||
{
|
||||
_context.AddCustomCallModule(custom_op.ModuleType);
|
||||
Emitter.CusCall(custom_op.RegisteredName, custom_op.SerializeFields(), checked((ushort)expr.Parameters.Count));
|
||||
Emitter.CusCall(custom_op.RegisteredName, custom_op.SerializeFields(), checked((ushort)expr.Arguments.Length));
|
||||
}
|
||||
else if (expr.Target is Op op)
|
||||
{
|
||||
|
@ -209,7 +211,7 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
|
|||
else if (expr.Target is Function func)
|
||||
{
|
||||
LdFunctionId(func);
|
||||
Emitter.ExtCall(checked((ushort)func.Parameters.Count), false);
|
||||
Emitter.ExtCall(checked((ushort)func.Parameters.Length), false);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -219,15 +221,9 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
|
|||
return snippet;
|
||||
}
|
||||
|
||||
public override TextSnippet Visit(If expr)
|
||||
protected override TextSnippet VisitIf(If expr)
|
||||
{
|
||||
if (!ExpressionMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
result = VisitLeaf(expr);
|
||||
ExpressionMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
return VisitLeafIf(expr);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -245,7 +241,7 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
|
|||
/// </summary>
|
||||
/// <param name="if">If expr.</param>
|
||||
/// <returns>TextSnippet.</returns>
|
||||
public override TextSnippet VisitLeaf(If @if)
|
||||
protected override TextSnippet VisitLeafIf(If @if)
|
||||
{
|
||||
var condSnippet = Visit(@if.Condition);
|
||||
condSnippet.Emitter.LdScalar();
|
||||
|
|
|
@ -11,8 +11,8 @@ using Microsoft.Extensions.Options;
|
|||
using Nncase.CodeGen;
|
||||
using Nncase.CodeGen.StackVM;
|
||||
using Nncase.IR;
|
||||
using Nncase.Passes;
|
||||
using Nncase.Quantization;
|
||||
using Nncase.Transform;
|
||||
|
||||
namespace Nncase.Targets;
|
||||
|
||||
|
|
24
nncase.sln
24
nncase.sln
|
@ -14,6 +14,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{E11A07F5
|
|||
EndProjectSection
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Core", "src\Nncase.Core\Nncase.Core.csproj", "{2978CFB1-1530-4EC3-BB5F-130B6F606F85}"
|
||||
ProjectSection(ProjectDependencies) = postProject
|
||||
{24DF3895-9473-4C98-A494-8275456D02CC} = {24DF3895-9473-4C98-A494-8275456D02CC}
|
||||
EndProjectSection
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Importer", "src\Nncase.Importer\Nncase.Importer.csproj", "{37410602-6F3A-46F4-8CFD-9B8742BB98F4}"
|
||||
EndProject
|
||||
|
@ -47,8 +50,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Targets", "src\Nncas
|
|||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Simulator", "src\Nncase.Simulator\Nncase.Simulator.csproj", "{901AC17C-7B53-4B10-A2AC-EA7AEA6DC614}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Transform", "src\Nncase.Transform\Nncase.Transform.csproj", "{2537C687-E989-47A5-AE93-CA7CEB4CB7AA}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.EGraph", "src\Nncase.EGraph\Nncase.EGraph.csproj", "{ED7F9E49-E5F1-4A0F-9AB1-D2247D3AC596}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Compiler", "src\Nncase.Compiler\Nncase.Compiler.csproj", "{54B4C0B9-8BF0-4157-AB04-A55BA268444C}"
|
||||
|
@ -63,8 +64,6 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "modules", "modules", "{9859
|
|||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Modules.StackVM", "modules\Nncase.Modules.StackVM\Nncase.Modules.StackVM.csproj", "{70D3FA34-B0B6-488F-812D-7E076CC0DFB3}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Modules.K210", "modules\Nncase.Modules.K210\Nncase.Modules.K210.csproj", "{7390DF41-E804-4F12-B441-6EB5119C81BF}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Quantization", "src\Nncase.Quantization\Nncase.Quantization.csproj", "{317C0D8F-75B3-4248-83E8-17AADDCF247A}"
|
||||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "root", "root", "{EF1F8779-2B98-4E6F-A3DC-CA3FD2CADAD8}"
|
||||
|
@ -78,6 +77,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Diagnostics", "src\N
|
|||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Tests.TestFixture", "src\Nncase.Tests.TestFixture\Nncase.Tests.TestFixture.csproj", "{98A03405-CA53-4EC4-9B18-94D1C8DF9453}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Passes", "src\Nncase.Passes\Nncase.Passes.csproj", "{E6462E82-B48F-4AFA-AE34-725EF0A9CB42}"
|
||||
EndProject
|
||||
Global
|
||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||
Debug|Any CPU = Debug|Any CPU
|
||||
|
@ -136,10 +137,6 @@ Global
|
|||
{901AC17C-7B53-4B10-A2AC-EA7AEA6DC614}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{901AC17C-7B53-4B10-A2AC-EA7AEA6DC614}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{901AC17C-7B53-4B10-A2AC-EA7AEA6DC614}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{2537C687-E989-47A5-AE93-CA7CEB4CB7AA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{2537C687-E989-47A5-AE93-CA7CEB4CB7AA}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{2537C687-E989-47A5-AE93-CA7CEB4CB7AA}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{2537C687-E989-47A5-AE93-CA7CEB4CB7AA}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{ED7F9E49-E5F1-4A0F-9AB1-D2247D3AC596}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{ED7F9E49-E5F1-4A0F-9AB1-D2247D3AC596}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{ED7F9E49-E5F1-4A0F-9AB1-D2247D3AC596}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
|
@ -160,10 +157,6 @@ Global
|
|||
{70D3FA34-B0B6-488F-812D-7E076CC0DFB3}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{70D3FA34-B0B6-488F-812D-7E076CC0DFB3}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{70D3FA34-B0B6-488F-812D-7E076CC0DFB3}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{7390DF41-E804-4F12-B441-6EB5119C81BF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{7390DF41-E804-4F12-B441-6EB5119C81BF}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{7390DF41-E804-4F12-B441-6EB5119C81BF}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{7390DF41-E804-4F12-B441-6EB5119C81BF}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{317C0D8F-75B3-4248-83E8-17AADDCF247A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{317C0D8F-75B3-4248-83E8-17AADDCF247A}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{317C0D8F-75B3-4248-83E8-17AADDCF247A}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
|
@ -176,6 +169,10 @@ Global
|
|||
{98A03405-CA53-4EC4-9B18-94D1C8DF9453}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{98A03405-CA53-4EC4-9B18-94D1C8DF9453}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{98A03405-CA53-4EC4-9B18-94D1C8DF9453}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{E6462E82-B48F-4AFA-AE34-725EF0A9CB42}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{E6462E82-B48F-4AFA-AE34-725EF0A9CB42}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{E6462E82-B48F-4AFA-AE34-725EF0A9CB42}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{E6462E82-B48F-4AFA-AE34-725EF0A9CB42}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
EndGlobalSection
|
||||
GlobalSection(SolutionProperties) = preSolution
|
||||
HideSolutionNode = FALSE
|
||||
|
@ -196,17 +193,16 @@ Global
|
|||
{8E0E0672-0F96-4EF1-BDCD-D31F96A3DF73} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
|
||||
{56283378-06E3-4C6E-A8BF-7BD85C92D42C} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
|
||||
{901AC17C-7B53-4B10-A2AC-EA7AEA6DC614} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
|
||||
{2537C687-E989-47A5-AE93-CA7CEB4CB7AA} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
|
||||
{ED7F9E49-E5F1-4A0F-9AB1-D2247D3AC596} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
|
||||
{54B4C0B9-8BF0-4157-AB04-A55BA268444C} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
|
||||
{24DF3895-9473-4C98-A494-8275456D02CC} = {E11A07F5-DDF9-4BCA-94E4-988913EF7F51}
|
||||
{A79769F9-D567-4236-8965-08CFB440783B} = {E11A07F5-DDF9-4BCA-94E4-988913EF7F51}
|
||||
{920E3584-7987-44ED-8794-F6929E812038} = {A79769F9-D567-4236-8965-08CFB440783B}
|
||||
{70D3FA34-B0B6-488F-812D-7E076CC0DFB3} = {9859F5E8-5504-4AFE-B955-9497A0A0CD66}
|
||||
{7390DF41-E804-4F12-B441-6EB5119C81BF} = {9859F5E8-5504-4AFE-B955-9497A0A0CD66}
|
||||
{317C0D8F-75B3-4248-83E8-17AADDCF247A} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
|
||||
{E365B1B1-4D13-4839-9763-A7A7C5F32FD4} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
|
||||
{98A03405-CA53-4EC4-9B18-94D1C8DF9453} = {E5A4516C-4080-4346-991D-57A7AA76ADA6}
|
||||
{E6462E82-B48F-4AFA-AE34-725EF0A9CB42} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
|
||||
EndGlobalSection
|
||||
GlobalSection(ExtensibilityGlobals) = postSolution
|
||||
SolutionGuid = {9492E141-292E-4D60-9C6E-3738AB234DB2}
|
||||
|
|
|
@ -14,8 +14,8 @@ using Nncase.CodeGen;
|
|||
using Nncase.Compiler;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using Nncase.Passes;
|
||||
using Nncase.Quantization;
|
||||
using Nncase.Transform;
|
||||
|
||||
namespace Nncase.Cli.Commands;
|
||||
|
||||
|
|
|
@ -15,12 +15,8 @@ public class LinkedFunction : ILinkedFunction
|
|||
public LinkedFunction(uint id, Callable sourceFunction, uint textBegin, uint textLength, IReadOnlyList<ILinkedSection> sections)
|
||||
{
|
||||
Id = id;
|
||||
if (sourceFunction.CheckedType is null)
|
||||
{
|
||||
CompilerServices.InferenceType(sourceFunction);
|
||||
}
|
||||
|
||||
ParameterTypes = ((CallableType)sourceFunction.CheckedType!).Parameters.ToArray();
|
||||
CompilerServices.InferenceType(sourceFunction);
|
||||
ParameterTypes = ((CallableType)sourceFunction.CheckedType).Parameters.ToArray();
|
||||
ReturnType = ((CallableType)sourceFunction.CheckedType).ReturnType;
|
||||
TextBegin = textBegin;
|
||||
TextLength = textLength;
|
||||
|
|
|
@ -10,10 +10,10 @@ using Nncase.Diagnostics;
|
|||
using Nncase.Evaluator;
|
||||
using Nncase.Hosting;
|
||||
using Nncase.IR;
|
||||
using Nncase.Passes;
|
||||
using Nncase.Passes.Rules.Lower;
|
||||
using Nncase.Passes.Transforms;
|
||||
using Nncase.Quantization;
|
||||
using Nncase.Transform;
|
||||
using Nncase.Transform.Passes;
|
||||
using Nncase.Transform.Rules.Lower;
|
||||
using Nncase.Utilities;
|
||||
|
||||
namespace Nncase.Compiler;
|
||||
|
@ -61,43 +61,43 @@ internal class Compiler : ICompiler
|
|||
var quantMode = _compileSession.CompileOptions.QuantizeOptions.ModelQuantMode;
|
||||
if (quantMode == ModelQuantMode.UsePTQ)
|
||||
{
|
||||
passManager.AddWithName<EGraphPass>("NeutralOptimizeTranspose").Configure(p =>
|
||||
passManager.AddWithName<EGraphRulesPass>("NeutralOptimizeTranspose").Configure(p =>
|
||||
{
|
||||
p.Add<Transform.Rules.Neutral.FoldConstCall>();
|
||||
p.Add<Transform.Rules.Neutral.FoldNopTranspose>();
|
||||
p.Add<Transform.Rules.Neutral.FoldTwoTransposes>();
|
||||
p.Add<Transform.Rules.Neutral.CombineTransposeUnary>();
|
||||
p.Add<Transform.Rules.Neutral.CombineTransposePad>();
|
||||
p.Add<Transform.Rules.Neutral.CombinePadTranspose>();
|
||||
p.Add<Transform.Rules.Neutral.CombineBinaryTranspose>();
|
||||
p.Add<Transform.Rules.Neutral.CombineConstBinaryTranspose>();
|
||||
p.Add<Transform.Rules.Neutral.CombineTransposeConstBinary>();
|
||||
p.Add<Transform.Rules.Neutral.CombineTransposeReduce>();
|
||||
p.Add<Transform.Rules.Neutral.CombineTransposeActivations>();
|
||||
p.Add<Transform.Rules.Neutral.CombineActivationsTranspose>();
|
||||
p.Add<Transform.Rules.Neutral.CombineTransposeConcat>();
|
||||
p.Add<Transform.Rules.Neutral.FoldNopPad>();
|
||||
p.Add<Transform.Rules.Neutral.FoldConv2DPads>();
|
||||
p.Add<Transform.Rules.Neutral.FoldReduceWindow2DPads>();
|
||||
p.Add<Transform.Rules.Neutral.SqueezeToReshape>();
|
||||
p.Add<Transform.Rules.Neutral.UnSqueezeToReshape>();
|
||||
p.Add<Transform.Rules.Neutral.TransposeToReshape>();
|
||||
p.Add<Transform.Rules.Neutral.FoldNopReshape>();
|
||||
p.Add<Transform.Rules.Neutral.FoldTwoReshapes>();
|
||||
p.Add<Transform.Rules.Neutral.FoldLayerNormPattern1>();
|
||||
p.Add<Transform.Rules.Neutral.FoldLayerNormPattern2>();
|
||||
p.Add<Transform.Rules.Neutral.FoldLayerNormPattern3>();
|
||||
p.Add<Passes.Rules.Neutral.FoldConstCall>();
|
||||
p.Add<Passes.Rules.Neutral.FoldNopTranspose>();
|
||||
p.Add<Passes.Rules.Neutral.FoldTwoTransposes>();
|
||||
p.Add<Passes.Rules.Neutral.CombineTransposeUnary>();
|
||||
p.Add<Passes.Rules.Neutral.CombineTransposePad>();
|
||||
p.Add<Passes.Rules.Neutral.CombinePadTranspose>();
|
||||
p.Add<Passes.Rules.Neutral.CombineBinaryTranspose>();
|
||||
p.Add<Passes.Rules.Neutral.CombineConstBinaryTranspose>();
|
||||
p.Add<Passes.Rules.Neutral.CombineTransposeConstBinary>();
|
||||
p.Add<Passes.Rules.Neutral.CombineTransposeReduce>();
|
||||
p.Add<Passes.Rules.Neutral.CombineTransposeActivations>();
|
||||
p.Add<Passes.Rules.Neutral.CombineActivationsTranspose>();
|
||||
p.Add<Passes.Rules.Neutral.CombineTransposeConcat>();
|
||||
p.Add<Passes.Rules.Neutral.FoldNopPad>();
|
||||
p.Add<Passes.Rules.Neutral.FoldConv2DPads>();
|
||||
p.Add<Passes.Rules.Neutral.FoldReduceWindow2DPads>();
|
||||
p.Add<Passes.Rules.Neutral.SqueezeToReshape>();
|
||||
p.Add<Passes.Rules.Neutral.UnSqueezeToReshape>();
|
||||
p.Add<Passes.Rules.Neutral.TransposeToReshape>();
|
||||
p.Add<Passes.Rules.Neutral.FoldNopReshape>();
|
||||
p.Add<Passes.Rules.Neutral.FoldTwoReshapes>();
|
||||
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern1>();
|
||||
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern2>();
|
||||
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern3>();
|
||||
});
|
||||
|
||||
// passManager.AddWithName<EGraphPass>("NeutralOptimizeClamp").Configure(p =>
|
||||
// {
|
||||
// p.Add<Transform.Rules.Neutral.FoldConstCall>();
|
||||
// p.Add<Transform.Rules.Neutral.FoldConv2DAddMul>();
|
||||
// p.Add<Transform.Rules.Neutral.ReluToClamp>();
|
||||
// p.Add<Transform.Rules.Neutral.Relu6ToClamp>();
|
||||
// p.Add<Transform.Rules.Neutral.CombineClampAdd>();
|
||||
// p.Add<Transform.Rules.Neutral.CombineClampMul>();
|
||||
// p.Add<Transform.Rules.Neutral.FoldNopClamp>();
|
||||
// p.Add<Passes.Rules.Neutral.FoldConstCall>();
|
||||
// p.Add<Passes.Rules.Neutral.FoldConv2DAddMul>();
|
||||
// p.Add<Passes.Rules.Neutral.ReluToClamp>();
|
||||
// p.Add<Passes.Rules.Neutral.Relu6ToClamp>();
|
||||
// p.Add<Passes.Rules.Neutral.CombineClampAdd>();
|
||||
// p.Add<Passes.Rules.Neutral.CombineClampMul>();
|
||||
// p.Add<Passes.Rules.Neutral.FoldNopClamp>();
|
||||
// });
|
||||
}
|
||||
|
||||
|
@ -107,7 +107,7 @@ internal class Compiler : ICompiler
|
|||
{
|
||||
passManager.AddWithName<DataflowPass>("AddRangeOfMarker").Configure(p =>
|
||||
{
|
||||
p.Add<Transform.Rules.Neutral.AddRangeOfAndMarker>();
|
||||
p.Add<Passes.Rules.Neutral.AddRangeOfAndMarker>();
|
||||
});
|
||||
passManager.AddWithName<EGraphPassWithQuantize>("AssignRanges");
|
||||
}
|
||||
|
|
|
@ -52,8 +52,8 @@ public static class CompilerHostBuilderExtensions
|
|||
.AddGraph()
|
||||
.AddEGraph()
|
||||
.AddCodeGen()
|
||||
.AddStackVM()
|
||||
.AddK210();
|
||||
.AddPasses()
|
||||
.AddStackVM();
|
||||
}
|
||||
|
||||
private static void ConfigureServices(HostBuilderContext context, IServiceCollection services)
|
||||
|
|
|
@ -9,8 +9,6 @@ using System.Reflection;
|
|||
using System.Reflection.Metadata;
|
||||
using System.Reflection.PortableExecutable;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.Loader;
|
||||
using LanguageExt;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Nncase.Compiler;
|
||||
using Nncase.IR;
|
||||
|
|
|
@ -171,8 +171,8 @@ public static unsafe class CApi
|
|||
var samples = (dataset.Length == 0 ?
|
||||
Array.Empty<Dictionary<Var, IValue>>() :
|
||||
dataset.Chunk(dataset.Length / (int)samplesCount).Select(inputs => inputs.Zip(fnParams).ToDictionary(
|
||||
item => item.Item2,
|
||||
item => item.Item1.ToValue()))).ToAsyncEnumerable();
|
||||
item => item.Second,
|
||||
item => item.First.ToValue()))).ToAsyncEnumerable();
|
||||
return GCHandle.ToIntPtr(GCHandle.Alloc(new CCalibrationDatasetProvider(samples, (int)samplesCount)));
|
||||
}
|
||||
|
||||
|
@ -288,8 +288,8 @@ public static unsafe class CApi
|
|||
var fnParams = Get<Var[]>(fnParamsHandle);
|
||||
var inputs = Get<RTValue[]>(inputsHandle);
|
||||
var result = CompilerServices.Evaluate(expr, fnParams.Zip(inputs).ToDictionary(
|
||||
x => x.Item1,
|
||||
x => x.Item2.ToValue()));
|
||||
x => x.First,
|
||||
x => x.Second.ToValue()));
|
||||
var rtValue = RTValue.FromValue(result);
|
||||
return GCHandle.ToIntPtr(GCHandle.Alloc(rtValue));
|
||||
}
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\modules\Nncase.Modules.K210\Nncase.Modules.K210.csproj" />
|
||||
<ProjectReference Include="..\Nncase.CodeGen\Nncase.CodeGen.csproj" />
|
||||
<ProjectReference Include="..\Nncase.Core\Nncase.Core.csproj" />
|
||||
<ProjectReference Include="..\Nncase.Diagnostics\Nncase.Diagnostics.csproj" />
|
||||
|
@ -25,7 +24,7 @@
|
|||
<ProjectReference Include="..\Nncase.Importer\Nncase.Importer.csproj" />
|
||||
<ProjectReference Include="..\Nncase.Simulator\Nncase.Simulator.csproj" />
|
||||
<ProjectReference Include="..\Nncase.Quantization\Nncase.Quantization.csproj" />
|
||||
<ProjectReference Include="..\Nncase.Transform\Nncase.Transform.csproj" />
|
||||
<ProjectReference Include="..\Nncase.Passes\Nncase.Passes.csproj" />
|
||||
<ProjectReference Include="..\..\modules\Nncase.Modules.StackVM\Nncase.Modules.StackVM.csproj" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase.Collections;
|
||||
|
||||
public static class CollectionsExtensions
|
||||
{
|
||||
public static void AddRange<T>(this List<T> list, ReadOnlySpan<T> items)
|
||||
{
|
||||
list.Capacity += items.Length;
|
||||
foreach (var item in items)
|
||||
{
|
||||
list.Add(item);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,197 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase.Collections;
|
||||
|
||||
/// <summary>
|
||||
/// Represents an ordered sequence of weak references.
|
||||
/// </summary>
|
||||
public sealed class WeakList<T> : IEnumerable<T>
|
||||
where T : class
|
||||
{
|
||||
private const int MinimalNonEmptySize = 4;
|
||||
|
||||
private WeakReference<T>[] _items;
|
||||
private int _size;
|
||||
|
||||
public WeakList()
|
||||
{
|
||||
_items = Array.Empty<WeakReference<T>>();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the number of weak references in this list.
|
||||
/// Note that some of them might not point to live objects anymore.
|
||||
/// </summary>
|
||||
public int WeakCount
|
||||
{
|
||||
get { return _size; }
|
||||
}
|
||||
|
||||
public WeakReference<T> GetWeakReference(int index)
|
||||
{
|
||||
if (index < 0 || index >= _size)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(index));
|
||||
}
|
||||
|
||||
return _items[index];
|
||||
}
|
||||
|
||||
public void Add(T item)
|
||||
{
|
||||
if (_size == _items.Length)
|
||||
{
|
||||
Resize();
|
||||
}
|
||||
|
||||
_items[_size++] = new WeakReference<T>(item);
|
||||
}
|
||||
|
||||
public IEnumerator<T> GetEnumerator()
|
||||
{
|
||||
int count = _size;
|
||||
int alive = _size;
|
||||
int firstDead = -1;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
T? item;
|
||||
if (_items[i].TryGetTarget(out item))
|
||||
{
|
||||
yield return item;
|
||||
}
|
||||
else
|
||||
{
|
||||
// object has been collected
|
||||
if (firstDead < 0)
|
||||
{
|
||||
firstDead = i;
|
||||
}
|
||||
|
||||
alive--;
|
||||
}
|
||||
}
|
||||
|
||||
if (alive == 0)
|
||||
{
|
||||
_items = Array.Empty<WeakReference<T>>();
|
||||
_size = 0;
|
||||
}
|
||||
else if (alive < _items.Length / 4)
|
||||
{
|
||||
// If we have just a few items left we shrink the array.
|
||||
Shrink(firstDead, alive);
|
||||
}
|
||||
}
|
||||
|
||||
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
|
||||
{
|
||||
return GetEnumerator();
|
||||
}
|
||||
|
||||
private static int GetExpandedSize(int baseSize)
|
||||
{
|
||||
return Math.Max((baseSize * 2) + 1, MinimalNonEmptySize);
|
||||
}
|
||||
|
||||
private void Resize()
|
||||
{
|
||||
int alive = _items.Length;
|
||||
int firstDead = -1;
|
||||
for (int i = 0; i < _items.Length; i++)
|
||||
{
|
||||
if (!_items[i].TryGetTarget(out _))
|
||||
{
|
||||
if (firstDead == -1)
|
||||
{
|
||||
firstDead = i;
|
||||
}
|
||||
|
||||
alive--;
|
||||
}
|
||||
}
|
||||
|
||||
if (alive < _items.Length / 4)
|
||||
{
|
||||
// If we have just a few items left we shrink the array.
|
||||
// We avoid expanding the array until the number of new items added exceeds half of its capacity.
|
||||
Shrink(firstDead, alive);
|
||||
}
|
||||
else if (alive >= 3 * _items.Length / 4)
|
||||
{
|
||||
// If we have a lot of items alive we expand the array since just compacting them
|
||||
// wouldn't free up much space (we would end up calling Resize again after adding a few more items).
|
||||
var newItems = new WeakReference<T>[GetExpandedSize(_items.Length)];
|
||||
|
||||
if (firstDead >= 0)
|
||||
{
|
||||
Compact(firstDead, newItems);
|
||||
}
|
||||
else
|
||||
{
|
||||
Array.Copy(_items, 0, newItems, 0, _items.Length);
|
||||
}
|
||||
|
||||
_items = newItems;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Compact in-place to make space for new items at the end.
|
||||
// We will free up to length/4 slots in the array.
|
||||
Compact(firstDead, _items);
|
||||
}
|
||||
|
||||
Debug.Assert(_items.Length > 0 && _size < 3 * _items.Length / 4, "length: " + _items.Length + " size: " + _size);
|
||||
}
|
||||
|
||||
private void Shrink(int firstDead, int alive)
|
||||
{
|
||||
int newSize = GetExpandedSize(alive);
|
||||
var newItems = (newSize == _items.Length) ? _items : new WeakReference<T>[newSize];
|
||||
Compact(firstDead, newItems);
|
||||
_items = newItems;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Copies all live references from <see cref="_items"/> to <paramref name="result"/>.
|
||||
/// Assumes that all references prior <paramref name="firstDead"/> are alive.
|
||||
/// </summary>
|
||||
private void Compact(int firstDead, WeakReference<T>[] result)
|
||||
{
|
||||
if (!ReferenceEquals(_items, result))
|
||||
{
|
||||
Array.Copy(_items, 0, result, 0, firstDead);
|
||||
}
|
||||
|
||||
int oldSize = _size;
|
||||
int j = firstDead;
|
||||
for (int i = firstDead + 1; i < oldSize; i++)
|
||||
{
|
||||
var item = _items[i];
|
||||
|
||||
if (item.TryGetTarget(out _))
|
||||
{
|
||||
result[j++] = item;
|
||||
}
|
||||
}
|
||||
|
||||
_size = j;
|
||||
|
||||
// free WeakReferences
|
||||
if (ReferenceEquals(_items, result))
|
||||
{
|
||||
while (j < oldSize)
|
||||
{
|
||||
_items[j++] = null!;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -11,7 +11,7 @@ using Microsoft.Extensions.DependencyInjection;
|
|||
using Nncase.CodeGen;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using Nncase.Transform;
|
||||
using Nncase.Passes;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
|
@ -79,9 +79,7 @@ public sealed class CompileSession : IServiceProvider, IDisposable
|
|||
/// <param name="name">Name.</param>
|
||||
/// <returns>Created pass manager.</returns>
|
||||
public IPassManager CreatePassManager(string name)
|
||||
{
|
||||
return new PassManager(name, this);
|
||||
}
|
||||
=> _serviceProvider.GetRequiredService<IPassManagerFactory>().Create(name, this);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public void Dispose()
|
||||
|
|
|
@ -9,7 +9,7 @@ using System.Threading.Tasks;
|
|||
|
||||
namespace Nncase;
|
||||
|
||||
internal struct CompileSessionScope : IDisposable
|
||||
public struct CompileSessionScope : IDisposable
|
||||
{
|
||||
private static readonly AsyncLocal<CompileSession?> _compileSession = new AsyncLocal<CompileSession?>();
|
||||
|
||||
|
|
|
@ -14,9 +14,9 @@ using Microsoft.Extensions.DependencyInjection;
|
|||
using Nncase.CostModel;
|
||||
using Nncase.Evaluator;
|
||||
using Nncase.IR;
|
||||
using Nncase.Passes;
|
||||
using Nncase.PatternMatch;
|
||||
using Nncase.Targets;
|
||||
using Nncase.Transform;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
|
@ -108,9 +108,8 @@ public interface ICompilerServicesProvider
|
|||
/// Evaluate cost of the expression tree.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression.</param>
|
||||
/// <param name="varsValues">Optional vars' values.</param>
|
||||
/// <returns>Evaluate result.</returns>
|
||||
Cost? EvaluateCost(Expr expr, IReadOnlyDictionary<Var, Cost>? varsValues = null);
|
||||
Cost EvaluateCost(Expr expr);
|
||||
|
||||
/// <summary>
|
||||
/// Evaluate cost of operator.
|
||||
|
@ -118,7 +117,7 @@ public interface ICompilerServicesProvider
|
|||
/// <param name="op">Target operator.</param>
|
||||
/// <param name="context">Evaluate context.</param>
|
||||
/// <returns>Evaluate result.</returns>
|
||||
Cost? EvaluateOpCost(Op op, ICostEvaluateContext context);
|
||||
Cost EvaluateOpCost(Op op, ICostEvaluateContext context);
|
||||
|
||||
/// <summary>
|
||||
/// Match expression.
|
||||
|
@ -267,11 +266,10 @@ public static class CompilerServices
|
|||
/// Evaluate cost of the expression tree.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression.</param>
|
||||
/// <param name="varsValues">Optional vars' values.</param>
|
||||
/// <returns>Evaluate result.</returns>
|
||||
public static Cost? EvaluateCost(Expr expr, IReadOnlyDictionary<Var, Cost>? varsValues = null)
|
||||
public static Cost EvaluateCost(Expr expr)
|
||||
{
|
||||
return Provider.EvaluateCost(expr, varsValues);
|
||||
return Provider.EvaluateCost(expr);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -280,7 +278,7 @@ public static class CompilerServices
|
|||
/// <param name="op">Target operator.</param>
|
||||
/// <param name="context">Evaluate context.</param>
|
||||
/// <returns>Evaluate result.</returns>
|
||||
public static Cost? EvaluateOpCost(Op op, ICostEvaluateContext context)
|
||||
public static Cost EvaluateOpCost(Op op, ICostEvaluateContext context)
|
||||
{
|
||||
return Provider.EvaluateOpCost(op, context);
|
||||
}
|
||||
|
@ -560,13 +558,13 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
|
|||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public Cost? EvaluateCost(Expr expr, IReadOnlyDictionary<Var, Cost>? varsValues = null)
|
||||
public Cost EvaluateCost(Expr expr)
|
||||
{
|
||||
return _costEvaluateProvider.EvaluateCost(expr, varsValues);
|
||||
return _costEvaluateProvider.EvaluateCost(expr);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public Cost? EvaluateOpCost(Op op, ICostEvaluateContext context)
|
||||
public Cost EvaluateOpCost(Op op, ICostEvaluateContext context)
|
||||
{
|
||||
return _costEvaluateProvider.EvaluateOpCost(op, context);
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
using DryIoc;
|
||||
using Nncase.Hosting;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
|
@ -16,7 +15,6 @@ internal class CoreModule : IApplicationPart
|
|||
{
|
||||
registrator.RegisterManyInterface<CompilerServicesProvider>(reuse: Reuse.Singleton);
|
||||
registrator.Register<IDataTypeServiceProvider, DataTypeServiceProvider>(reuse: Reuse.Singleton);
|
||||
registrator.Register<IIRPrinterProvider, IRPrinterProvider>(reuse: Reuse.Singleton);
|
||||
|
||||
// Prim types
|
||||
registrator.Register<PrimType, BooleanType>(reuse: Reuse.Singleton);
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
using System;
|
||||
using NetFabric.Hyperlinq;
|
||||
using Nncase.IR;
|
||||
using static NetFabric.Hyperlinq.ArrayExtensions;
|
||||
|
||||
namespace Nncase.CostModel;
|
||||
|
||||
|
@ -162,9 +163,26 @@ public static class CostExtensions
|
|||
/// </summary>
|
||||
/// <param name="costs">Source.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public static Cost Sum(this IEnumerable<Cost?> costs)
|
||||
public static Cost Sum(this IEnumerable<Cost> costs)
|
||||
{
|
||||
return costs.Aggregate((Cost?)Cost.Zero, (x, y) => y == null ? null : x! + y)!;
|
||||
return costs.Aggregate(Cost.Zero, (x, y) => x + y)!;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sum all costs.
|
||||
/// </summary>
|
||||
/// <param name="costs">Source.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public static Cost Sum<TSource, TSelector>(this in SpanSelectEnumerable<TSource, Cost, TSelector> costs)
|
||||
where TSelector : struct, IFunction<TSource, Cost>
|
||||
{
|
||||
var sum = Cost.Zero;
|
||||
foreach (var cost in costs)
|
||||
{
|
||||
sum += cost;
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
internal enum EitherState
|
||||
{
|
||||
None,
|
||||
T1,
|
||||
T2,
|
||||
}
|
||||
|
||||
public struct Either<T1, T2>
|
||||
{
|
||||
private readonly EitherState _state;
|
||||
|
||||
[AllowNull]
|
||||
private readonly T1 _t1;
|
||||
|
||||
[AllowNull]
|
||||
private readonly T2 _t2;
|
||||
|
||||
private Either(T1 value)
|
||||
{
|
||||
_state = EitherState.T1;
|
||||
_t1 = value;
|
||||
_t2 = default;
|
||||
}
|
||||
|
||||
private Either(T2 value)
|
||||
{
|
||||
_state = EitherState.T2;
|
||||
_t2 = value;
|
||||
_t1 = default;
|
||||
}
|
||||
|
||||
public T1 Value1 => _state == EitherState.T1 ? _t1 : throw new InvalidOperationException($"Value is not a {typeof(T1)}.");
|
||||
|
||||
public T2 Value2 => _state == EitherState.T2 ? _t2 : throw new InvalidOperationException($"Value is not a {typeof(T2)}.");
|
||||
|
||||
public static implicit operator Either<T1, T2>(T1 value) => new(value);
|
||||
|
||||
public static implicit operator Either<T1, T2>(T2 value) => new(value);
|
||||
|
||||
public static explicit operator T1(Either<T1, T2> value) => value.Value1;
|
||||
|
||||
public static explicit operator T2(Either<T1, T2> value) => value.Value2;
|
||||
|
||||
public static Either<T1, T2> From(T1 value) => new(value);
|
||||
|
||||
public static Either<T1, T2> From(T2 value) => new(value);
|
||||
|
||||
public bool Is<T>()
|
||||
{
|
||||
if (typeof(T) == typeof(T1))
|
||||
{
|
||||
return _state == EitherState.T1;
|
||||
}
|
||||
|
||||
if (typeof(T) == typeof(T2))
|
||||
{
|
||||
return _state == EitherState.T2;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
public T Match<T>(Func<T1, T> t1Selector, Func<T2, T> t2Selector)
|
||||
=> Is<T1>() ? t1Selector(_t1) : t2Selector(_t2);
|
||||
|
||||
public object Match(Func<T1, object> t1Selector, Func<T2, object> t2Selector)
|
||||
=> Is<T1>() ? t1Selector(_t1) : t2Selector(_t2);
|
||||
}
|
|
@ -255,7 +255,7 @@ namespace Nncase.TIR
|
|||
IsEntryFunc = 549755813888,
|
||||
|
||||
/// <summary>
|
||||
/// Parameters used in the module that should be linked by the codegen.
|
||||
/// Arguments used in the module that should be linked by the codegen.
|
||||
/// </summary>
|
||||
LinkedParams = 1099511627776,
|
||||
|
||||
|
|
|
@ -20,9 +20,8 @@ public interface ICostEvaluateProvider
|
|||
/// Evaluate cost of the expression tree.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression.</param>
|
||||
/// <param name="varsValues">Optional vars' values.</param>
|
||||
/// <returns>Evaluate result.</returns>
|
||||
Cost? EvaluateCost(Expr expr, IReadOnlyDictionary<Var, Cost>? varsValues = null);
|
||||
Cost EvaluateCost(Expr expr);
|
||||
|
||||
/// <summary>
|
||||
/// Evaluate cost of operator.
|
||||
|
@ -30,5 +29,5 @@ public interface ICostEvaluateProvider
|
|||
/// <param name="op">Target operator.</param>
|
||||
/// <param name="context">Evaluate context.</param>
|
||||
/// <returns>Evaluate result.</returns>
|
||||
Cost? EvaluateOpCost(Op op, ICostEvaluateContext context);
|
||||
Cost EvaluateOpCost(Op op, ICostEvaluateContext context);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
/// <summary>
|
||||
/// Base function.
|
||||
/// </summary>
|
||||
public abstract class BaseFunction : Callable
|
||||
{
|
||||
public BaseFunction(string name, string moduleKind, Expr[] operands)
|
||||
: base(name, moduleKind, operands)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets sched result.
|
||||
/// </summary>
|
||||
public Schedule.SchedFunctionResult SchedResult { get; } = new();
|
||||
|
||||
/// <summary>
|
||||
/// Gets parameter types.
|
||||
/// </summary>
|
||||
public abstract IEnumerable<IRType?> ParameterTypes { get; }
|
||||
}
|
|
@ -6,11 +6,12 @@ using System.Collections.Generic;
|
|||
using System.Linq;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase.IR.Buffer;
|
||||
namespace Nncase.IR.Buffers;
|
||||
|
||||
/// <summary>
|
||||
/// get the buffer basement.
|
||||
/// </summary>
|
||||
public record Allocate(TensorType ElemType) : Op
|
||||
public sealed partial class Allocate : Op
|
||||
{
|
||||
public TensorType ElemType { get; }
|
||||
}
|
|
@ -4,12 +4,12 @@ using Nncase.IR.Tensors;
|
|||
using Nncase.PatternMatch;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
|
||||
namespace Nncase.IR.Buffer;
|
||||
namespace Nncase.IR.Buffers;
|
||||
|
||||
/// <summary>
|
||||
/// get the buffer basement.
|
||||
/// </summary>
|
||||
public record BaseMentOf() : Op
|
||||
public sealed partial class BaseMentOf : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Get the input parameter.
|
|
@ -4,18 +4,20 @@ using Nncase.IR.Tensors;
|
|||
using Nncase.PatternMatch;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
|
||||
namespace Nncase.IR.Buffer;
|
||||
namespace Nncase.IR.Buffers;
|
||||
|
||||
/// <summary>
|
||||
/// get the buffer from the input.
|
||||
/// </summary>
|
||||
public record BufferOf(Schedule.MemoryLocation MemoryLocation) : Op
|
||||
public sealed partial class BufferOf : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Get the input parameter.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(BufferOf), 0, "input", IsTensor());
|
||||
|
||||
public Schedule.MemoryLocation MemoryLocation { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"Schedule.MemoryLocation.{MemoryLocation}";
|
||||
}
|
|
@ -5,13 +5,13 @@ using Nncase.IR.Tensors;
|
|||
using Nncase.PatternMatch;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
|
||||
namespace Nncase.IR.Buffer;
|
||||
namespace Nncase.IR.Buffers;
|
||||
|
||||
/// <summary>
|
||||
/// DDrOf expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public record DDrOf() : Op
|
||||
public sealed partial class DDrOf : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Get the input parameter.
|
|
@ -6,7 +6,7 @@ using System.Collections.Generic;
|
|||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR.Buffer;
|
||||
using Nncase.IR.Buffers;
|
||||
|
||||
namespace Nncase.IR.F;
|
||||
|
|
@ -10,13 +10,13 @@ using System.Threading.Tasks;
|
|||
using Nncase.PatternMatch;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
|
||||
namespace Nncase.IR.Buffer;
|
||||
namespace Nncase.IR.Buffers;
|
||||
|
||||
/// <summary>
|
||||
/// Shape expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record StrideOf() : Op
|
||||
public sealed partial class StrideOf : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
|
@ -4,19 +4,23 @@
|
|||
using Nncase.PatternMatch;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
|
||||
namespace Nncase.IR.Buffer;
|
||||
namespace Nncase.IR.Buffers;
|
||||
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Uninitialized(DataType DType, Schedule.MemoryLocation MemoryLocation) : Op
|
||||
public sealed partial class Uninitialized : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// the shape.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Shape = new(typeof(Uninitialized), 0, "shape", IsIntegral() & IsTensor() & HasRank(1));
|
||||
|
||||
public DataType DType { get; }
|
||||
|
||||
public Schedule.MemoryLocation MemoryLocation { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override bool CanFoldConstCall => false;
|
||||
|
|
@ -7,6 +7,7 @@ using System.Collections.Immutable;
|
|||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.Utilities;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
|
@ -24,18 +25,32 @@ public interface IParameterList<T>
|
|||
/// <summary>
|
||||
/// Call expression.
|
||||
/// </summary>
|
||||
public sealed record Call(Expr Target, IRArray<Expr> Parameters) : Expr, IParameterList<Expr>
|
||||
public sealed class Call : Expr, IParameterList<Expr>
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="Call"/> class.
|
||||
/// </summary>
|
||||
/// <param name="target">Call target.</param>
|
||||
/// <param name="parameters">Parameters.</param>
|
||||
public Call(Expr target, params Expr[] parameters)
|
||||
: this(target, new IRArray<Expr>(parameters.ToImmutableArray()))
|
||||
/// <param name="arguments">Arguments.</param>
|
||||
public Call(Expr target, ReadOnlySpan<Expr> arguments)
|
||||
: base(ArrayUtility.Concat(target, arguments))
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="Call"/> class.
|
||||
/// </summary>
|
||||
/// <param name="target">Call target.</param>
|
||||
/// <param name="arguments">Arguments.</param>
|
||||
public Call(Expr target, params Expr[] arguments)
|
||||
: this(target, (ReadOnlySpan<Expr>)arguments)
|
||||
{
|
||||
}
|
||||
|
||||
public Expr Target => Operands[0];
|
||||
|
||||
public ReadOnlySpan<Expr> Arguments => Operands[1..];
|
||||
|
||||
// /// <summary>
|
||||
// /// used by fake ir, represents that whether this op permit int 16 quant.
|
||||
// /// </summary>
|
||||
|
@ -62,7 +77,7 @@ public sealed record Call(Expr Target, IRArray<Expr> Parameters) : Expr, IParame
|
|||
var type = Target.GetType();
|
||||
if (type == parameter.OwnerType)
|
||||
{
|
||||
return Parameters[parameter.Index];
|
||||
return Arguments[parameter.Index];
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -74,9 +89,16 @@ public sealed record Call(Expr Target, IRArray<Expr> Parameters) : Expr, IParame
|
|||
public void ParametersForeach(Action<Expr, ParameterInfo> f)
|
||||
{
|
||||
var parameterInfos = ((Op)Target).Parameters.ToArray();
|
||||
for (int i = 0; i < Parameters.Count; i++)
|
||||
for (int i = 0; i < Arguments.Length; i++)
|
||||
{
|
||||
f(Parameters[i], parameterInfos[i]);
|
||||
f(Arguments[i], parameterInfos[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
|
||||
=> functor.VisitCall(this, context);
|
||||
|
||||
public Call With(Expr? target = null, Expr[]? arguments = null)
|
||||
=> new Call(target ?? Target, arguments ?? Arguments);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
/// <summary>
|
||||
/// the Callable Expr.
|
||||
/// </summary>
|
||||
public abstract class Callable : Expr
|
||||
{
|
||||
/// <summary>
|
||||
/// StackVM module kind.
|
||||
/// </summary>
|
||||
public static readonly string StackVMModuleKind = "stackvm";
|
||||
|
||||
public Callable(string name, string moduleKind, Expr[] operands)
|
||||
: base(operands)
|
||||
{
|
||||
Name = name;
|
||||
ModuleKind = moduleKind;
|
||||
}
|
||||
|
||||
public string Name { get; }
|
||||
|
||||
public string ModuleKind { get; }
|
||||
}
|
|
@ -12,8 +12,20 @@ namespace Nncase.IR;
|
|||
/// <summary>
|
||||
/// Constant expression.
|
||||
/// </summary>
|
||||
public abstract record Const(IRType ValueType) : Expr
|
||||
public abstract class Const : Expr
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="Const"/> class.
|
||||
/// </summary>
|
||||
/// <param name="valueType">Type of value.</param>
|
||||
public Const(IRType valueType)
|
||||
: base(Array.Empty<Expr>())
|
||||
{
|
||||
ValueType = valueType;
|
||||
}
|
||||
|
||||
public IRType ValueType { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Create constant from a <see cref="byte"/>.
|
||||
/// </summary>
|
||||
|
@ -125,7 +137,7 @@ public abstract record Const(IRType ValueType) : Expr
|
|||
else
|
||||
{
|
||||
var tpv = (TupleValue)value;
|
||||
return new TupleConst(tpv.Select(x => FromValue(x)).ToArray());
|
||||
return new TupleConst(tpv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,149 +7,148 @@ using System.Linq;
|
|||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase.IR
|
||||
namespace Nncase.IR;
|
||||
|
||||
/// <summary>
|
||||
/// Conversion of expression.
|
||||
/// </summary>
|
||||
public abstract partial class Expr
|
||||
{
|
||||
/// <summary>
|
||||
/// Conversion of expression.
|
||||
/// Create <see cref="Expr"/> from a <see cref="byte"/>.
|
||||
/// </summary>
|
||||
public abstract partial record Expr
|
||||
{
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="byte"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(byte value) => (Const)value;
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(byte value) => (Const)value;
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="ushort"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(ushort value) => (Const)value;
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="ushort"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(ushort value) => (Const)value;
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="uint"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(uint value) => (Const)value;
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="uint"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(uint value) => (Const)value;
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="ulong"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(ulong value) => (Const)value;
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="ulong"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(ulong value) => (Const)value;
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="sbyte"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(sbyte value) => (Const)value;
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="sbyte"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(sbyte value) => (Const)value;
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="short"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(short value) => (Const)value;
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="short"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(short value) => (Const)value;
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="int"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(int value) => (Const)value;
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="int"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(int value) => (Const)value;
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="long"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(long value) => (Const)value;
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="long"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(long value) => (Const)value;
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="Half"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(Half value) => (Const)value;
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="Half"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(Half value) => (Const)value;
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="float"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(float value) => (Const)value;
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="float"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(float value) => (Const)value;
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="double"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(double value) => (Const)value;
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="double"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(double value) => (Const)value;
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="BFloat16"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(BFloat16 value) => (Const)value;
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="BFloat16"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(BFloat16 value) => (Const)value;
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="bool"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(bool value) => (Const)value;
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="bool"/>.
|
||||
/// </summary>
|
||||
/// <param name="value">Value.</param>
|
||||
public static implicit operator Expr(bool value) => (Const)value;
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="Shape"/>.
|
||||
/// </summary>
|
||||
/// <param name="shape">Shape.</param>
|
||||
public static implicit operator Expr(Shape shape) => Const.FromShape(shape);
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="Shape"/>.
|
||||
/// </summary>
|
||||
/// <param name="shape">Shape.</param>
|
||||
public static implicit operator Expr(Shape shape) => Const.FromShape(shape);
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from an array of<see cref="int"/>.
|
||||
/// </summary>
|
||||
/// <param name="array">Array.</param>
|
||||
public static implicit operator Expr(int[] array) => Tensor.From<int>(array);
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from an array of<see cref="int"/>.
|
||||
/// </summary>
|
||||
/// <param name="array">Array.</param>
|
||||
public static implicit operator Expr(int[] array) => Tensor.From<int>(array);
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from an array of<see cref="int"/>.
|
||||
/// </summary>
|
||||
/// <param name="array">Array.</param>
|
||||
public static implicit operator Expr(long[] array) => Tensor.From<long>(array);
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from an array of<see cref="int"/>.
|
||||
/// </summary>
|
||||
/// <param name="array">Array.</param>
|
||||
public static implicit operator Expr(long[] array) => Tensor.From<long>(array);
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from an array of<see cref="float"/>.
|
||||
/// </summary>
|
||||
/// <param name="array">Array.</param>
|
||||
public static implicit operator Expr(float[] array) => Tensor.From<float>(array);
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from an array of<see cref="float"/>.
|
||||
/// </summary>
|
||||
/// <param name="array">Array.</param>
|
||||
public static implicit operator Expr(float[] array) => Tensor.From<float>(array);
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from an array of<see cref="int"/>.
|
||||
/// </summary>
|
||||
/// <param name="array">Array.</param>
|
||||
public static implicit operator Expr(Array array) => Tensor.FromArray(array);
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from an array of<see cref="int"/>.
|
||||
/// </summary>
|
||||
/// <param name="array">Array.</param>
|
||||
public static implicit operator Expr(Array array) => Tensor.FromArray(array);
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a memory of<see cref="int"/>.
|
||||
/// </summary>
|
||||
/// <param name="memory">Span.</param>
|
||||
public static implicit operator Expr(Memory<int> memory) => Tensor.From<int>(memory);
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a memory of<see cref="int"/>.
|
||||
/// </summary>
|
||||
/// <param name="memory">Span.</param>
|
||||
public static implicit operator Expr(Memory<int> memory) => Tensor.From<int>(memory);
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a memory of<see cref="long"/>.
|
||||
/// </summary>
|
||||
/// <param name="memory">Span.</param>
|
||||
public static implicit operator Expr(Memory<long> memory) => Tensor.From<long>(memory);
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a memory of<see cref="long"/>.
|
||||
/// </summary>
|
||||
/// <param name="memory">Span.</param>
|
||||
public static implicit operator Expr(Memory<long> memory) => Tensor.From<long>(memory);
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a memory of<see cref="float"/>.
|
||||
/// </summary>
|
||||
/// <param name="memory">Span.</param>
|
||||
public static implicit operator Expr(Memory<float> memory) => Tensor.From<float>(memory);
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a memory of<see cref="float"/>.
|
||||
/// </summary>
|
||||
/// <param name="memory">Span.</param>
|
||||
public static implicit operator Expr(Memory<float> memory) => Tensor.From<float>(memory);
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="Tensor"/>.
|
||||
/// </summary>
|
||||
/// <param name="tensor">Tensor.</param>
|
||||
public static implicit operator Expr(Tensor tensor) => Const.FromTensor(tensor);
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="Tensor"/>.
|
||||
/// </summary>
|
||||
/// <param name="tensor">Tensor.</param>
|
||||
public static implicit operator Expr(Tensor tensor) => Const.FromTensor(tensor);
|
||||
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="QuantParam"/>.
|
||||
/// </summary>
|
||||
/// <param name="quantParam">QuantParam.</param>
|
||||
public static implicit operator Expr(QuantParam quantParam) => Tensor.FromScalar(quantParam);
|
||||
}
|
||||
/// <summary>
|
||||
/// Create <see cref="Expr"/> from a <see cref="QuantParam"/>.
|
||||
/// </summary>
|
||||
/// <param name="quantParam">QuantParam.</param>
|
||||
public static implicit operator Expr(QuantParam quantParam) => Tensor.FromScalar(quantParam);
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ namespace Nncase.IR;
|
|||
/// <summary>
|
||||
/// Math operators for <see cref="Expr"/>.
|
||||
/// </summary>
|
||||
public partial record Expr
|
||||
public partial class Expr
|
||||
{
|
||||
/// <summary>
|
||||
/// get the item from the expr.
|
|
@ -4,27 +4,74 @@
|
|||
using System;
|
||||
using System.Collections;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Diagnostics;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Toolkit.HighPerformance.Helpers;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
/// <summary>
|
||||
/// Expression.
|
||||
/// </summary>
|
||||
public abstract partial record Expr
|
||||
public abstract partial class Expr : IDisposable
|
||||
{
|
||||
private readonly Expr[] _operands;
|
||||
private readonly HashSet<Expr> _users = new(ReferenceEqualityComparer.Instance);
|
||||
|
||||
private IRType? _checkedType;
|
||||
private int? _hashCodeCache;
|
||||
private bool _disposedValue;
|
||||
|
||||
internal Expr(IEnumerable<Expr> operands)
|
||||
{
|
||||
_operands = operands.ToArray();
|
||||
foreach (var operand in _operands)
|
||||
{
|
||||
operand.AddUser(this);
|
||||
}
|
||||
}
|
||||
|
||||
internal Expr(Expr[] operands)
|
||||
{
|
||||
_operands = operands;
|
||||
foreach (var operand in _operands)
|
||||
{
|
||||
operand.AddUser(this);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets checked type.
|
||||
/// </summary>
|
||||
public IRType? CheckedType { get; set; }
|
||||
public IRType CheckedType
|
||||
{
|
||||
get
|
||||
{
|
||||
if (_checkedType == null)
|
||||
{
|
||||
CompilerServices.InferenceType(this);
|
||||
}
|
||||
|
||||
Trace.Assert(_checkedType is not null);
|
||||
return _checkedType!;
|
||||
}
|
||||
|
||||
set
|
||||
{
|
||||
if (_checkedType != value)
|
||||
{
|
||||
_checkedType = value;
|
||||
InvalidateUsersTypeInference();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets checked shape.
|
||||
/// </summary>
|
||||
public Shape CheckedShape => (CheckedType ?? ((Const)this).ValueType) switch
|
||||
public Shape CheckedShape => CheckedType switch
|
||||
{
|
||||
TensorType type => type.Shape,
|
||||
_ => throw new InvalidOperationException("Only The Expr Have CheckedType Can Get It's Shape"),
|
||||
|
@ -35,30 +82,220 @@ public abstract partial record Expr
|
|||
/// </summary>
|
||||
public DataType CheckedDataType => CheckedType switch
|
||||
{
|
||||
// todo:more info
|
||||
TensorType type => type.DType,
|
||||
_ => throw new InvalidOperationException("Expr don't have a valid tensor type"),
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets hash code cache.
|
||||
/// Gets users.
|
||||
/// </summary>
|
||||
protected int? HashCodeCache { get; set; }
|
||||
public IReadOnlyCollection<Expr> Users => EnsureAlive()._users;
|
||||
|
||||
/// <summary>
|
||||
/// Gets operands.
|
||||
/// </summary>
|
||||
public ReadOnlySpan<Expr> Operands => EnsureAlive()._operands;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets raw checked type.
|
||||
/// </summary>
|
||||
internal IRType? RawCheckedType
|
||||
{
|
||||
get => _checkedType;
|
||||
set => _checkedType = value;
|
||||
}
|
||||
|
||||
public static bool operator ==(Expr? left, Expr? right) => EqualityComparer<Expr>.Default.Equals(left, right);
|
||||
|
||||
public static bool operator !=(Expr? left, Expr? right) => !(left == right);
|
||||
|
||||
/// <summary>
|
||||
/// Accept a <see cref="ExprFunctor{TExprResult, TTypeResult, TContext}"/>.
|
||||
/// </summary>
|
||||
/// <typeparam name="TExprResult">Result type of visiting expressions.</typeparam>
|
||||
/// <typeparam name="TTypeResult">Result type of visiting types.</typeparam>
|
||||
/// <typeparam name="TContext">Visit context.</typeparam>
|
||||
/// <param name="functor">Expression functor.</param>
|
||||
/// <param name="context">Context.</param>
|
||||
/// <returns>Visit result.</returns>
|
||||
public abstract TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public virtual bool Equals(Expr? other)
|
||||
public override string ToString()
|
||||
{
|
||||
return !(other is null) && EqualityContract == other.EqualityContract;
|
||||
return GetType().ToString();
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override int GetHashCode()
|
||||
public override bool Equals(object? obj)
|
||||
{
|
||||
return HashCodeCache ??= EqualityComparer<Type>.Default.GetHashCode(EqualityContract);
|
||||
if (ReferenceEquals(this, obj))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
return obj is Expr other && GetHashCode() == other.GetHashCode() && Operands.SequenceEqual(other.Operands);
|
||||
}
|
||||
|
||||
protected virtual bool PrintMembers(StringBuilder builder)
|
||||
/// <inheritdoc/>
|
||||
public sealed override int GetHashCode() => _hashCodeCache ??= GetHashCodeCore();
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
Dispose(disposing: true);
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
|
||||
public void DisposeIfNoUsers()
|
||||
{
|
||||
if (_users.Count == 0)
|
||||
{
|
||||
Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
internal void AddUser(Expr user)
|
||||
{
|
||||
EnsureAlive();
|
||||
Trace.Assert(!ReferenceEquals(this, user));
|
||||
_users.Add(user.EnsureAlive());
|
||||
}
|
||||
|
||||
internal void RemoveUser(Expr user)
|
||||
{
|
||||
_users.Remove(user);
|
||||
}
|
||||
|
||||
internal void ReplaceOperand(int index, Expr newOperand)
|
||||
{
|
||||
ref var operand = ref _operands[index];
|
||||
if (!ReferenceEquals(operand, newOperand))
|
||||
{
|
||||
newOperand.AddUser(this);
|
||||
operand.RemoveUser(this);
|
||||
operand = newOperand;
|
||||
OnOperandsReplaced();
|
||||
}
|
||||
}
|
||||
|
||||
internal void ReplaceAllUsesWith(Expr newOperand)
|
||||
=> ReplaceScopedUsesWith(newOperand, null);
|
||||
|
||||
internal void ReplaceScopedUsesWith(Expr newOperand, IReadOnlySet<Expr>? scope)
|
||||
{
|
||||
EnsureAlive();
|
||||
if (!ReferenceEquals(this, newOperand))
|
||||
{
|
||||
foreach (var user in Users.ToArray())
|
||||
{
|
||||
if ((scope is null || scope.Contains(user))
|
||||
&& !newOperand.IsDescendantOf(this))
|
||||
{
|
||||
newOperand.AddUser(user);
|
||||
var operands = user._operands;
|
||||
for (int i = 0; i < operands.Length; i++)
|
||||
{
|
||||
ref var operand = ref operands[i];
|
||||
if (ReferenceEquals(operand, this))
|
||||
{
|
||||
operand = newOperand;
|
||||
}
|
||||
}
|
||||
|
||||
user.OnOperandsReplaced();
|
||||
RemoveUser(user);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected virtual int GetHashCodeCore()
|
||||
{
|
||||
return HashCode.Combine(GetType(), HashCode<Expr>.Combine(Operands));
|
||||
}
|
||||
|
||||
protected virtual void Dispose(bool disposing)
|
||||
{
|
||||
if (!_disposedValue)
|
||||
{
|
||||
foreach (var operand in _operands)
|
||||
{
|
||||
operand.RemoveUser(this);
|
||||
operand.DisposeIfNoUsers();
|
||||
}
|
||||
|
||||
_disposedValue = true;
|
||||
}
|
||||
}
|
||||
|
||||
private bool IsDescendantOf(Expr other)
|
||||
{
|
||||
foreach (var operand in _operands)
|
||||
{
|
||||
if (ReferenceEquals(operand, other))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
foreach (var operand in _operands)
|
||||
{
|
||||
if (operand.IsDescendantOf(other))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private void OnOperandsReplaced()
|
||||
{
|
||||
InvalidateTypeInference();
|
||||
InvalidateHashCodeCache();
|
||||
}
|
||||
|
||||
private void InvalidateTypeInference()
|
||||
{
|
||||
if (_checkedType != null)
|
||||
{
|
||||
_checkedType = null;
|
||||
InvalidateUsersTypeInference();
|
||||
}
|
||||
}
|
||||
|
||||
private void InvalidateUsersTypeInference()
|
||||
{
|
||||
foreach (var user in Users)
|
||||
{
|
||||
user.InvalidateTypeInference();
|
||||
}
|
||||
}
|
||||
|
||||
private void InvalidateHashCodeCache()
|
||||
{
|
||||
if (_hashCodeCache != null)
|
||||
{
|
||||
_hashCodeCache = null;
|
||||
InvalidateUsersHashCodeCache();
|
||||
}
|
||||
}
|
||||
|
||||
private void InvalidateUsersHashCodeCache()
|
||||
{
|
||||
foreach (var user in Users)
|
||||
{
|
||||
user.InvalidateHashCodeCache();
|
||||
}
|
||||
}
|
||||
|
||||
private Expr EnsureAlive()
|
||||
{
|
||||
if (_disposedValue)
|
||||
{
|
||||
throw new ObjectDisposedException(null);
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reactive;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.TIR;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
public static class ExprClonerExtensions
|
||||
{
|
||||
public static T Clone<T>(this T expr, bool cloneOtherFunctions = false)
|
||||
where T : Expr
|
||||
{
|
||||
return new ExprCloner<Unit>(cloneOtherFunctions).Clone(expr, default);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Expression cloner.
|
||||
/// </summary>
|
||||
/// <typeparam name="TContext">Clone context.</typeparam>
|
||||
public partial class ExprCloner<TContext> : ExprVisitor<Expr, IRType, TContext>
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ExprCloner{TContext}"/> class.
|
||||
/// </summary>
|
||||
/// <param name="cloneOtherFunctions">Clone other functions.</param>
|
||||
public ExprCloner(bool cloneOtherFunctions = false)
|
||||
: base(cloneOtherFunctions)
|
||||
{
|
||||
}
|
||||
|
||||
public T Clone<T>(T expr, TContext context)
|
||||
where T : Expr
|
||||
=> (T)Visit(expr, context);
|
||||
|
||||
protected T[] CloneArray<T>(ReadOnlySpan<T> values, TContext context)
|
||||
where T : Expr
|
||||
{
|
||||
var array = new T[values.Length];
|
||||
for (int i = 0; i < values.Length; i++)
|
||||
{
|
||||
array[i] = Clone(values[i], context);
|
||||
}
|
||||
|
||||
return array;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,258 @@
|
|||
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
// <auto-generated>
|
||||
// This code was generated by T4 template.
|
||||
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
|
||||
// </auto-generated>
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Reactive;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
public partial class ExprCloner<TContext>
|
||||
{
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafCall(Call expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
target: Clone(expr.Target, context),
|
||||
arguments: CloneArray(expr.Arguments, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafFunction(Function expr, TContext context)
|
||||
{
|
||||
if (!CanVisitFunctionBody(expr))
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr.With(
|
||||
parameters: CloneArray(expr.Parameters, context),
|
||||
body: Clone(expr.Body, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafFusion(Fusion expr, TContext context)
|
||||
{
|
||||
if (!CanVisitFunctionBody(expr))
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr.With(
|
||||
parameters: CloneArray(expr.Parameters, context),
|
||||
body: Clone(expr.Body, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafIf(If expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
condition: Clone(expr.Condition, context),
|
||||
then: Clone(expr.Then, context),
|
||||
@else: Clone(expr.Else, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafMarker(Marker expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
target: Clone(expr.Target, context),
|
||||
attribute: Clone(expr.Attribute, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafNone(None expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafOp(Op expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, TContext context)
|
||||
{
|
||||
if (!CanVisitFunctionBody(expr))
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr.With(
|
||||
target: Clone(expr.Target, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafTensorConst(TensorConst expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafTuple(IR.Tuple expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
fields: CloneArray(expr.Fields, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafTupleConst(TupleConst expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafVar(Var expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafBlock(TIR.Block expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
body: Clone(expr.Body, context),
|
||||
initBody: Clone(expr.InitBody, context),
|
||||
iterVars: CloneArray(expr.IterVars, context),
|
||||
reads: CloneArray(expr.Reads, context),
|
||||
writes: CloneArray(expr.Writes, context),
|
||||
allocBuffers: CloneArray(expr.AllocBuffers, context),
|
||||
predicate: Clone(expr.Predicate, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
dimensions: CloneArray(expr.Dimensions, context),
|
||||
strides: CloneArray(expr.Strides, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
buffer: Clone(expr.Buffer, context),
|
||||
indices: CloneArray(expr.Indices, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
buffer: Clone(expr.Buffer, context),
|
||||
region: CloneArray(expr.Region, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafBufferStore(TIR.BufferStore expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
buffer: Clone(expr.Buffer, context),
|
||||
indices: CloneArray(expr.Indices, context),
|
||||
value: Clone(expr.Value, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafFor(TIR.For expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
loopVar: Clone(expr.LoopVar, context),
|
||||
domain: Clone(expr.Domain, context),
|
||||
body: Clone(expr.Body, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafIfThenElse(TIR.IfThenElse expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
condition: Clone(expr.Condition, context),
|
||||
then: Clone(expr.Then, context),
|
||||
@else: Clone(expr.Else, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafLet(TIR.Let expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
var: Clone(expr.Var, context),
|
||||
expression: Clone(expr.Expression, context),
|
||||
body: Clone(expr.Body, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafPrimFunction(TIR.PrimFunction expr, TContext context)
|
||||
{
|
||||
if (!CanVisitFunctionBody(expr))
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr.With(
|
||||
parameters: CloneArray(expr.Parameters, context),
|
||||
body: Clone(expr.Body, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafSequential(TIR.Sequential expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
fields: CloneArray(expr.Fields, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafRange(TIR.Range expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
start: Clone(expr.Start, context),
|
||||
stop: Clone(expr.Stop, context),
|
||||
step: Clone(expr.Step, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafIterVar(TIR.IterVar expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
value: Clone(expr.Value, context),
|
||||
dom: Clone(expr.Dom, context)
|
||||
);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
<#@ template debug="false" hostspecific="false" language="C#" #>
|
||||
<#@ assembly name="System.Core" #>
|
||||
<#@ import namespace="System.IO" #>
|
||||
<#@ import namespace="System.Linq" #>
|
||||
<#@ import namespace="System.Text" #>
|
||||
<#@ import namespace="System.Collections.Generic" #>
|
||||
<#@ output extension=".cs" #>
|
||||
<#@ include file="IRListParser.tt"#>
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
// <auto-generated>
|
||||
// This code was generated by T4 template.
|
||||
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
|
||||
// </auto-generated>
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Reactive;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
public partial class ExprCloner<TContext>
|
||||
{
|
||||
<#
|
||||
foreach (var ir in irs.Where(x => x.IsDerived))
|
||||
{
|
||||
#>
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, TContext context)
|
||||
{
|
||||
<#
|
||||
if (ir.IsFunction)
|
||||
{
|
||||
#>
|
||||
if (!CanVisitFunctionBody(expr))
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
<#
|
||||
}
|
||||
#>
|
||||
return expr.With(
|
||||
<#
|
||||
for (int i = 0; i < ir.Fields.Length; i++)
|
||||
{
|
||||
var field = ir.Fields[i];
|
||||
var func = field.StartsWith("@") ? "CloneArray" : "Clone";
|
||||
var fieldName = field.TrimStart('@');
|
||||
var paramName = $"{char.ToLowerInvariant(fieldName[0])}{fieldName.Substring(1)}";
|
||||
if (paramName == "else")
|
||||
{
|
||||
paramName = "@" + paramName;
|
||||
}
|
||||
#>
|
||||
<#=paramName#>: <#=func#>(expr.<#=fieldName#>, context)<#=i == ir.Fields.Length - 1 ? string.Empty : ","#>
|
||||
<#
|
||||
}
|
||||
#>
|
||||
);
|
||||
}
|
||||
|
||||
<#
|
||||
}
|
||||
#>
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reactive;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.TIR;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
public sealed class ExprCollector : ExprWalker<List<Expr>>
|
||||
{
|
||||
private ExprCollector()
|
||||
{
|
||||
}
|
||||
|
||||
public static IReadOnlyList<Expr> Collect(Expr expr)
|
||||
{
|
||||
var exprs = new List<Expr>();
|
||||
new ExprCollector().Visit(expr, exprs);
|
||||
return exprs;
|
||||
}
|
||||
|
||||
protected override Unit DefaultVisitLeaf(Expr expr, List<Expr> context)
|
||||
{
|
||||
context.Add(expr);
|
||||
return default;
|
||||
}
|
||||
}
|
|
@ -3,238 +3,159 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Linq;
|
||||
using System.Reactive;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase.IR
|
||||
namespace Nncase.IR;
|
||||
|
||||
/// <summary>
|
||||
/// Expression functor.
|
||||
/// </summary>
|
||||
/// <typeparam name="TExprResult">Expression visit result type.</typeparam>
|
||||
/// <typeparam name="TTypeResult">Type visit result type.</typeparam>
|
||||
/// <typeparam name="TContext">Visit context.</typeparam>
|
||||
public abstract partial class ExprFunctor<TExprResult, TTypeResult, TContext> : TypeFunctor<TTypeResult, TContext>
|
||||
{
|
||||
/// <summary>
|
||||
/// Expression functor.
|
||||
/// Gets visit root.
|
||||
/// </summary>
|
||||
/// <typeparam name="TExprResult">Expression visit result type.</typeparam>
|
||||
/// <typeparam name="TTypeResult">Type visit result type.</typeparam>
|
||||
public abstract class ExprFunctor<TExprResult, TTypeResult> : TypeFunctor<TTypeResult>
|
||||
protected Expr? VisitRoot { get; private set; }
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="Expr"/>.
|
||||
/// </summary>
|
||||
public TExprResult Visit(Expr expr, TContext context)
|
||||
{
|
||||
/// <summary>
|
||||
/// Visit expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(Expr expr)
|
||||
{
|
||||
return expr switch
|
||||
{
|
||||
Var var => Visit(var),
|
||||
Const con => Visit(con),
|
||||
Call call => Visit(call),
|
||||
If @if => Visit(@if),
|
||||
Tuple tuple => Visit(tuple),
|
||||
Op op => Visit(op),
|
||||
None none => Visit(none),
|
||||
Marker marker => Visit(marker),
|
||||
BaseFunction basefunc => Visit(basefunc),
|
||||
TIR.IterVar itvar => Visit(itvar),
|
||||
TIR.Sequential seq => Visit(seq),
|
||||
TIR.For @for => Visit(@for),
|
||||
TIR.Block block => Visit(block),
|
||||
TIR.BufferLoad bload => Visit(bload),
|
||||
TIR.BufferStore bstore => Visit(bstore),
|
||||
TIR.IfThenElse ift => Visit(ift),
|
||||
TIR.Let let => Visit(let),
|
||||
TIR.Buffer buffer => Visit(buffer),
|
||||
TIR.BufferRegion region => Visit(region),
|
||||
_ => DefaultVisit(expr),
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Visit Basefunction expression.
|
||||
/// </summary>
|
||||
/// <param name="baseFunction"> base function. </param>
|
||||
public virtual TExprResult Visit(BaseFunction baseFunction) => baseFunction switch
|
||||
{
|
||||
Function func => Visit(func),
|
||||
Fusion fusion => Visit(fusion),
|
||||
PrimFunctionWrapper wrapper => Visit(wrapper),
|
||||
TIR.PrimFunction primfunc => Visit(primfunc),
|
||||
_ => DefaultVisit(baseFunction),
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Visit variable expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Variable expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(Var expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit constant expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Constant expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(Const expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit function expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Variable expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(Function expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit fusion expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Fusion Expression.</param>
|
||||
public virtual TExprResult Visit(Fusion expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit prim function wrapper expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">PrimFunctionWrapper expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(PrimFunctionWrapper expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit prim function expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Variable expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(TIR.PrimFunction expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit call expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Call expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(Call expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit if expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">If expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(If expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit tuple expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Variable expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(Tuple expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit operator expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Operator expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(Op expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit None expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">None expr.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(None expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit marker expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Marker expr.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(Marker expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit IterVar expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">IterVar expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(TIR.IterVar expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit Sequential expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Sequential expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(TIR.Sequential expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit For expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">For expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(TIR.For expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit block expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">block expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(TIR.Block expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit BufferLoad expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">BufferLoad expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(TIR.BufferLoad expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit BufferStore expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">BufferStore expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(TIR.BufferStore expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit IfThenElse expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">IfThenElse expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(TIR.IfThenElse expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit Let expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Let expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(TIR.Let expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit MemRef expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">MemRef expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(TIR.Buffer expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit buffer region expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">buffer region expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult Visit(TIR.BufferRegion expr) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit visitable.
|
||||
/// </summary>
|
||||
public virtual object Visit(IVisitable visitable) => DefaultVisit(visitable);
|
||||
|
||||
/// <summary>
|
||||
/// Default Visit IVisitable.
|
||||
/// </summary>
|
||||
public virtual object DefaultVisit(IVisitable visitable)
|
||||
{
|
||||
return visitable.Visit<TExprResult, TTypeResult>(this);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Default visit routine.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult DefaultVisit(Expr expr)
|
||||
{
|
||||
throw new NotImplementedException($"Unhandled visit routine for {expr.GetType()}.");
|
||||
}
|
||||
VisitRoot ??= expr;
|
||||
return DispatchVisit(expr, context);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Clear functor states.
|
||||
/// </summary>
|
||||
public virtual void Clear()
|
||||
{
|
||||
VisitRoot = null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Default visit routine.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression.</param>
|
||||
/// <param name="context">Context.</param>
|
||||
/// <returns>Result.</returns>
|
||||
protected internal virtual TExprResult DefaultVisit(Expr expr, TContext context)
|
||||
{
|
||||
throw new NotImplementedException($"Unhandled visit routine for {expr.GetType()}.");
|
||||
}
|
||||
|
||||
protected virtual TExprResult DispatchVisit(Expr expr, TContext context) => expr.Accept(this, context);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Expression functor.
|
||||
/// </summary>
|
||||
/// <typeparam name="TExprResult">Expression visit result type.</typeparam>
|
||||
/// <typeparam name="TTypeResult">Type visit result type.</typeparam>
|
||||
public partial class ExprFunctor<TExprResult, TTypeResult> : ExprFunctor<TExprResult, TTypeResult, Unit>
|
||||
{
|
||||
/// <summary>
|
||||
/// Visit <see cref="Expr"/>.
|
||||
/// </summary>
|
||||
public TExprResult Visit(Expr expr) => Visit(expr, default);
|
||||
|
||||
/// <summary>
|
||||
/// Visit type.
|
||||
/// </summary>
|
||||
/// <param name="type">Type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public TTypeResult VisitType(IRType type) => VisitType(type, default);
|
||||
|
||||
/// <summary>
|
||||
/// Visit any type.
|
||||
/// </summary>
|
||||
/// <param name="type">Any type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult VisitType(AnyType type) => base.VisitType(type, default);
|
||||
|
||||
/// <summary>
|
||||
/// Visit None type.
|
||||
/// </summary>
|
||||
/// <param name="type">None type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult VisitType(NoneType type) => base.VisitType(type, default);
|
||||
|
||||
/// <summary>
|
||||
/// Visit invalid type.
|
||||
/// </summary>
|
||||
/// <param name="type">Invalid type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult VisitType(InvalidType type) => base.VisitType(type, default);
|
||||
|
||||
/// <summary>
|
||||
/// Visit tensor type.
|
||||
/// </summary>
|
||||
/// <param name="type">Tensor type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult VisitType(TensorType type) => base.VisitType(type, default);
|
||||
|
||||
/// <summary>
|
||||
/// Visit tuple type.
|
||||
/// </summary>
|
||||
/// <param name="type">Tuple type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult VisitType(TupleType type) => base.VisitType(type, default);
|
||||
|
||||
/// <summary>
|
||||
/// Visit callable type.
|
||||
/// </summary>
|
||||
/// <param name="type">Callable type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult VisitType(CallableType type) => base.VisitType(type, default);
|
||||
|
||||
/// <summary>
|
||||
/// Default visit routine.
|
||||
/// </summary>
|
||||
/// <param name="type">Type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult DefaultVisitType(IRType type) => base.DefaultVisitType(type, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(AnyType type, Unit context) => VisitType(type);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(NoneType type, Unit context) => VisitType(type);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(InvalidType type, Unit context) => VisitType(type);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(TensorType type, Unit context) => VisitType(type);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(TupleType type, Unit context) => VisitType(type);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(CallableType type, Unit context) => VisitType(type);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult DefaultVisitType(IRType type, Unit context) => DefaultVisitType(type);
|
||||
|
||||
/// <summary>
|
||||
/// Default visit routine.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
protected internal virtual TExprResult DefaultVisit(Expr expr) => base.DefaultVisit(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected internal sealed override TExprResult DefaultVisit(Expr expr, Unit context) => DefaultVisit(expr);
|
||||
|
||||
protected virtual TExprResult DispatchVisit(Expr expr) => base.DispatchVisit(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult DispatchVisit(Expr expr, Unit context) => DispatchVisit(expr);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,357 @@
|
|||
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
// <auto-generated>
|
||||
// This code was generated by T4 template.
|
||||
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
|
||||
// </auto-generated>
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Reactive;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
public partial class ExprFunctor<TExprResult, TTypeResult, TContext>
|
||||
{
|
||||
/// <summary>
|
||||
/// Visit <see cref="BaseFunction"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBaseFunction(BaseFunction expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="Call"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitCall(Call expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="Const"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitConst(Const expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="Function"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitFunction(Function expr, TContext context) => VisitBaseFunction(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="Fusion"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitFusion(Fusion expr, TContext context) => VisitBaseFunction(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="If"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitIf(If expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="Marker"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitMarker(Marker expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="None"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitNone(None expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="Op"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitOp(Op expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="PrimFunctionWrapper"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, TContext context) => VisitBaseFunction(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TensorConst"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitTensorConst(TensorConst expr, TContext context) => VisitConst(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="IR.Tuple"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitTuple(IR.Tuple expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TupleConst"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitTupleConst(TupleConst expr, TContext context) => VisitConst(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="Var"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitVar(Var expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Block"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBlock(TIR.Block expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Buffer"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBuffer(TIR.Buffer expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.LogicalBuffer"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => VisitBuffer(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitBuffer(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferLoad"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferRegion"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferStore"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.For"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitFor(TIR.For expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.IfThenElse"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitIfThenElse(TIR.IfThenElse expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Let"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitLet(TIR.Let expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.PrimFunction"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitPrimFunction(TIR.PrimFunction expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Sequential"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitSequential(TIR.Sequential expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Range"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitRange(TIR.Range expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.IterVar"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitIterVar(TIR.IterVar expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
}
|
||||
|
||||
public partial class ExprFunctor<TExprResult, TTypeResult>
|
||||
{
|
||||
/// <summary>
|
||||
/// Visit <see cref="BaseFunction"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBaseFunction(BaseFunction expr) => base.VisitBaseFunction(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBaseFunction(BaseFunction expr, Unit context) => VisitBaseFunction(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Call"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitCall(Call expr) => base.VisitCall(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitCall(Call expr, Unit context) => VisitCall(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Const"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitConst(Const expr) => base.VisitConst(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitConst(Const expr, Unit context) => VisitConst(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Function"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitFunction(Function expr) => base.VisitFunction(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitFunction(Function expr, Unit context) => VisitFunction(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Fusion"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitFusion(Fusion expr) => base.VisitFusion(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitFusion(Fusion expr, Unit context) => VisitFusion(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="If"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitIf(If expr) => base.VisitIf(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitIf(If expr, Unit context) => VisitIf(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Marker"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitMarker(Marker expr) => base.VisitMarker(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitMarker(Marker expr, Unit context) => VisitMarker(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="None"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitNone(None expr) => base.VisitNone(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitNone(None expr, Unit context) => VisitNone(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Op"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitOp(Op expr) => base.VisitOp(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitOp(Op expr, Unit context) => VisitOp(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="PrimFunctionWrapper"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitPrimFunctionWrapper(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitPrimFunctionWrapper(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TensorConst"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitTensorConst(TensorConst expr) => base.VisitTensorConst(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="IR.Tuple"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitTuple(IR.Tuple expr) => base.VisitTuple(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitTuple(IR.Tuple expr, Unit context) => VisitTuple(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TupleConst"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitTupleConst(TupleConst expr) => base.VisitTupleConst(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitTupleConst(TupleConst expr, Unit context) => VisitTupleConst(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Var"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitVar(Var expr) => base.VisitVar(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitVar(Var expr, Unit context) => VisitVar(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Block"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBlock(TIR.Block expr) => base.VisitBlock(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBlock(TIR.Block expr, Unit context) => VisitBlock(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Buffer"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBuffer(TIR.Buffer expr) => base.VisitBuffer(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBuffer(TIR.Buffer expr, Unit context) => VisitBuffer(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.LogicalBuffer"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLogicalBuffer(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLogicalBuffer(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitPhysicalBuffer(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferLoad"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr) => base.VisitBufferLoad(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBufferLoad(TIR.BufferLoad expr, Unit context) => VisitBufferLoad(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferRegion"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBufferRegion(TIR.BufferRegion expr, Unit context) => VisitBufferRegion(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferStore"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr) => base.VisitBufferStore(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBufferStore(TIR.BufferStore expr, Unit context) => VisitBufferStore(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.For"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitFor(TIR.For expr) => base.VisitFor(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitFor(TIR.For expr, Unit context) => VisitFor(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.IfThenElse"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitIfThenElse(TIR.IfThenElse expr) => base.VisitIfThenElse(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitIfThenElse(TIR.IfThenElse expr, Unit context) => VisitIfThenElse(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Let"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitLet(TIR.Let expr) => base.VisitLet(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitLet(TIR.Let expr, Unit context) => VisitLet(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.PrimFunction"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitPrimFunction(TIR.PrimFunction expr) => base.VisitPrimFunction(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitPrimFunction(TIR.PrimFunction expr, Unit context) => VisitPrimFunction(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Sequential"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitSequential(TIR.Sequential expr) => base.VisitSequential(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitSequential(TIR.Sequential expr, Unit context) => VisitSequential(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Range"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitRange(TIR.Range expr) => base.VisitRange(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitRange(TIR.Range expr, Unit context) => VisitRange(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.IterVar"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitIterVar(TIR.IterVar expr) => base.VisitIterVar(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitIterVar(TIR.IterVar expr, Unit context) => VisitIterVar(expr);
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
<#@ template debug="false" hostspecific="false" language="C#" #>
|
||||
<#@ assembly name="System.Core" #>
|
||||
<#@ import namespace="System.IO" #>
|
||||
<#@ import namespace="System.Linq" #>
|
||||
<#@ import namespace="System.Text" #>
|
||||
<#@ import namespace="System.Collections.Generic" #>
|
||||
<#@ output extension=".cs" #>
|
||||
<#@ include file="IRListParser.tt"#>
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
// <auto-generated>
|
||||
// This code was generated by T4 template.
|
||||
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
|
||||
// </auto-generated>
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Reactive;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
public partial class ExprFunctor<TExprResult, TTypeResult, TContext>
|
||||
{
|
||||
<#
|
||||
foreach (var ir in irs)
|
||||
{
|
||||
var func = ir.VisitBase == "Default" ? "DefaultVisit" : $"Visit{ir.VisitBase}";
|
||||
#>
|
||||
/// <summary>
|
||||
/// Visit <see cref="<#=ir.Namespace#><#=ir.Name#>"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult Visit<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, TContext context) => <#=func#>(expr, context);
|
||||
|
||||
<#
|
||||
}
|
||||
#>
|
||||
}
|
||||
|
||||
public partial class ExprFunctor<TExprResult, TTypeResult>
|
||||
{
|
||||
<#
|
||||
foreach (var ir in irs)
|
||||
{
|
||||
#>
|
||||
/// <summary>
|
||||
/// Visit <see cref="<#=ir.Namespace#><#=ir.Name#>"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult Visit<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr) => base.Visit<#=ir.Name#>(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult Visit<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, Unit context) => Visit<#=ir.Name#>(expr);
|
||||
<#
|
||||
}
|
||||
#>
|
||||
}
|
|
@ -1,781 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Collections.Immutable;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
/// <summary>
|
||||
/// IVisitable interface for the custom class visit leaf.
|
||||
/// </summary>
|
||||
public interface IVisitable
|
||||
{
|
||||
/// <summary>
|
||||
/// accept the visit.
|
||||
/// </summary>
|
||||
object Visit<TExprResult, TTypeResult>(ExprFunctor<TExprResult, TTypeResult> functor);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// IMutatable Define.
|
||||
/// </summary>
|
||||
public interface IMutatable : IVisitable
|
||||
{
|
||||
#if false
|
||||
/// <summary>
|
||||
/// mutate the current object.
|
||||
/// NOTE In order to ensure the consistency of coding, please return a new object.
|
||||
/// </summary>
|
||||
/// <param name="mutator">ExprMutator.</param>
|
||||
/// <returns> new instance. </returns>
|
||||
// object MutateLeaf(ExprMutator mutator);
|
||||
#endif
|
||||
|
||||
/// <summary>
|
||||
/// recursive build new object.
|
||||
/// </summary>
|
||||
/// <param name="mutator">ExprMutator.</param>
|
||||
object WithNew(ExprVisitor<Expr, IRType> mutator);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deep Expression matutor.
|
||||
/// </summary>
|
||||
public abstract class DeepExprMutator : ExprVisitor<Expr, IRType>
|
||||
{
|
||||
/// <summary>
|
||||
/// The Struct Equal Memo folding the const/op.
|
||||
/// </summary>
|
||||
private readonly Dictionary<Expr, Expr> _exprSEqualMemo = new();
|
||||
|
||||
/// <summary>
|
||||
/// Gets the Struct Equal Memo.
|
||||
/// </summary>
|
||||
public Dictionary<Expr, Expr> ExpressionStructMemo => _exprSEqualMemo;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets a value indicating whether for speedup the Mutator, If is Mutated we need MutateLeaf recursive.
|
||||
/// </summary>
|
||||
public bool IsMutated { get; set; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(Call expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
Target = Visit(expr.Target),
|
||||
Parameters = MutateArray(expr.Parameters, Visit),
|
||||
};
|
||||
}
|
||||
|
||||
public override Expr VisitLeaf(If expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
Condition = Visit(expr.Condition),
|
||||
Then = Visit(expr.Then),
|
||||
Else = Visit(expr.Else),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(Const expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return StructEqualFolding(expr);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(Function expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
Body = Visit(expr.Body),
|
||||
Parameters = new(expr.Parameters.Select(x => (Var)Visit(x))),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(Fusion expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
Body = Visit(expr.Body),
|
||||
Parameters = new(expr.Parameters.Select(x => (Var)Visit(x))),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(PrimFunctionWrapper expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
Target = (TIR.PrimFunction)Visit((IR.BaseFunction)expr.Target),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(TIR.PrimFunction expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
Body = (TIR.Sequential)Visit(expr.Body),
|
||||
Parameters = new(expr.Parameters.Select(x => (TIR.PhysicalBuffer)Visit(x))),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(Op expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(Tuple expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
Fields = MutateArray(expr.Fields, Visit),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(Var expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(None expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(Marker expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
Target = Visit(expr.Target),
|
||||
Attribute = Visit(expr.Attribute),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(TIR.IterVar expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
Dom = (TIR.Range)Visit(expr.Dom),
|
||||
Value = (Var)Visit(expr.Value),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(TIR.Sequential expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
Fields = MutateArray(expr.Fields, Visit),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(TIR.For expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
LoopVar = (Var)Visit(expr.LoopVar),
|
||||
Domain = (TIR.Range)Visit(expr.Domain),
|
||||
Body = (TIR.Sequential)Visit(expr.Body),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(TIR.IfThenElse expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
Condition = Visit(expr.Condition),
|
||||
Then = (TIR.Sequential)Visit(expr.Then),
|
||||
Else = (TIR.Sequential)Visit(expr.Else),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(TIR.Block expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
// the block realize
|
||||
InitBody = expr.InitBody.Fields.IsDefaultOrEmpty ? expr.InitBody : (TIR.Sequential)Visit(expr.InitBody),
|
||||
Predicate = Visit(expr.Predicate),
|
||||
IterVars = expr.IterVars.IsDefaultOrEmpty ? expr.IterVars : MutateArray(expr.IterVars, x => (TIR.IterVar)Visit(x)),
|
||||
|
||||
// the block internal.
|
||||
Body = expr.Body.Fields.IsDefaultOrEmpty ? expr.Body : (TIR.Sequential)Visit(expr.Body),
|
||||
Reads = expr.Reads.IsDefaultOrEmpty ? expr.Reads : MutateArray(expr.Reads, b => (TIR.BufferRegion)Visit(b)),
|
||||
Writes = expr.Writes.IsDefaultOrEmpty ? expr.Writes : MutateArray(expr.Writes, b => (TIR.BufferRegion)Visit(b)),
|
||||
AllocBuffers = expr.AllocBuffers.IsDefaultOrEmpty ? expr.AllocBuffers : MutateArray(expr.AllocBuffers, b => (TIR.Buffer)Visit(b)),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(TIR.BufferStore expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
Value = Visit(expr.Value),
|
||||
Indices = MutateArray(expr.Indices, Visit),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(TIR.BufferLoad expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
Indices = MutateArray(expr.Indices, Visit),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(TIR.Let expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
Var = (Var)Visit(expr.Var),
|
||||
Expression = Visit(expr.Expression),
|
||||
Body = (TIR.Sequential)Visit(expr.Body),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(TIR.Buffer expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr VisitLeaf(TIR.BufferRegion expr)
|
||||
{
|
||||
var nexpr = MutateLeaf(expr);
|
||||
if (!object.ReferenceEquals(expr, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
|
||||
return expr with
|
||||
{
|
||||
Buffer = (TIR.Buffer)Visit(expr.Buffer),
|
||||
Region = MutateArray(expr.Region, rg => (TIR.Range)Visit(rg)),
|
||||
};
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override object VisitLeaf(IVisitable visitable)
|
||||
{
|
||||
if (visitable is IMutatable mutatable)
|
||||
{
|
||||
var nexpr = MutateLeaf(mutatable);
|
||||
if (!object.ReferenceEquals(mutatable, nexpr))
|
||||
{
|
||||
IsMutated = true;
|
||||
return nexpr;
|
||||
}
|
||||
|
||||
if (!IsMutated)
|
||||
{
|
||||
return mutatable;
|
||||
}
|
||||
|
||||
return mutatable.WithNew(this);
|
||||
}
|
||||
|
||||
throw new NotSupportedException($"IVisitable {visitable.GetType().Name} Is Not IMutatable!");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// defulat mutate leaf is not mutate.
|
||||
/// </summary>
|
||||
public virtual Expr DefaultMutateLeaf(Expr expr) => expr;
|
||||
|
||||
/// <summary>
|
||||
/// default mutate leaf is not mutate.
|
||||
/// </summary>
|
||||
public virtual IMutatable DefaultMutateLeaf(IMutatable mutatable) => mutatable;
|
||||
|
||||
/// <summary>
|
||||
/// mutate the call.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(Call expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the if.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(If expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the const.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(Const expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the function.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(Function expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the fusion.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(Fusion expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the prim function wrapper.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(PrimFunctionWrapper expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the prim function.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(TIR.PrimFunction expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the op.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(Op expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the tuple.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(Tuple expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the var.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(Var expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the var.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(None expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the marker.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(Marker expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the itervar.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(TIR.IterVar expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the sequential.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(TIR.Sequential expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the for.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(TIR.For expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the for.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(TIR.IfThenElse expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the block.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(TIR.Block expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the bufferstore.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(TIR.BufferStore expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the buffer load.
|
||||
/// </summary>
|
||||
public virtual Expr MutateLeaf(TIR.BufferLoad expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the let.
|
||||
/// </summary>
|
||||
/// <param name="expr">let expr.</param>
|
||||
/// <returns>new expr.</returns>
|
||||
public virtual Expr MutateLeaf(TIR.Let expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the memref.
|
||||
/// </summary>
|
||||
/// <param name="expr">new memref.</param>
|
||||
/// <returns>new expr.</returns>
|
||||
public virtual Expr MutateLeaf(TIR.Buffer expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the buffer region.
|
||||
/// </summary>
|
||||
/// <param name="expr">new memref.</param>
|
||||
/// <returns>new expr.</returns>
|
||||
public virtual Expr MutateLeaf(TIR.BufferRegion expr) => DefaultMutateLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// mutate the imutatable.
|
||||
/// </summary>
|
||||
/// <param name="mutatable">IMutatable instance.</param>
|
||||
/// <returns>new expr.</returns>
|
||||
public virtual IMutatable MutateLeaf(IMutatable mutatable) => DefaultMutateLeaf(mutatable);
|
||||
|
||||
/// <summary>
|
||||
/// Mutate IRArray.
|
||||
/// </summary>
|
||||
public virtual IRArray<TResult> MutateArray<TInput, TResult>(IRArray<TInput> array, Func<TInput, TResult> visitor)
|
||||
{
|
||||
return new(array.Select(visitor));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// fold the expr by struct comparer.
|
||||
/// </summary>
|
||||
public virtual Expr StructEqualFolding(Expr expr)
|
||||
{
|
||||
if (!_exprSEqualMemo.TryGetValue(expr, out var folded))
|
||||
{
|
||||
folded = expr;
|
||||
_exprSEqualMemo.Add(expr, folded);
|
||||
}
|
||||
|
||||
return folded;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// NOTE the mutator only visit the only one basefunction and skip other basefunction.
|
||||
/// </summary>
|
||||
public abstract class ExprMutator : DeepExprMutator
|
||||
{
|
||||
private BaseFunction? _entryBaseFunc;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr Visit(BaseFunction baseFunction)
|
||||
{
|
||||
if (_entryBaseFunc is null)
|
||||
{
|
||||
_entryBaseFunc = baseFunction;
|
||||
}
|
||||
|
||||
return base.Visit(baseFunction);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr Visit(Fusion expr)
|
||||
{
|
||||
if (_entryBaseFunc is null)
|
||||
{
|
||||
_entryBaseFunc = expr;
|
||||
return base.Visit(expr);
|
||||
}
|
||||
|
||||
return object.ReferenceEquals(_entryBaseFunc, expr) ? base.Visit(expr) : expr;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr Visit(Function expr)
|
||||
{
|
||||
if (_entryBaseFunc is null)
|
||||
{
|
||||
_entryBaseFunc = expr;
|
||||
return base.Visit(expr);
|
||||
}
|
||||
|
||||
return object.ReferenceEquals(_entryBaseFunc, expr) ? base.Visit(expr) : expr;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr Visit(PrimFunctionWrapper expr)
|
||||
{
|
||||
if (_entryBaseFunc is null)
|
||||
{
|
||||
_entryBaseFunc = expr;
|
||||
return base.Visit(expr);
|
||||
}
|
||||
|
||||
return object.ReferenceEquals(_entryBaseFunc, expr) ? base.Visit(expr) : expr;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override Expr Visit(TIR.PrimFunction expr)
|
||||
{
|
||||
if (_entryBaseFunc is null)
|
||||
{
|
||||
_entryBaseFunc = expr;
|
||||
return base.Visit(expr);
|
||||
}
|
||||
|
||||
return object.ReferenceEquals(_entryBaseFunc, expr) ? base.Visit(expr) : expr;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
public sealed class ExprPinner : IDisposable
|
||||
{
|
||||
private static readonly ExprUser _user = new();
|
||||
|
||||
private readonly Expr[] _exprs;
|
||||
private bool _disposed;
|
||||
|
||||
public ExprPinner(params Expr[] exprs)
|
||||
{
|
||||
_exprs = exprs;
|
||||
|
||||
foreach (var expr in _exprs)
|
||||
{
|
||||
expr.AddUser(_user);
|
||||
}
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
if (!_disposed)
|
||||
{
|
||||
foreach (var expr in _exprs)
|
||||
{
|
||||
expr.RemoveUser(_user);
|
||||
}
|
||||
|
||||
_disposed = true;
|
||||
}
|
||||
}
|
||||
|
||||
private sealed class ExprUser : Expr
|
||||
{
|
||||
public ExprUser()
|
||||
: base(Array.Empty<Expr>())
|
||||
{
|
||||
}
|
||||
|
||||
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context) => throw new NotSupportedException();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,109 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Linq;
|
||||
using System.Reactive;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using TorchSharp.Modules;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
/// <summary>
|
||||
/// Expression rewriter.
|
||||
/// </summary>
|
||||
/// <typeparam name="TContext">Rewrite context.</typeparam>
|
||||
public abstract partial class ExprRewriter<TContext> : ExprVisitor<Expr, IRType, TContext>
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ExprRewriter{TContext}"/> class.
|
||||
/// </summary>
|
||||
/// <param name="visitOtherFunctions">Vist other functions.</param>
|
||||
public ExprRewriter(bool visitOtherFunctions = false)
|
||||
: base(visitOtherFunctions)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether expression is mutated.
|
||||
/// </summary>
|
||||
public bool IsMutated { get; private set; }
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression to rewrite.</param>
|
||||
/// <param name="context">Context.</param>
|
||||
/// <returns>Rewritten expression.</returns>
|
||||
public Expr Rewrite(Expr expr, TContext context)
|
||||
{
|
||||
var newExpr = Visit(expr, context);
|
||||
DCE(newExpr);
|
||||
return newExpr;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Default rewrite leaf routine.
|
||||
/// </summary>
|
||||
protected virtual Expr DefaultRewriteLeaf(Expr expr, TContext context) => expr;
|
||||
|
||||
protected void SetMutated() => IsMutated = true;
|
||||
|
||||
protected override void VisitOperands(Expr expr, TContext context)
|
||||
{
|
||||
var operands = expr.Operands;
|
||||
for (int i = 0; i < operands.Length; i++)
|
||||
{
|
||||
var operand = operands[i];
|
||||
var newOperand = Visit(operand, context);
|
||||
if (!ReferenceEquals(operand, newOperand))
|
||||
{
|
||||
expr.ReplaceOperand(i, newOperand);
|
||||
SetMutated();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void DCE(Expr root)
|
||||
{
|
||||
using var exprPin = new ExprPinner(root);
|
||||
foreach (var expr in ExprMemo)
|
||||
{
|
||||
expr.Key.DisposeIfNoUsers();
|
||||
expr.Value.DisposeIfNoUsers();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Expression rewriter.
|
||||
/// </summary>
|
||||
public abstract partial class ExprRewriter : ExprRewriter<Unit>
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ExprRewriter"/> class.
|
||||
/// </summary>
|
||||
/// <param name="visitOtherFunctions">Vist other functions.</param>
|
||||
protected ExprRewriter(bool visitOtherFunctions = false)
|
||||
: base(visitOtherFunctions)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression to rewrite.</param>
|
||||
/// <returns>Rewritten expression.</returns>
|
||||
public Expr Rewrite(Expr expr) => Rewrite(expr, default);
|
||||
|
||||
/// <summary>
|
||||
/// Default rewrite leaf routine.
|
||||
/// </summary>
|
||||
protected virtual Expr DefaultRewriteLeaf(Expr expr) => base.DefaultRewriteLeaf(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr DefaultRewriteLeaf(Expr expr, Unit context) => DefaultRewriteLeaf(expr);
|
||||
}
|
|
@ -0,0 +1,553 @@
|
|||
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
// <auto-generated>
|
||||
// This code was generated by T4 template.
|
||||
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
|
||||
// </auto-generated>
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Reactive;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
public partial class ExprRewriter<TContext>
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafBaseFunction(BaseFunction expr, TContext context)
|
||||
{
|
||||
return RewriteLeafBaseFunction(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafCall(Call expr, TContext context)
|
||||
{
|
||||
return RewriteLeafCall(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafConst(Const expr, TContext context)
|
||||
{
|
||||
return RewriteLeafConst(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafFunction(Function expr, TContext context)
|
||||
{
|
||||
return RewriteLeafFunction(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafFusion(Fusion expr, TContext context)
|
||||
{
|
||||
return RewriteLeafFusion(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafIf(If expr, TContext context)
|
||||
{
|
||||
return RewriteLeafIf(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafMarker(Marker expr, TContext context)
|
||||
{
|
||||
return RewriteLeafMarker(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafNone(None expr, TContext context)
|
||||
{
|
||||
return RewriteLeafNone(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafOp(Op expr, TContext context)
|
||||
{
|
||||
return RewriteLeafOp(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, TContext context)
|
||||
{
|
||||
return RewriteLeafPrimFunctionWrapper(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafTensorConst(TensorConst expr, TContext context)
|
||||
{
|
||||
return RewriteLeafTensorConst(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafTuple(IR.Tuple expr, TContext context)
|
||||
{
|
||||
return RewriteLeafTuple(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafTupleConst(TupleConst expr, TContext context)
|
||||
{
|
||||
return RewriteLeafTupleConst(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafVar(Var expr, TContext context)
|
||||
{
|
||||
return RewriteLeafVar(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafBlock(TIR.Block expr, TContext context)
|
||||
{
|
||||
return RewriteLeafBlock(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafBuffer(TIR.Buffer expr, TContext context)
|
||||
{
|
||||
return RewriteLeafBuffer(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context)
|
||||
{
|
||||
return RewriteLeafLogicalBuffer(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context)
|
||||
{
|
||||
return RewriteLeafPhysicalBuffer(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context)
|
||||
{
|
||||
return RewriteLeafBufferLoad(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context)
|
||||
{
|
||||
return RewriteLeafBufferRegion(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafBufferStore(TIR.BufferStore expr, TContext context)
|
||||
{
|
||||
return RewriteLeafBufferStore(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafFor(TIR.For expr, TContext context)
|
||||
{
|
||||
return RewriteLeafFor(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafIfThenElse(TIR.IfThenElse expr, TContext context)
|
||||
{
|
||||
return RewriteLeafIfThenElse(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafLet(TIR.Let expr, TContext context)
|
||||
{
|
||||
return RewriteLeafLet(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafPrimFunction(TIR.PrimFunction expr, TContext context)
|
||||
{
|
||||
return RewriteLeafPrimFunction(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafSequential(TIR.Sequential expr, TContext context)
|
||||
{
|
||||
return RewriteLeafSequential(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafRange(TIR.Range expr, TContext context)
|
||||
{
|
||||
return RewriteLeafRange(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafIterVar(TIR.IterVar expr, TContext context)
|
||||
{
|
||||
return RewriteLeafIterVar(expr, context);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="BaseFunction"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBaseFunction(BaseFunction expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Call"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafCall(Call expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Const"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafConst(Const expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Function"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafFunction(Function expr, TContext context) => RewriteLeafBaseFunction(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Fusion"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafFusion(Fusion expr, TContext context) => RewriteLeafBaseFunction(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="If"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafIf(If expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Marker"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafMarker(Marker expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="None"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafNone(None expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Op"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafOp(Op expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="PrimFunctionWrapper"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafPrimFunctionWrapper(PrimFunctionWrapper expr, TContext context) => RewriteLeafBaseFunction(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TensorConst"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafTensorConst(TensorConst expr, TContext context) => RewriteLeafConst(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="IR.Tuple"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafTuple(IR.Tuple expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TupleConst"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafTupleConst(TupleConst expr, TContext context) => RewriteLeafConst(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Var"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafVar(Var expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.Block"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBlock(TIR.Block expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.Buffer"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBuffer(TIR.Buffer expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.LogicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => RewriteLeafBuffer(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => RewriteLeafBuffer(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.BufferLoad"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.BufferRegion"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.BufferStore"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBufferStore(TIR.BufferStore expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.For"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafFor(TIR.For expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.IfThenElse"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafIfThenElse(TIR.IfThenElse expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.Let"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafLet(TIR.Let expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.PrimFunction"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafPrimFunction(TIR.PrimFunction expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.Sequential"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafSequential(TIR.Sequential expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.Range"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafRange(TIR.Range expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.IterVar"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafIterVar(TIR.IterVar expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
}
|
||||
|
||||
public partial class ExprRewriter
|
||||
{
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="BaseFunction"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBaseFunction(BaseFunction expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafBaseFunction(BaseFunction expr, Unit context) => RewriteLeafBaseFunction(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Call"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafCall(Call expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafCall(Call expr, Unit context) => RewriteLeafCall(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Const"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafConst(Const expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafConst(Const expr, Unit context) => RewriteLeafConst(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Function"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafFunction(Function expr) => RewriteLeafBaseFunction(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafFunction(Function expr, Unit context) => RewriteLeafFunction(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Fusion"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafFusion(Fusion expr) => RewriteLeafBaseFunction(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafFusion(Fusion expr, Unit context) => RewriteLeafFusion(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="If"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafIf(If expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafIf(If expr, Unit context) => RewriteLeafIf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Marker"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafMarker(Marker expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafMarker(Marker expr, Unit context) => RewriteLeafMarker(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="None"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafNone(None expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafNone(None expr, Unit context) => RewriteLeafNone(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Op"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafOp(Op expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafOp(Op expr, Unit context) => RewriteLeafOp(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="PrimFunctionWrapper"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafPrimFunctionWrapper(PrimFunctionWrapper expr) => RewriteLeafBaseFunction(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => RewriteLeafPrimFunctionWrapper(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TensorConst"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafTensorConst(TensorConst expr) => RewriteLeafConst(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafTensorConst(TensorConst expr, Unit context) => RewriteLeafTensorConst(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="IR.Tuple"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafTuple(IR.Tuple expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafTuple(IR.Tuple expr, Unit context) => RewriteLeafTuple(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TupleConst"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafTupleConst(TupleConst expr) => RewriteLeafConst(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafTupleConst(TupleConst expr, Unit context) => RewriteLeafTupleConst(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Var"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafVar(Var expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafVar(Var expr, Unit context) => RewriteLeafVar(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.Block"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBlock(TIR.Block expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafBlock(TIR.Block expr, Unit context) => RewriteLeafBlock(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.Buffer"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBuffer(TIR.Buffer expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafBuffer(TIR.Buffer expr, Unit context) => RewriteLeafBuffer(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.LogicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr) => RewriteLeafBuffer(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => RewriteLeafLogicalBuffer(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr) => RewriteLeafBuffer(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => RewriteLeafPhysicalBuffer(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.BufferLoad"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBufferLoad(TIR.BufferLoad expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafBufferLoad(TIR.BufferLoad expr, Unit context) => RewriteLeafBufferLoad(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.BufferRegion"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBufferRegion(TIR.BufferRegion expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafBufferRegion(TIR.BufferRegion expr, Unit context) => RewriteLeafBufferRegion(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.BufferStore"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBufferStore(TIR.BufferStore expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafBufferStore(TIR.BufferStore expr, Unit context) => RewriteLeafBufferStore(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.For"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafFor(TIR.For expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafFor(TIR.For expr, Unit context) => RewriteLeafFor(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.IfThenElse"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafIfThenElse(TIR.IfThenElse expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafIfThenElse(TIR.IfThenElse expr, Unit context) => RewriteLeafIfThenElse(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.Let"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafLet(TIR.Let expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafLet(TIR.Let expr, Unit context) => RewriteLeafLet(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.PrimFunction"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafPrimFunction(TIR.PrimFunction expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafPrimFunction(TIR.PrimFunction expr, Unit context) => RewriteLeafPrimFunction(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.Sequential"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafSequential(TIR.Sequential expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafSequential(TIR.Sequential expr, Unit context) => RewriteLeafSequential(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.Range"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafRange(TIR.Range expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafRange(TIR.Range expr, Unit context) => RewriteLeafRange(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.IterVar"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafIterVar(TIR.IterVar expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafIterVar(TIR.IterVar expr, Unit context) => RewriteLeafIterVar(expr);
|
||||
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
<#@ template debug="false" hostspecific="false" language="C#" #>
|
||||
<#@ assembly name="System.Core" #>
|
||||
<#@ import namespace="System.IO" #>
|
||||
<#@ import namespace="System.Linq" #>
|
||||
<#@ import namespace="System.Text" #>
|
||||
<#@ import namespace="System.Collections.Generic" #>
|
||||
<#@ output extension=".cs" #>
|
||||
<#@ include file="IRListParser.tt"#>
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
// <auto-generated>
|
||||
// This code was generated by T4 template.
|
||||
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
|
||||
// </auto-generated>
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Reactive;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
public partial class ExprRewriter<TContext>
|
||||
{
|
||||
<#
|
||||
foreach (var ir in irs)
|
||||
{
|
||||
#>
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, TContext context)
|
||||
{
|
||||
return RewriteLeaf<#=ir.Name#>(expr, context);
|
||||
}
|
||||
|
||||
<#
|
||||
}
|
||||
#>
|
||||
<#
|
||||
foreach (var ir in irs)
|
||||
{
|
||||
var func = ir.VisitBase == "Default" ? "DefaultRewriteLeaf" : $"RewriteLeaf{ir.VisitBase}";
|
||||
#>
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="<#=ir.Namespace#><#=ir.Name#>"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, TContext context) => <#=func#>(expr, context);
|
||||
|
||||
<#
|
||||
}
|
||||
#>
|
||||
}
|
||||
|
||||
public partial class ExprRewriter
|
||||
{
|
||||
<#
|
||||
foreach (var ir in irs)
|
||||
{
|
||||
var func = ir.VisitBase == "Default" ? "DefaultRewriteLeaf" : $"RewriteLeaf{ir.VisitBase}";
|
||||
#>
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="<#=ir.Namespace#><#=ir.Name#>"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr) => <#=func#>(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, Unit context) => RewriteLeaf<#=ir.Name#>(expr);
|
||||
|
||||
<#
|
||||
}
|
||||
#>
|
||||
}
|
|
@ -3,765 +3,262 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Linq;
|
||||
using System.Reactive;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase.IR
|
||||
namespace Nncase.IR;
|
||||
|
||||
/// <summary>
|
||||
/// Expression visitor.
|
||||
/// </summary>
|
||||
/// <typeparam name="TExprResult">Expression visit result type.</typeparam>
|
||||
/// <typeparam name="TTypeResult">Type visit result type.</typeparam>
|
||||
/// <typeparam name="TContext">Visit context.</typeparam>
|
||||
public abstract partial class ExprVisitor<TExprResult, TTypeResult, TContext> : ExprFunctor<TExprResult, TTypeResult, TContext>
|
||||
{
|
||||
private readonly bool _visitOtherFunctions;
|
||||
|
||||
/// <summary>
|
||||
/// Expression visitor.
|
||||
/// Initializes a new instance of the <see cref="ExprVisitor{TExprResult, TTypeResult, TContext}"/> class.
|
||||
/// </summary>
|
||||
/// <typeparam name="TExprResult">Expression visit result type.</typeparam>
|
||||
/// <typeparam name="TTypeResult">Type visit result type.</typeparam>
|
||||
public abstract class ExprVisitor<TExprResult, TTypeResult> : ExprFunctor<TExprResult, TTypeResult>
|
||||
/// <param name="visitOtherFunctions">Vist other functions.</param>
|
||||
public ExprVisitor(bool visitOtherFunctions = false)
|
||||
{
|
||||
private readonly Dictionary<Expr, TExprResult> _exprMemo = new Dictionary<Expr, TExprResult>(ReferenceEqualityComparer.Instance);
|
||||
private readonly Dictionary<IVisitable, object> _visitableMemo = new Dictionary<IVisitable, object>(ReferenceEqualityComparer.Instance);
|
||||
private readonly Dictionary<IRType, TTypeResult> _typeMemo = new Dictionary<IRType, TTypeResult>();
|
||||
private readonly Dictionary<string, Action<Expr>> _callbacksAfterCall = new();
|
||||
private readonly Dictionary<string, Action<Expr>> _callbacksBeforeCall = new();
|
||||
_visitOtherFunctions = visitOtherFunctions;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets expression visit result memo.
|
||||
/// </summary>
|
||||
public Dictionary<Expr, TExprResult> ExpressionMemo => _exprMemo;
|
||||
/// <summary>
|
||||
/// Gets expression memo.
|
||||
/// </summary>
|
||||
public Dictionary<Expr, TExprResult> ExprMemo { get; } = new(ReferenceEqualityComparer.Instance);
|
||||
|
||||
/// <summary>
|
||||
/// Gets visitable visit result memo.
|
||||
/// </summary>
|
||||
public Dictionary<IVisitable, object> VisitAbleMemo => _visitableMemo;
|
||||
/// <summary>
|
||||
/// Gets type memo.
|
||||
/// </summary>
|
||||
public Dictionary<IRType, TTypeResult> TypeMemo { get; } = new(ReferenceEqualityComparer.Instance);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Visit(Call expr)
|
||||
/// <inheritdoc/>
|
||||
public override TTypeResult VisitType(AnyType type, TContext context)
|
||||
{
|
||||
if (HasVisited(type, out var result))
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
Visit(expr.Target);
|
||||
foreach (var param in expr.Parameters)
|
||||
{
|
||||
Visit(param);
|
||||
}
|
||||
|
||||
CallbacksBeforeCall(expr);
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
CallbacksAfterCall(expr);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override TExprResult Visit(If expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
Visit(expr.Condition);
|
||||
Visit(expr.Then);
|
||||
Visit(expr.Else);
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
return MarkVisited(type, VisitTypeLeaf(type, context));
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TTypeResult VisitType(CallableType type, TContext context)
|
||||
{
|
||||
if (HasVisited(type, out var result))
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TExprResult Visit(Const expr)
|
||||
foreach (var param in type.Parameters)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
VisitType(param, context);
|
||||
}
|
||||
|
||||
VisitType(type.ReturnType, context);
|
||||
return MarkVisited(type, VisitTypeLeaf(type, context));
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TTypeResult VisitType(InvalidType type, TContext context)
|
||||
{
|
||||
if (HasVisited(type, out var result))
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Visit(Function expr)
|
||||
return MarkVisited(type, VisitTypeLeaf(type, context));
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TTypeResult VisitType(TensorType type, TContext context)
|
||||
{
|
||||
if (HasVisited(type, out var result))
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
foreach (var param in expr.Parameters)
|
||||
{
|
||||
Visit(param);
|
||||
}
|
||||
|
||||
Visit(expr.Body);
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Visit(Fusion expr)
|
||||
return MarkVisited(type, VisitTypeLeaf(type, context));
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TTypeResult VisitType(TupleType type, TContext context)
|
||||
{
|
||||
if (HasVisited(type, out var result))
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
foreach (var param in expr.Parameters)
|
||||
{
|
||||
Visit(param);
|
||||
}
|
||||
|
||||
Visit(expr.Body);
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Visit(PrimFunctionWrapper expr)
|
||||
foreach (var field in type.Fields)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
Visit(expr.Target);
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
VisitType(field, context);
|
||||
}
|
||||
|
||||
return MarkVisited(type, VisitTypeLeaf(type, context));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Visit any type leaf.
|
||||
/// </summary>
|
||||
public virtual TTypeResult VisitTypeLeaf(AnyType type, TContext context) => DefaultVisitTypeLeaf(type, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit invalid type leaf.
|
||||
/// </summary>
|
||||
public virtual TTypeResult VisitTypeLeaf(InvalidType type, TContext context) => DefaultVisitTypeLeaf(type, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit tensor type leaf.
|
||||
/// </summary>
|
||||
public virtual TTypeResult VisitTypeLeaf(TensorType type, TContext context) => DefaultVisitTypeLeaf(type, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit tuple type leaf.
|
||||
/// </summary>
|
||||
public virtual TTypeResult VisitTypeLeaf(TupleType type, TContext context) => DefaultVisitTypeLeaf(type, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit tuple type leaf.
|
||||
/// </summary>
|
||||
public virtual TTypeResult VisitTypeLeaf(CallableType type, TContext context) => DefaultVisitTypeLeaf(type, context);
|
||||
|
||||
/// <summary>
|
||||
/// Default visit leaf routine.
|
||||
/// </summary>
|
||||
public virtual TTypeResult DefaultVisitTypeLeaf(IRType type, TContext context)
|
||||
{
|
||||
throw new NotImplementedException($"Unhandled visit leaf routine for {type.GetType()}.");
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override void Clear()
|
||||
{
|
||||
ExprMemo.Clear();
|
||||
TypeMemo.Clear();
|
||||
base.Clear();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Whether this expression is not visited before.
|
||||
/// </summary>
|
||||
protected bool HasVisited(Expr expr, [MaybeNullWhen(false)] out TExprResult result)
|
||||
=> ExprMemo.TryGetValue(expr, out result);
|
||||
|
||||
/// <summary>
|
||||
/// Whether this type is not visited before.
|
||||
/// </summary>
|
||||
protected bool HasVisited(IRType type, [MaybeNullWhen(false)] out TTypeResult result)
|
||||
=> TypeMemo.TryGetValue(type, out result);
|
||||
|
||||
/// <summary>
|
||||
/// Mark expression is visited.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression to visit.</param>
|
||||
/// <param name="result">Visit result.</param>
|
||||
protected TExprResult MarkVisited(Expr expr, TExprResult result)
|
||||
{
|
||||
ExprMemo[expr] = result;
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Mark type is visited.
|
||||
/// </summary>
|
||||
/// <param name="type">Type to visit.</param>
|
||||
/// <param name="result">Visit result.</param>
|
||||
protected TTypeResult MarkVisited(IRType type, TTypeResult result)
|
||||
{
|
||||
TypeMemo[type] = result;
|
||||
return result;
|
||||
}
|
||||
|
||||
protected bool CanVisitFunctionBody(BaseFunction baseFunction)
|
||||
{
|
||||
if (_visitOtherFunctions)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
return ReferenceEquals(baseFunction, VisitRoot);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Default leaf visit routine.
|
||||
/// </summary>
|
||||
protected virtual TExprResult DefaultVisitLeaf(Expr expr, TContext context)
|
||||
{
|
||||
throw new NotImplementedException($"Unhandled visit leaf routine for {expr.GetType()}.");
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override TExprResult DispatchVisit(Expr expr, TContext context)
|
||||
{
|
||||
if (HasVisited(expr, out var result))
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Visit(TIR.PrimFunction expr)
|
||||
return MarkVisited(expr, base.DispatchVisit(expr, context));
|
||||
}
|
||||
|
||||
protected virtual void VisitOperands(Expr expr, TContext context)
|
||||
{
|
||||
foreach (var operand in expr.Operands)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
foreach (var param in expr.Parameters)
|
||||
{
|
||||
Visit(param);
|
||||
}
|
||||
|
||||
Visit(expr.Body);
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TExprResult Visit(Op expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TExprResult Visit(Tuple expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
foreach (var field in expr.Fields)
|
||||
{
|
||||
Visit(field);
|
||||
}
|
||||
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Visit(Var expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Visit(None expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Visit(Marker expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
Visit(expr.Target);
|
||||
Visit(expr.Attribute);
|
||||
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TExprResult Visit(TIR.IterVar expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
Visit(expr.Value);
|
||||
Visit(expr.Dom);
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TExprResult Visit(TIR.Sequential expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
foreach (var item in expr.Fields)
|
||||
{
|
||||
Visit(item);
|
||||
}
|
||||
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Visit(TIR.For expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
Visit(expr.LoopVar);
|
||||
Visit(expr.Domain);
|
||||
Visit(expr.Body);
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TExprResult Visit(TIR.Block expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
Visit(expr.Body);
|
||||
Visit(expr.InitBody);
|
||||
foreach (var iterVar in expr.IterVars)
|
||||
{
|
||||
Visit(iterVar);
|
||||
}
|
||||
|
||||
foreach (var reads in expr.Reads)
|
||||
{
|
||||
Visit(reads);
|
||||
}
|
||||
|
||||
foreach (var writes in expr.Writes)
|
||||
{
|
||||
Visit(writes);
|
||||
}
|
||||
|
||||
foreach (var buffer in expr.AllocBuffers)
|
||||
{
|
||||
Visit(buffer);
|
||||
}
|
||||
|
||||
Visit(expr.Predicate);
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TExprResult Visit(TIR.BufferLoad expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
foreach (var index in expr.Indices)
|
||||
{
|
||||
Visit(index);
|
||||
}
|
||||
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TExprResult Visit(TIR.BufferStore expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
Visit(expr.Buffer);
|
||||
foreach (var index in expr.Indices)
|
||||
{
|
||||
Visit(index);
|
||||
}
|
||||
|
||||
Visit(expr.Value);
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TExprResult Visit(TIR.IfThenElse expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
Visit(expr.Condition);
|
||||
Visit(expr.Then);
|
||||
Visit(expr.Else);
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Visit(TIR.Let expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
Visit(expr.Var);
|
||||
Visit(expr.Expression);
|
||||
Visit(expr.Body);
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Visit(TIR.Buffer expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Visit(TIR.BufferRegion expr)
|
||||
{
|
||||
if (!_exprMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
Visit(expr.Buffer);
|
||||
foreach (var param in expr.Region)
|
||||
{
|
||||
Visit(param);
|
||||
}
|
||||
|
||||
result = VisitLeaf(expr);
|
||||
_exprMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// visit ivisitable.
|
||||
/// </summary>
|
||||
public override object Visit(IVisitable visitable)
|
||||
{
|
||||
if (!_visitableMemo.TryGetValue(visitable, out var result))
|
||||
{
|
||||
visitable.Visit<TExprResult, TTypeResult>(this);
|
||||
result = VisitLeaf(visitable);
|
||||
_visitableMemo.Add(visitable, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Visit expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(Expr expr)
|
||||
{
|
||||
return expr switch
|
||||
{
|
||||
Var var => VisitLeaf(var),
|
||||
Const con => VisitLeaf(con),
|
||||
Function func => VisitLeaf(func),
|
||||
Fusion fusion => VisitLeaf(fusion),
|
||||
Call call => VisitLeaf(call),
|
||||
If @if => VisitLeaf(@if),
|
||||
Tuple tuple => VisitLeaf(tuple),
|
||||
Op op => VisitLeaf(op),
|
||||
None none => VisitLeaf(none),
|
||||
TIR.Sequential seq => VisitLeaf(seq),
|
||||
TIR.For @for => VisitLeaf(@for),
|
||||
TIR.Block block => VisitLeaf(block),
|
||||
TIR.BufferLoad bufload => VisitLeaf(bufload),
|
||||
TIR.BufferStore bufstore => VisitLeaf(bufstore),
|
||||
TIR.IfThenElse ift => VisitLeaf(ift),
|
||||
TIR.Let let => VisitLeaf(let),
|
||||
TIR.Buffer buffer => VisitLeaf(buffer),
|
||||
_ => DefaultVisitLeaf(expr),
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf variable expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Variable expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(Var expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf constant expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Constant expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(Const expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf function expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Function expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(Function expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf fusion expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Fusion expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(Fusion expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf prim function wrapper expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Function expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(PrimFunctionWrapper expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf prim function expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">PrimFunction expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(TIR.PrimFunction expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf call expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Call expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(Call expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf if expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">If expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(If expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf tuple expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Variable expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(Tuple expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf operator expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Operator expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(Op expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf None expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">None expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(None expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf marker expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">None expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(Marker expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf IterVar expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">IterVar expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(TIR.IterVar expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf sequential expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">sequential expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(TIR.Sequential expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf For expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">For expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(TIR.For expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf Block expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Block expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(TIR.Block expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf BufferLoad expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">BufferLoad expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(TIR.BufferLoad expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf BufferRead expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">BufferRead expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(TIR.BufferStore expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf IfThenElse expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">IfThenElse expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(TIR.IfThenElse expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf Let expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">Let expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(TIR.Let expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf MemRef expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">MemRef expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(TIR.Buffer expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf buffer region expression.
|
||||
/// </summary>
|
||||
/// <param name="expr">buffer region expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult VisitLeaf(TIR.BufferRegion expr) => DefaultVisitLeaf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf ifunctable.
|
||||
/// </summary>
|
||||
public virtual object VisitLeaf(IVisitable visitable) => DefaultVisitLeaf(visitable);
|
||||
|
||||
/// <summary>
|
||||
/// Default leaf visit routine.
|
||||
/// </summary>
|
||||
public virtual object DefaultVisitLeaf(IVisitable visitable)
|
||||
{
|
||||
throw new NotImplementedException($"Unhandled visit leaf routine for {visitable.GetType()}.");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Default leaf visit routine.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TExprResult DefaultVisitLeaf(Expr expr)
|
||||
{
|
||||
throw new NotImplementedException($"Unhandled visit leaf routine for {expr.GetType()}.");
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(AnyType type)
|
||||
{
|
||||
if (!_typeMemo.TryGetValue(type, out var result))
|
||||
{
|
||||
result = VisitTypeLeaf(type);
|
||||
_typeMemo.Add(type, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(CallableType type)
|
||||
{
|
||||
if (!_typeMemo.TryGetValue(type, out var result))
|
||||
{
|
||||
foreach (var param in type.Parameters)
|
||||
{
|
||||
VisitType(param);
|
||||
}
|
||||
|
||||
VisitType(type.ReturnType);
|
||||
result = VisitTypeLeaf(type);
|
||||
_typeMemo.Add(type, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(InvalidType type)
|
||||
{
|
||||
if (!_typeMemo.TryGetValue(type, out var result))
|
||||
{
|
||||
result = VisitTypeLeaf(type);
|
||||
_typeMemo.Add(type, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(TensorType type)
|
||||
{
|
||||
if (!_typeMemo.TryGetValue(type, out var result))
|
||||
{
|
||||
result = VisitTypeLeaf(type);
|
||||
_typeMemo.Add(type, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(TupleType type)
|
||||
{
|
||||
if (!_typeMemo.TryGetValue(type, out var result))
|
||||
{
|
||||
foreach (var field in type.Fields)
|
||||
{
|
||||
VisitType(field);
|
||||
}
|
||||
|
||||
result = VisitTypeLeaf(type);
|
||||
_typeMemo.Add(type, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Visit any type leaf.
|
||||
/// </summary>
|
||||
/// <param name="type">Any type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult VisitTypeLeaf(AnyType type) => DefaultVisitTypeLeaf(type);
|
||||
|
||||
/// <summary>
|
||||
/// Visit invalid type leaf.
|
||||
/// </summary>
|
||||
/// <param name="type">Invalid type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult VisitTypeLeaf(InvalidType type) => DefaultVisitTypeLeaf(type);
|
||||
|
||||
/// <summary>
|
||||
/// Visit tensor type leaf.
|
||||
/// </summary>
|
||||
/// <param name="type">Tensor type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult VisitTypeLeaf(TensorType type) => DefaultVisitTypeLeaf(type);
|
||||
|
||||
/// <summary>
|
||||
/// Visit tuple type leaf.
|
||||
/// </summary>
|
||||
/// <param name="type">Tuple type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult VisitTypeLeaf(TupleType type) => DefaultVisitTypeLeaf(type);
|
||||
|
||||
/// <summary>
|
||||
/// Visit tuple type leaf.
|
||||
/// </summary>
|
||||
/// <param name="type">Callable type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult VisitTypeLeaf(CallableType type) => DefaultVisitTypeLeaf(type);
|
||||
|
||||
/// <summary>
|
||||
/// Default visit leaf routine.
|
||||
/// </summary>
|
||||
/// <param name="type">Type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult DefaultVisitTypeLeaf(IRType type)
|
||||
{
|
||||
throw new NotImplementedException($"Unhandled visit leaf routine for {type.GetType()}.");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// clear the Memo!.
|
||||
/// </summary>
|
||||
public virtual void Clear()
|
||||
{
|
||||
_exprMemo.Clear();
|
||||
_typeMemo.Clear();
|
||||
}
|
||||
|
||||
protected void RegisterAfterCallback(string name, Action<Expr> callback)
|
||||
{
|
||||
_callbacksAfterCall[name] = callback;
|
||||
}
|
||||
|
||||
protected void RegisterBeforeCallback(string name, Action<Expr> callback)
|
||||
{
|
||||
_callbacksBeforeCall[name] = callback;
|
||||
}
|
||||
|
||||
private void CallbacksBeforeCall(Expr expr)
|
||||
{
|
||||
foreach (var (name, callback) in _callbacksBeforeCall)
|
||||
{
|
||||
callback(expr);
|
||||
}
|
||||
}
|
||||
|
||||
private void CallbacksAfterCall(Expr expr)
|
||||
{
|
||||
foreach (var (name, callback) in _callbacksAfterCall)
|
||||
{
|
||||
callback(expr);
|
||||
}
|
||||
Visit(operand, context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Expression visitor.
|
||||
/// </summary>
|
||||
/// <typeparam name="TExprResult">Expression visit result type.</typeparam>
|
||||
/// <typeparam name="TTypeResult">Type visit result type.</typeparam>
|
||||
public abstract partial class ExprVisitor<TExprResult, TTypeResult> : ExprVisitor<TExprResult, TTypeResult, Unit>
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ExprVisitor{TExprResult, TTypeResult}"/> class.
|
||||
/// </summary>
|
||||
/// <param name="visitOtherFunctions">Vist other functions.</param>
|
||||
public ExprVisitor(bool visitOtherFunctions = false)
|
||||
: base(visitOtherFunctions)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="Expr"/>.
|
||||
/// </summary>
|
||||
public TExprResult Visit(Expr expr) => Visit(expr, default);
|
||||
|
||||
/// <summary>
|
||||
/// Default visit routine.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression.</param>
|
||||
/// <returns>Result.</returns>
|
||||
protected internal virtual TExprResult DefaultVisit(Expr expr) => base.DefaultVisit(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected internal sealed override TExprResult DefaultVisit(Expr expr, Unit context) => DefaultVisit(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Default leaf visit routine.
|
||||
/// </summary>
|
||||
protected virtual TExprResult DefaultVisitLeaf(Expr expr) => base.DefaultVisitLeaf(expr, default);
|
||||
|
||||
protected sealed override TExprResult DefaultVisitLeaf(Expr expr, Unit context) => DefaultVisitLeaf(expr);
|
||||
|
||||
protected virtual TExprResult DispatchVisit(Expr expr) => base.DispatchVisit(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult DispatchVisit(Expr expr, Unit context) => DispatchVisit(expr);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,751 @@
|
|||
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
// <auto-generated>
|
||||
// This code was generated by T4 template.
|
||||
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
|
||||
// </auto-generated>
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Reactive;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
public partial class ExprVisitor<TExprResult, TTypeResult, TContext>
|
||||
{
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitCall(Call expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafCall(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitFunction(Function expr, TContext context)
|
||||
{
|
||||
if (CanVisitFunctionBody(expr))
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
}
|
||||
|
||||
return VisitLeafFunction(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitFusion(Fusion expr, TContext context)
|
||||
{
|
||||
if (CanVisitFunctionBody(expr))
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
}
|
||||
|
||||
return VisitLeafFusion(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitIf(If expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafIf(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitMarker(Marker expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafMarker(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitNone(None expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafNone(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitOp(Op expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafOp(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, TContext context)
|
||||
{
|
||||
if (CanVisitFunctionBody(expr))
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
}
|
||||
|
||||
return VisitLeafPrimFunctionWrapper(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitTensorConst(TensorConst expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafTensorConst(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitTuple(IR.Tuple expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafTuple(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitTupleConst(TupleConst expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafTupleConst(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitVar(Var expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafVar(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitBlock(TIR.Block expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafBlock(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafLogicalBuffer(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafPhysicalBuffer(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitBufferLoad(TIR.BufferLoad expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafBufferLoad(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitBufferRegion(TIR.BufferRegion expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafBufferRegion(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitBufferStore(TIR.BufferStore expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafBufferStore(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitFor(TIR.For expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafFor(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitIfThenElse(TIR.IfThenElse expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafIfThenElse(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitLet(TIR.Let expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafLet(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitPrimFunction(TIR.PrimFunction expr, TContext context)
|
||||
{
|
||||
if (CanVisitFunctionBody(expr))
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
}
|
||||
|
||||
return VisitLeafPrimFunction(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitSequential(TIR.Sequential expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafSequential(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitRange(TIR.Range expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafRange(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitIterVar(TIR.IterVar expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafIterVar(expr, context);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="BaseFunction"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBaseFunction(BaseFunction expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Call"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafCall(Call expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Const"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafConst(Const expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Function"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafFunction(Function expr, TContext context) => VisitLeafBaseFunction(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Fusion"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafFusion(Fusion expr, TContext context) => VisitLeafBaseFunction(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="If"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafIf(If expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Marker"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafMarker(Marker expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="None"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafNone(None expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Op"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafOp(Op expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="PrimFunctionWrapper"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, TContext context) => VisitLeafBaseFunction(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TensorConst"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafTensorConst(TensorConst expr, TContext context) => VisitLeafConst(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="IR.Tuple"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafTuple(IR.Tuple expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TupleConst"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafTupleConst(TupleConst expr, TContext context) => VisitLeafConst(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Var"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafVar(Var expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.Block"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBlock(TIR.Block expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.Buffer"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBuffer(TIR.Buffer expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.LogicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => VisitLeafBuffer(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitLeafBuffer(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.BufferLoad"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.BufferRegion"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.BufferStore"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBufferStore(TIR.BufferStore expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.For"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafFor(TIR.For expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.IfThenElse"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.Let"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafLet(TIR.Let expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.PrimFunction"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.Sequential"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafSequential(TIR.Sequential expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.Range"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafRange(TIR.Range expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.IterVar"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafIterVar(TIR.IterVar expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
}
|
||||
|
||||
public partial class ExprVisitor<TExprResult, TTypeResult>
|
||||
{
|
||||
/// <summary>
|
||||
/// Visit <see cref="Call"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitCall(Call expr) => base.VisitCall(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitCall(Call expr, Unit context) => VisitCall(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Function"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitFunction(Function expr) => base.VisitFunction(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitFunction(Function expr, Unit context) => VisitFunction(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Fusion"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitFusion(Fusion expr) => base.VisitFusion(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitFusion(Fusion expr, Unit context) => VisitFusion(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="If"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitIf(If expr) => base.VisitIf(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitIf(If expr, Unit context) => VisitIf(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Marker"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitMarker(Marker expr) => base.VisitMarker(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitMarker(Marker expr, Unit context) => VisitMarker(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="None"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitNone(None expr) => base.VisitNone(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitNone(None expr, Unit context) => VisitNone(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Op"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitOp(Op expr) => base.VisitOp(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitOp(Op expr, Unit context) => VisitOp(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="PrimFunctionWrapper"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitPrimFunctionWrapper(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitPrimFunctionWrapper(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TensorConst"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitTensorConst(TensorConst expr) => base.VisitTensorConst(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="IR.Tuple"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitTuple(IR.Tuple expr) => base.VisitTuple(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitTuple(IR.Tuple expr, Unit context) => VisitTuple(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TupleConst"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitTupleConst(TupleConst expr) => base.VisitTupleConst(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitTupleConst(TupleConst expr, Unit context) => VisitTupleConst(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Var"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitVar(Var expr) => base.VisitVar(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitVar(Var expr, Unit context) => VisitVar(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Block"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBlock(TIR.Block expr) => base.VisitBlock(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBlock(TIR.Block expr, Unit context) => VisitBlock(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.LogicalBuffer"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLogicalBuffer(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLogicalBuffer(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitPhysicalBuffer(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferLoad"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr) => base.VisitBufferLoad(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBufferLoad(TIR.BufferLoad expr, Unit context) => VisitBufferLoad(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferRegion"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBufferRegion(TIR.BufferRegion expr, Unit context) => VisitBufferRegion(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferStore"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr) => base.VisitBufferStore(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBufferStore(TIR.BufferStore expr, Unit context) => VisitBufferStore(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.For"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitFor(TIR.For expr) => base.VisitFor(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitFor(TIR.For expr, Unit context) => VisitFor(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.IfThenElse"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitIfThenElse(TIR.IfThenElse expr) => base.VisitIfThenElse(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitIfThenElse(TIR.IfThenElse expr, Unit context) => VisitIfThenElse(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Let"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitLet(TIR.Let expr) => base.VisitLet(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitLet(TIR.Let expr, Unit context) => VisitLet(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.PrimFunction"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitPrimFunction(TIR.PrimFunction expr) => base.VisitPrimFunction(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitPrimFunction(TIR.PrimFunction expr, Unit context) => VisitPrimFunction(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Sequential"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitSequential(TIR.Sequential expr) => base.VisitSequential(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitSequential(TIR.Sequential expr, Unit context) => VisitSequential(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Range"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitRange(TIR.Range expr) => base.VisitRange(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitRange(TIR.Range expr, Unit context) => VisitRange(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.IterVar"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitIterVar(TIR.IterVar expr) => base.VisitIterVar(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitIterVar(TIR.IterVar expr, Unit context) => VisitIterVar(expr);
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="BaseFunction"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBaseFunction(BaseFunction expr) => base.VisitLeafBaseFunction(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafBaseFunction(BaseFunction expr, Unit context) => VisitLeafBaseFunction(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Call"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafCall(Call expr) => base.VisitLeafCall(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafCall(Call expr, Unit context) => VisitLeafCall(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Const"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafConst(Const expr) => base.VisitLeafConst(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafConst(Const expr, Unit context) => VisitLeafConst(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Function"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafFunction(Function expr) => base.VisitLeafFunction(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafFunction(Function expr, Unit context) => VisitLeafFunction(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Fusion"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafFusion(Fusion expr) => base.VisitLeafFusion(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafFusion(Fusion expr, Unit context) => VisitLeafFusion(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="If"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafIf(If expr) => base.VisitLeafIf(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafIf(If expr, Unit context) => VisitLeafIf(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Marker"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafMarker(Marker expr) => base.VisitLeafMarker(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafMarker(Marker expr, Unit context) => VisitLeafMarker(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="None"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafNone(None expr) => base.VisitLeafNone(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafNone(None expr, Unit context) => VisitLeafNone(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Op"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafOp(Op expr) => base.VisitLeafOp(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafOp(Op expr, Unit context) => VisitLeafOp(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="PrimFunctionWrapper"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitLeafPrimFunctionWrapper(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitLeafPrimFunctionWrapper(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TensorConst"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafTensorConst(TensorConst expr) => base.VisitLeafTensorConst(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafTensorConst(TensorConst expr, Unit context) => VisitLeafTensorConst(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="IR.Tuple"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafTuple(IR.Tuple expr) => base.VisitLeafTuple(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafTuple(IR.Tuple expr, Unit context) => VisitLeafTuple(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TupleConst"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafTupleConst(TupleConst expr) => base.VisitLeafTupleConst(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafTupleConst(TupleConst expr, Unit context) => VisitLeafTupleConst(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Var"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafVar(Var expr) => base.VisitLeafVar(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafVar(Var expr, Unit context) => VisitLeafVar(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.Block"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBlock(TIR.Block expr) => base.VisitLeafBlock(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafBlock(TIR.Block expr, Unit context) => VisitLeafBlock(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.Buffer"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBuffer(TIR.Buffer expr) => base.VisitLeafBuffer(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafBuffer(TIR.Buffer expr, Unit context) => VisitLeafBuffer(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.LogicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLeafLogicalBuffer(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLeafLogicalBuffer(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitLeafPhysicalBuffer(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitLeafPhysicalBuffer(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.BufferLoad"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr) => base.VisitLeafBufferLoad(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr, Unit context) => VisitLeafBufferLoad(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.BufferRegion"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr) => base.VisitLeafBufferRegion(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr, Unit context) => VisitLeafBufferRegion(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.BufferStore"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBufferStore(TIR.BufferStore expr) => base.VisitLeafBufferStore(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafBufferStore(TIR.BufferStore expr, Unit context) => VisitLeafBufferStore(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.For"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafFor(TIR.For expr) => base.VisitLeafFor(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafFor(TIR.For expr, Unit context) => VisitLeafFor(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.IfThenElse"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr) => base.VisitLeafIfThenElse(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr, Unit context) => VisitLeafIfThenElse(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.Let"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafLet(TIR.Let expr) => base.VisitLeafLet(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafLet(TIR.Let expr, Unit context) => VisitLeafLet(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.PrimFunction"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr) => base.VisitLeafPrimFunction(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr, Unit context) => VisitLeafPrimFunction(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.Sequential"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafSequential(TIR.Sequential expr) => base.VisitLeafSequential(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafSequential(TIR.Sequential expr, Unit context) => VisitLeafSequential(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.Range"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafRange(TIR.Range expr) => base.VisitLeafRange(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafRange(TIR.Range expr, Unit context) => VisitLeafRange(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.IterVar"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafIterVar(TIR.IterVar expr) => base.VisitLeafIterVar(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafIterVar(TIR.IterVar expr, Unit context) => VisitLeafIterVar(expr);
|
||||
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
<#@ template debug="false" hostspecific="false" language="C#" #>
|
||||
<#@ assembly name="System.Core" #>
|
||||
<#@ import namespace="System.IO" #>
|
||||
<#@ import namespace="System.Linq" #>
|
||||
<#@ import namespace="System.Text" #>
|
||||
<#@ import namespace="System.Collections.Generic" #>
|
||||
<#@ output extension=".cs" #>
|
||||
<#@ include file="IRListParser.tt"#>
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
// <auto-generated>
|
||||
// This code was generated by T4 template.
|
||||
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
|
||||
// </auto-generated>
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Reactive;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
public partial class ExprVisitor<TExprResult, TTypeResult, TContext>
|
||||
{
|
||||
<#
|
||||
foreach (var ir in irs.Where(x => x.IsDerived))
|
||||
{
|
||||
#>
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult Visit<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, TContext context)
|
||||
{
|
||||
<#
|
||||
var fieldIdent = ir.IsFunction ? " " : string.Empty;
|
||||
if (ir.IsFunction)
|
||||
{
|
||||
#>
|
||||
if (CanVisitFunctionBody(expr))
|
||||
{
|
||||
<#
|
||||
}
|
||||
#>
|
||||
<#=fieldIdent#>VisitOperands(expr, context);
|
||||
<#
|
||||
|
||||
if (ir.IsFunction)
|
||||
{
|
||||
#>
|
||||
}
|
||||
|
||||
<#
|
||||
}
|
||||
#>
|
||||
return VisitLeaf<#=ir.Name#>(expr, context);
|
||||
}
|
||||
|
||||
<#
|
||||
}
|
||||
#>
|
||||
<#
|
||||
foreach (var ir in irs)
|
||||
{
|
||||
var func = ir.VisitBase == "Default" ? "DefaultVisitLeaf" : $"VisitLeaf{ir.VisitBase}";
|
||||
#>
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="<#=ir.Namespace#><#=ir.Name#>"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, TContext context) => <#=func#>(expr, context);
|
||||
|
||||
<#
|
||||
}
|
||||
#>
|
||||
}
|
||||
|
||||
public partial class ExprVisitor<TExprResult, TTypeResult>
|
||||
{
|
||||
<#
|
||||
foreach (var ir in irs.Where(x => x.IsDerived))
|
||||
{
|
||||
#>
|
||||
/// <summary>
|
||||
/// Visit <see cref="<#=ir.Namespace#><#=ir.Name#>"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult Visit<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr) => base.Visit<#=ir.Name#>(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult Visit<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, Unit context) => Visit<#=ir.Name#>(expr);
|
||||
<#
|
||||
}
|
||||
#>
|
||||
<#
|
||||
foreach (var ir in irs)
|
||||
{
|
||||
var func = ir.VisitBase == "Default" ? "DefaultVisitLeaf" : $"VisitLeaf{ir.VisitBase}";
|
||||
#>
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="<#=ir.Namespace#><#=ir.Name#>"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr) => base.VisitLeaf<#=ir.Name#>(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, Unit context) => VisitLeaf<#=ir.Name#>(expr);
|
||||
|
||||
<#
|
||||
}
|
||||
#>
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reactive;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
public abstract class ExprWalker<TContext> : ExprVisitor<Unit, Unit, TContext>
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ExprWalker{TContext}"/> class.
|
||||
/// </summary>
|
||||
/// <param name="visitOtherFunctions">Vist other functions.</param>
|
||||
public ExprWalker(bool visitOtherFunctions = false)
|
||||
: base(visitOtherFunctions)
|
||||
{
|
||||
}
|
||||
|
||||
protected override Unit DefaultVisitLeaf(Expr expr, TContext context) => default;
|
||||
}
|
||||
|
||||
public abstract class ExprWalker : ExprVisitor<Unit, Unit>
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ExprWalker"/> class.
|
||||
/// </summary>
|
||||
/// <param name="visitOtherFunctions">Vist other functions.</param>
|
||||
public ExprWalker(bool visitOtherFunctions = false)
|
||||
: base(visitOtherFunctions)
|
||||
{
|
||||
}
|
||||
|
||||
protected override Unit DefaultVisitLeaf(Expr expr) => default;
|
||||
}
|
|
@ -3,56 +3,35 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Collections.Immutable;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using NetFabric.Hyperlinq;
|
||||
using Nncase.Utilities;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
/// <summary>
|
||||
/// the Callable Expr.
|
||||
/// </summary>
|
||||
public abstract record Callable(string Name, string ModuleKind) : Expr
|
||||
{
|
||||
/// <summary>
|
||||
/// StackVM module kind.
|
||||
/// </summary>
|
||||
public static readonly string StackVMModuleKind = "stackvm";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Base function.
|
||||
/// </summary>
|
||||
public abstract record BaseFunction(string Name, string ModuleKind) : Callable(Name, ModuleKind)
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets sched result.
|
||||
/// </summary>
|
||||
public Schedule.SchedFunctionResult SchedResult { get; } = new();
|
||||
|
||||
/// <summary>
|
||||
/// Gets parameter types.
|
||||
/// </summary>
|
||||
public abstract IEnumerable<IRType?> ParameterTypes { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override int GetHashCode() => base.GetHashCode();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Function expression.
|
||||
/// </summary>
|
||||
public sealed record Function(string Name, Expr Body, IRArray<Var> Parameters) : BaseFunction(Name, StackVMModuleKind)
|
||||
public sealed class Function : BaseFunction
|
||||
{
|
||||
private static int _globalFuncIndex;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="Function"/> class.
|
||||
/// build function.
|
||||
/// </summary>
|
||||
/// <param name="parameters">Parameters.</param>
|
||||
/// <param name="body">Body.</param>
|
||||
public Function(Expr body, IRArray<Var> parameters)
|
||||
public Function(string name, Expr body, ReadOnlySpan<Var> parameters)
|
||||
: base(name, StackVMModuleKind, ArrayUtility.Concat(body, SpanUtility.UnsafeCast<Var, Expr>(parameters)))
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="Function"/> class.
|
||||
/// build function.
|
||||
/// </summary>
|
||||
public Function(Expr body, ReadOnlySpan<Var> parameters)
|
||||
: this($"func_{_globalFuncIndex++}", body, parameters)
|
||||
{
|
||||
}
|
||||
|
@ -61,13 +40,33 @@ public sealed record Function(string Name, Expr Body, IRArray<Var> Parameters) :
|
|||
/// Initializes a new instance of the <see cref="Function"/> class.
|
||||
/// build function.
|
||||
/// </summary>
|
||||
public Function(Expr body, params Var[] parameters)
|
||||
: this($"func_{_globalFuncIndex++}", body, new(parameters))
|
||||
public Function(string name, Expr body, params Var[] parameters)
|
||||
: this(name, body, parameters.AsSpan())
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="Function"/> class.
|
||||
/// build function.
|
||||
/// </summary>
|
||||
public Function(Expr body, params Var[] parameters)
|
||||
: this(body, parameters.AsSpan())
|
||||
{
|
||||
}
|
||||
|
||||
public Expr Body => Operands[0];
|
||||
|
||||
public ReadOnlySpan<Var> Parameters => SpanUtility.UnsafeCast<Expr, Var>(Operands[1..]);
|
||||
|
||||
/// <summary>
|
||||
/// Gets get all parameter checked types.
|
||||
/// </summary>
|
||||
public override IEnumerable<IRType?> ParameterTypes => Parameters.Select(x => x.CheckedType);
|
||||
public override IEnumerable<IRType?> ParameterTypes => Parameters.AsValueEnumerable().Select(x => x.CheckedType).ToArray();
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
|
||||
=> functor.VisitFunction(this, context);
|
||||
|
||||
public Function With(string? name = null, Expr? body = null, Var[]? parameters = null)
|
||||
=> new Function(name ?? Name, body ?? Body, parameters ?? Parameters);
|
||||
}
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using NetFabric.Hyperlinq;
|
||||
using Nncase.CodeGen;
|
||||
using Nncase.Utilities;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
/// <summary>
|
||||
/// Fusion expression.
|
||||
/// </summary>
|
||||
public record Fusion(string Name, string ModuleKind, Expr Body, IRArray<Var> Parameters) : BaseFunction(Name, ModuleKind)
|
||||
public sealed class Fusion : BaseFunction
|
||||
{
|
||||
private static int _globalFusionIndex;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="Fusion"/> class.
|
||||
/// build function.
|
||||
/// </summary>
|
||||
/// <param name="module_kind">Module kind.</param>
|
||||
/// <param name="parameters">Parameters.</param>
|
||||
/// <param name="body">Body.</param>
|
||||
public Fusion(string module_kind, Expr body, IRArray<Var> parameters)
|
||||
: this($"fusion_{_globalFusionIndex++}", module_kind, body, parameters)
|
||||
public Fusion(string name, string moduleKind, Expr body, ReadOnlySpan<Var> parameters)
|
||||
: base(name, moduleKind, ArrayUtility.Concat(body, SpanUtility.UnsafeCast<Var, Expr>(parameters)))
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -27,13 +27,42 @@ public record Fusion(string Name, string ModuleKind, Expr Body, IRArray<Var> Par
|
|||
/// Initializes a new instance of the <see cref="Fusion"/> class.
|
||||
/// build function.
|
||||
/// </summary>
|
||||
public Fusion(string module_kind, Expr body, params Var[] parameters)
|
||||
: this($"fusion_{_globalFusionIndex++}", module_kind, body, new(parameters))
|
||||
public Fusion(string moduleKind, Expr body, ReadOnlySpan<Var> parameters)
|
||||
: this($"func_{_globalFusionIndex++}", moduleKind, body, parameters)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="Fusion"/> class.
|
||||
/// build function.
|
||||
/// </summary>
|
||||
public Fusion(string name, string moduleKind, Expr body, params Var[] parameters)
|
||||
: base(name, moduleKind, ArrayUtility.Concat(body, SpanUtility.UnsafeCast<Var, Expr>(parameters)))
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="Fusion"/> class.
|
||||
/// build function.
|
||||
/// </summary>
|
||||
public Fusion(string moduleKind, Expr body, params Var[] parameters)
|
||||
: this($"func_{_globalFusionIndex++}", moduleKind, body, parameters)
|
||||
{
|
||||
}
|
||||
|
||||
public Expr Body => Operands[0];
|
||||
|
||||
public ReadOnlySpan<Var> Parameters => SpanUtility.UnsafeCast<Expr, Var>(Operands[1..]);
|
||||
|
||||
/// <summary>
|
||||
/// Gets get all parameter checked types.
|
||||
/// </summary>
|
||||
public override IEnumerable<IRType?> ParameterTypes => Parameters.Select(x => x.CheckedType);
|
||||
public override IEnumerable<IRType?> ParameterTypes => Parameters.AsValueEnumerable().Select(x => x.CheckedType).ToArray();
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
|
||||
=> functor.VisitFusion(this, context);
|
||||
|
||||
public Fusion With(string? name = null, string? moduleKind = null, Expr? body = null, Var[]? parameters = null)
|
||||
=> new Fusion(name ?? Name, moduleKind ?? ModuleKind, body ?? Body, parameters ?? Parameters);
|
||||
}
|
||||
|
|
|
@ -1,371 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using GiGraph.Dot.Entities.Clusters;
|
||||
using GiGraph.Dot.Entities.Graphs;
|
||||
using GiGraph.Dot.Entities.Html.Table;
|
||||
using GiGraph.Dot.Entities.Nodes;
|
||||
using GiGraph.Dot.Extensions;
|
||||
using GiGraph.Dot.Types.Colors;
|
||||
using GiGraph.Dot.Types.Edges;
|
||||
using GiGraph.Dot.Types.Fonts;
|
||||
using GiGraph.Dot.Types.Graphs;
|
||||
using GiGraph.Dot.Types.Nodes;
|
||||
using GiGraph.Dot.Types.Records;
|
||||
using GiGraph.Dot.Types.Styling;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
internal sealed class ILDotOption
|
||||
{
|
||||
private readonly DotNode? _dotNode;
|
||||
private readonly string? _str;
|
||||
|
||||
public ILDotOption(DotNode dotNode)
|
||||
{
|
||||
_dotNode = dotNode;
|
||||
_str = null;
|
||||
}
|
||||
|
||||
public ILDotOption(string str)
|
||||
{
|
||||
_dotNode = null;
|
||||
_str = str;
|
||||
}
|
||||
|
||||
public DotNode DotNode => _dotNode!;
|
||||
|
||||
public string Str => _str!;
|
||||
|
||||
public bool IsDotNode => _dotNode is not null;
|
||||
}
|
||||
|
||||
internal sealed class ILDotPrintVisitor : ExprVisitor<ILDotOption, string>
|
||||
{
|
||||
private readonly bool _display_callable;
|
||||
private readonly DotGraph _dotGraph;
|
||||
private readonly List<(string, DotGraph)> _subdotGraphs;
|
||||
private int _idCounter;
|
||||
|
||||
private BaseFunction? _entryBaseFunc;
|
||||
|
||||
public ILDotPrintVisitor(bool display_callable)
|
||||
{
|
||||
_display_callable = display_callable;
|
||||
_dotGraph = new(directed: true);
|
||||
_subdotGraphs = new();
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override ILDotOption Visit(BaseFunction baseFunction)
|
||||
{
|
||||
_entryBaseFunc ??= baseFunction;
|
||||
return base.Visit(baseFunction);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override ILDotOption Visit(PrimFunctionWrapper expr)
|
||||
{
|
||||
if (!ExpressionMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
var id = _idCounter++;
|
||||
_ = "\"" + id.ToString() + "\"";
|
||||
result = new(expr.Name);
|
||||
ExpressionMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override ILDotOption Visit(TIR.PrimFunction expr)
|
||||
{
|
||||
if (!ExpressionMemo.TryGetValue(expr, out var result))
|
||||
{
|
||||
var id = _idCounter++;
|
||||
_ = "\"" + id.ToString() + "\"";
|
||||
result = new(expr.Name);
|
||||
ExpressionMemo.Add(expr, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override ILDotOption Visit(Fusion expr)
|
||||
{
|
||||
_entryBaseFunc ??= expr;
|
||||
if (!object.ReferenceEquals(_entryBaseFunc, expr))
|
||||
{
|
||||
if (_display_callable)
|
||||
{
|
||||
var visitor = new ILDotPrintVisitor(_display_callable);
|
||||
visitor.Visit(expr);
|
||||
_subdotGraphs.Add((expr.Name, visitor._dotGraph));
|
||||
_subdotGraphs.AddRange(visitor._subdotGraphs);
|
||||
}
|
||||
|
||||
return new(expr.Name);
|
||||
}
|
||||
|
||||
return base.Visit(expr);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override ILDotOption Visit(Function expr)
|
||||
{
|
||||
_entryBaseFunc ??= expr;
|
||||
if (!object.ReferenceEquals(_entryBaseFunc, expr))
|
||||
{
|
||||
if (_display_callable)
|
||||
{
|
||||
var visitor = new ILDotPrintVisitor(_display_callable);
|
||||
visitor.Visit(expr);
|
||||
_subdotGraphs.Add((expr.Name, visitor._dotGraph));
|
||||
_subdotGraphs.AddRange(visitor._subdotGraphs);
|
||||
}
|
||||
|
||||
return new(expr.Name);
|
||||
}
|
||||
|
||||
return base.Visit(expr);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override ILDotOption VisitLeaf(Fusion expr)
|
||||
{
|
||||
return new(expr.Name);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override ILDotOption VisitLeaf(Function expr)
|
||||
{
|
||||
return new(expr.Name);
|
||||
}
|
||||
|
||||
public override ILDotOption VisitLeaf(Tuple expr)
|
||||
{
|
||||
var id = _idCounter++;
|
||||
string exprId = "\"" + id.ToString() + "\"";
|
||||
|
||||
var table = new DotHtmlTable
|
||||
{
|
||||
BorderWidth = 0,
|
||||
CellBorderWidth = 1,
|
||||
CellSpacing = 0,
|
||||
};
|
||||
|
||||
var connect_list = new List<(Expr, string)>();
|
||||
|
||||
// 1. the connect type.
|
||||
table.AddRow(row =>
|
||||
{
|
||||
row.AddCell("Tuple"); // key wrods type.
|
||||
int count = 0;
|
||||
foreach (var child in expr.Fields)
|
||||
{
|
||||
var childnode = Visit(child);
|
||||
var portName = $"P{count++}";
|
||||
row.AddCell(childnode.IsDotNode ? string.Empty : childnode.Str, cell => cell.PortName = portName);
|
||||
if (childnode.IsDotNode)
|
||||
{
|
||||
connect_list.Add((child, portName));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// 3. make crrent node.
|
||||
var dotNode = _dotGraph.Nodes.Add(exprId);
|
||||
dotNode.ToPlainHtmlNode(table);
|
||||
|
||||
// 4. connect edge.
|
||||
foreach (var (child, port_name) in connect_list)
|
||||
{
|
||||
_dotGraph.Edges.Add(Visit(child).DotNode, dotNode, edge =>
|
||||
{
|
||||
edge.Head.Endpoint.Port = new DotEndpointPort(port_name);
|
||||
});
|
||||
}
|
||||
|
||||
return new(dotNode);
|
||||
}
|
||||
|
||||
public override ILDotOption VisitLeaf(Op expr)
|
||||
{
|
||||
return new(expr.GetType().Name + $"({expr.DisplayProperty()})");
|
||||
}
|
||||
|
||||
public override ILDotOption VisitLeaf(Const expr)
|
||||
{
|
||||
return new(CompilerServices.Print(expr));
|
||||
}
|
||||
|
||||
public override ILDotOption VisitLeaf(None expr)
|
||||
{
|
||||
return new("None");
|
||||
}
|
||||
|
||||
public override ILDotOption VisitLeaf(Marker expr)
|
||||
{
|
||||
var id = _idCounter++;
|
||||
string exprId = "\"" + id.ToString() + "\"";
|
||||
|
||||
var table = new DotHtmlTable
|
||||
{
|
||||
BorderWidth = 0,
|
||||
CellBorderWidth = 1,
|
||||
CellSpacing = 0,
|
||||
};
|
||||
var target = Visit(expr.Target);
|
||||
var attr = Visit(expr.Attribute);
|
||||
|
||||
// 1. the connect type.
|
||||
table.AddRow(row =>
|
||||
{
|
||||
row.AddCell("Marker"); // key wrods type.
|
||||
if (target.IsDotNode)
|
||||
{
|
||||
row.AddCell("Target", cell => cell.PortName = "P0"); // target.
|
||||
}
|
||||
else
|
||||
{
|
||||
row.AddCell(target.Str, cell => cell.PortName = "P0");
|
||||
}
|
||||
|
||||
if (attr.IsDotNode)
|
||||
{
|
||||
row.AddCell("Attr", cell => cell.PortName = "P1"); // attr
|
||||
}
|
||||
else
|
||||
{
|
||||
row.AddCell(attr.Str, cell => cell.PortName = "P1");
|
||||
}
|
||||
});
|
||||
table.AddRow(row =>
|
||||
{
|
||||
row.AddCell(expr.CheckedType is null ? "Null" : CompilerServices.Print(expr.CheckedType), cell => cell.ColumnSpan = 3);
|
||||
});
|
||||
|
||||
// 3. make crrent node.
|
||||
var dotNode = _dotGraph.Nodes.Add(exprId);
|
||||
dotNode.ToPlainHtmlNode(table);
|
||||
|
||||
// 4. connect edge.
|
||||
if (target.IsDotNode)
|
||||
{
|
||||
_dotGraph.Edges.Add(target.DotNode, dotNode, edge =>
|
||||
{
|
||||
edge.Head.Endpoint.Port = new DotEndpointPort("P0");
|
||||
});
|
||||
}
|
||||
|
||||
if (attr.IsDotNode)
|
||||
{
|
||||
_dotGraph.Edges.Add(attr.DotNode, dotNode, edge =>
|
||||
{
|
||||
edge.Head.Endpoint.Port = new DotEndpointPort("P1");
|
||||
});
|
||||
}
|
||||
|
||||
return new(dotNode);
|
||||
}
|
||||
|
||||
public override ILDotOption VisitLeaf(Var expr)
|
||||
{
|
||||
var id = _idCounter++;
|
||||
string exprId = "\"" + id.ToString() + "\"";
|
||||
var dotNode = new DotNode(exprId) { Label = expr.Name, Shape = DotNodeShape.Rectangle };
|
||||
_dotGraph.Nodes.Add(dotNode);
|
||||
return new(dotNode);
|
||||
}
|
||||
|
||||
public override ILDotOption VisitLeaf(Call expr)
|
||||
{
|
||||
var id = _idCounter++;
|
||||
string exprId = "\"" + id.ToString() + "\"";
|
||||
|
||||
var table = new DotHtmlTable
|
||||
{
|
||||
BorderWidth = 0,
|
||||
CellBorderWidth = 1,
|
||||
CellSpacing = 0,
|
||||
};
|
||||
|
||||
var connect_list = new List<(Expr, string)>();
|
||||
|
||||
// 1. the connect type.
|
||||
table.AddRow(row =>
|
||||
{
|
||||
row.AddCell("Call"); // key wrods type.
|
||||
row.AddCell(Visit(expr.Target).Str); // target.
|
||||
int count = 0;
|
||||
foreach (var (child, arg_name) in expr.Parameters.Zip(expr.Target switch
|
||||
{
|
||||
Op op => op.Parameters.Select(info => info.Name),
|
||||
Fusion fusion => fusion.Parameters.Select(v => v.Name),
|
||||
Function func => func.Parameters.Select(v => v.Name),
|
||||
PrimFunctionWrapper wrapper => wrapper.Target.Parameters.Select(b => b.Name),
|
||||
_ => throw new NotSupportedException($"Target type {expr.Target.GetType()} is not supported."),
|
||||
}))
|
||||
{
|
||||
if (child is Const or None)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var portName = $"P{count++}";
|
||||
row.AddCell(arg_name, cell => cell.PortName = portName);
|
||||
connect_list.Add((child, portName));
|
||||
}
|
||||
});
|
||||
table.AddRow(row =>
|
||||
{
|
||||
row.AddCell(expr.CheckedType is null ? "Null" : CompilerServices.Print(expr.CheckedType), cell => cell.ColumnSpan = connect_list.Count + 2);
|
||||
});
|
||||
|
||||
// 3. make crrent node.
|
||||
var dotNode = _dotGraph.Nodes.Add(exprId);
|
||||
dotNode.ToPlainHtmlNode(table);
|
||||
|
||||
// 4. connect edge.
|
||||
foreach (var (child, port_name) in connect_list)
|
||||
{
|
||||
_dotGraph.Edges.Add(Visit(child).DotNode, dotNode, edge =>
|
||||
{
|
||||
edge.Head.Endpoint.Port = new DotEndpointPort(port_name);
|
||||
});
|
||||
}
|
||||
|
||||
return new(dotNode);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Save the dot to File.
|
||||
/// </summary>
|
||||
/// <param name="name">name.</param>
|
||||
/// <param name="prefix">prefix.</param>
|
||||
/// <param name="dumpDir">dump dir.</param>
|
||||
public void SaveToFile(string name, string prefix, string dumpDir)
|
||||
{
|
||||
SaveToFileCore(_dotGraph, name, prefix, dumpDir);
|
||||
foreach (var (sub_name, subGraph) in _subdotGraphs)
|
||||
{
|
||||
SaveToFileCore(subGraph, name + "_" + sub_name, prefix, dumpDir);
|
||||
}
|
||||
}
|
||||
|
||||
private static void SaveToFileCore(DotGraph dotGraph, string name, string prefix, string dumpDir)
|
||||
{
|
||||
var nprefix = prefix.Any() ? prefix + "_" : prefix;
|
||||
string dump_path = Path.Combine(dumpDir, $"{nprefix}{name}.dot");
|
||||
dotGraph.Build();
|
||||
dotGraph.SaveToFile(dump_path);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
public static class IRHelpers
|
||||
{
|
||||
public static IRType? GetRawCheckedType(Expr expr) => expr.RawCheckedType;
|
||||
|
||||
public static void SetRawCheckedType(Expr expr, IRType? value) => expr.RawCheckedType = value;
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
BaseFunction,false,true,Default,,
|
||||
Call,true,false,Default,,Target;@Arguments
|
||||
Const,false,false,Default,,
|
||||
Function,true,true,BaseFunction,,@Parameters;Body
|
||||
Fusion,true,true,BaseFunction,,@Parameters;Body
|
||||
If,true,false,Default,,Condition;Then;Else
|
||||
Marker,true,false,Default,,Target;Attribute
|
||||
None,true,false,Default,,
|
||||
Op,true,false,Default,,
|
||||
PrimFunctionWrapper,true,true,BaseFunction,,Target
|
||||
TensorConst,true,false,Const,,
|
||||
Tuple,true,false,Default,IR.,@Fields
|
||||
TupleConst,true,false,Const,,
|
||||
Var,true,false,Default,,
|
||||
Block,true,false,Default,TIR.,Body;InitBody;@IterVars;@Reads;@Writes;@AllocBuffers;Predicate
|
||||
Buffer,false,false,Default,TIR.,
|
||||
LogicalBuffer,true,false,Buffer,TIR.,@Dimensions;@Strides
|
||||
PhysicalBuffer,true,false,Buffer,TIR.,
|
||||
BufferLoad,true,false,Default,TIR.,Buffer;@Indices
|
||||
BufferRegion,true,false,Default,TIR.,Buffer;@Region
|
||||
BufferStore,true,false,Default,TIR.,Buffer;@Indices;Value
|
||||
For,true,false,Default,TIR.,LoopVar;Domain;Body
|
||||
IfThenElse,true,false,Default,TIR.,Condition;Then;Else
|
||||
Let,true,false,Default,TIR.,Var;Expression;Body
|
||||
PrimFunction,true,true,Default,TIR.,@Parameters;Body
|
||||
Sequential,true,false,Default,TIR.,@Fields
|
||||
Range,true,false,Default,TIR.,Start;Stop;Step
|
||||
IterVar,true,false,Default,TIR.,Value;Dom
|
|
|
@ -0,0 +1,32 @@
|
|||
<#@ assembly name="System.Core" #>
|
||||
<#@ import namespace="System.IO" #>
|
||||
<#@ import namespace="System.Linq" #>
|
||||
<#@ import namespace="System.Text" #>
|
||||
<#@ import namespace="System.Collections.Generic" #>
|
||||
<#@ import namespace="System.Diagnostics" #>
|
||||
<#
|
||||
var irs = (from l in File.ReadAllLines("src/Nncase.Core/IR/IRList.csv")
|
||||
where !string.IsNullOrWhiteSpace(l)
|
||||
let columns = l.Split(',')
|
||||
let isDerived = bool.Parse(columns[1])
|
||||
select new IRDef
|
||||
{
|
||||
Name = columns[0],
|
||||
IsDerived = isDerived,
|
||||
IsFunction = bool.Parse(columns[2]),
|
||||
VisitBase = columns[3],
|
||||
Namespace = columns[4],
|
||||
Fields = isDerived ? columns[5].Split(new[]{';'}, StringSplitOptions.RemoveEmptyEntries) : Array.Empty<string>()
|
||||
}).ToArray();
|
||||
#>
|
||||
<#+
|
||||
struct IRDef
|
||||
{
|
||||
public string Name;
|
||||
public bool IsDerived;
|
||||
public bool IsFunction;
|
||||
public string VisitBase;
|
||||
public string Namespace;
|
||||
public string[] Fields;
|
||||
}
|
||||
#>
|
|
@ -4,9 +4,12 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reactive;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR;
|
||||
using TorchSharp.Modules;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
|
@ -57,6 +60,7 @@ public sealed class IRModule
|
|||
/// <param name="function">Callable to add.</param>
|
||||
public void Add(BaseFunction function)
|
||||
{
|
||||
CompilerServices.InferenceType(function);
|
||||
_functions.Add(function);
|
||||
}
|
||||
|
||||
|
@ -67,47 +71,9 @@ public sealed class IRModule
|
|||
/// <param name="function">the entry function defination.</param>
|
||||
public void Replace(int index, BaseFunction function)
|
||||
{
|
||||
var old = _functions[index];
|
||||
|
||||
var replacer = new FunctionReplacer(old, function);
|
||||
for (int i = 0; i < _functions.Count; i++)
|
||||
{
|
||||
replacer.Visit(_functions[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < _functions.Count; i++)
|
||||
{
|
||||
var originFunc = _functions[i];
|
||||
if (replacer.ExpressionMemo.TryGetValue(originFunc, out var replace))
|
||||
{
|
||||
_functions[i] = (BaseFunction)replace;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Replace the function call dependencer.
|
||||
/// </summary>
|
||||
private sealed class FunctionReplacer : DeepExprMutator
|
||||
{
|
||||
private readonly BaseFunction _original;
|
||||
private readonly BaseFunction _replace;
|
||||
|
||||
public FunctionReplacer(BaseFunction original, BaseFunction replace)
|
||||
{
|
||||
_original = original;
|
||||
_replace = replace;
|
||||
}
|
||||
|
||||
public override Expr DefaultMutateLeaf(Expr expr)
|
||||
{
|
||||
if (expr is BaseFunction baseFunction
|
||||
&& object.ReferenceEquals(baseFunction, _original))
|
||||
{
|
||||
return _replace;
|
||||
}
|
||||
|
||||
return expr;
|
||||
}
|
||||
CompilerServices.InferenceType(function);
|
||||
ref var old = ref CollectionsMarshal.AsSpan(_functions)[index];
|
||||
old.ReplaceAllUsesWith(function);
|
||||
old = function;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,16 +12,10 @@ namespace Nncase.IR;
|
|||
/// <summary>
|
||||
/// Tuple interface.
|
||||
/// </summary>
|
||||
public interface ITuple : IReadOnlyList<Expr>
|
||||
public interface ITuple
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets fields.
|
||||
/// Gets fields count.
|
||||
/// </summary>
|
||||
IReadOnlyList<Expr> Fields { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Cast to expression.
|
||||
/// </summary>
|
||||
/// <returns>The expression.</returns>
|
||||
public Expr AsExpr() => (Expr)this;
|
||||
int Count { get; }
|
||||
}
|
||||
|
|
|
@ -12,6 +12,23 @@ namespace Nncase.IR;
|
|||
/// <summary>
|
||||
/// if(Condition) then { Then } else { Else }.
|
||||
/// </summary>
|
||||
public sealed record If(Expr Condition, Expr Then, Expr Else) : Expr
|
||||
public sealed class If : Expr
|
||||
{
|
||||
public If(Expr condition, Expr then, Expr @else)
|
||||
: base(new[] { condition, then, @else })
|
||||
{
|
||||
}
|
||||
|
||||
public Expr Condition => Operands[0];
|
||||
|
||||
public Expr Then => Operands[1];
|
||||
|
||||
public Expr Else => Operands[2];
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
|
||||
=> functor.VisitIf(this, context);
|
||||
|
||||
public If With(Expr? condition = null, Expr? then = null, Expr? @else = null)
|
||||
=> new If(condition ?? Condition, then ?? Then, @else ?? Else);
|
||||
}
|
||||
|
|
|
@ -10,49 +10,53 @@ using System.Threading.Tasks;
|
|||
using Nncase.PatternMatch;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
|
||||
namespace Nncase.IR.Imaging
|
||||
namespace Nncase.IR.Imaging;
|
||||
|
||||
/// <summary>
|
||||
/// ResizeImage expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed partial class ResizeImage : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// ResizeImage expression.
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record ResizeImage(
|
||||
ImageResizeMode ResizeMode,
|
||||
ImageResizeTransformationMode TransformationMode,
|
||||
ImageResizeNearestMode NearestMode, bool IsTFResize = false) : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(ResizeImage), 0, "input", HasRank(r => r >= 2, "RanK >= 2"));
|
||||
public static readonly ParameterInfo Input = new(typeof(ResizeImage), 0, "input", HasRank(r => r >= 2, "RanK >= 2"));
|
||||
|
||||
/// <summary>
|
||||
/// Gets roi.
|
||||
/// [axis 0 start,axis 1 end, ... ,axis n start, axis n end].
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Roi = new(typeof(ResizeImage), 1, "roi", IsNoneType() | (IsFloat() & HasRank(1)));
|
||||
/// <summary>
|
||||
/// Gets roi.
|
||||
/// [axis 0 start,axis 1 end, ... ,axis n start, axis n end].
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Roi = new(typeof(ResizeImage), 1, "roi", IsNoneType() | (IsFloat() & HasRank(1)));
|
||||
|
||||
/// <summary>
|
||||
/// Gets new_size.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo NewSize = new(typeof(ResizeImage), 2, "new_size", HasShape(new[] { 4 }));
|
||||
/// <summary>
|
||||
/// Gets new_size.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo NewSize = new(typeof(ResizeImage), 2, "new_size", HasShape(new[] { 4 }));
|
||||
|
||||
/// <summary>
|
||||
/// Gets CubicCoeffA.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo CubicCoeffA = new(typeof(ResizeImage), 3, "cubic_coeff_a", IsNoneType() | IsFloatScalar());
|
||||
/// <summary>
|
||||
/// Gets CubicCoeffA.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo CubicCoeffA = new(typeof(ResizeImage), 3, "cubic_coeff_a", IsNoneType() | IsFloatScalar());
|
||||
|
||||
/// <summary>
|
||||
/// Gets ExcludeOutside.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo ExcludeOutside = new(typeof(ResizeImage), 4, "exclude_outside", IsNoneType() | IsIntegralScalar());
|
||||
/// <summary>
|
||||
/// Gets ExcludeOutside.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo ExcludeOutside = new(typeof(ResizeImage), 4, "exclude_outside", IsNoneType() | IsIntegralScalar());
|
||||
|
||||
/// <summary>
|
||||
/// Gets ExtrapolationValue.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo ExtrapolationValue = new(typeof(ResizeImage), 5, "extrapolation_value", IsNoneType() | IsFloatScalar());
|
||||
/// <summary>
|
||||
/// Gets ExtrapolationValue.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo ExtrapolationValue = new(typeof(ResizeImage), 5, "extrapolation_value", IsNoneType() | IsFloatScalar());
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"ImageResizeMode.{ResizeMode}, ImageResizeTransformationMode.{TransformationMode}, ImageResizeNearestMode.{NearestMode}, {IsTFResize}";
|
||||
}
|
||||
public ImageResizeMode ResizeMode { get; }
|
||||
|
||||
public ImageResizeTransformationMode TransformationMode { get; }
|
||||
|
||||
public ImageResizeNearestMode NearestMode { get; }
|
||||
|
||||
public bool IsTFResize { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"ImageResizeMode.{ResizeMode}, ImageResizeTransformationMode.{TransformationMode}, ImageResizeNearestMode.{NearestMode}, {IsTFResize}";
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ using System.Collections.Generic;
|
|||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Toolkit.HighPerformance.Helpers;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
|
@ -43,12 +44,12 @@ public class LeafExprEqualityComparer : IEqualityComparer<Expr>
|
|||
(Const tx, Const ty) => tx.Equals(ty),
|
||||
|
||||
// note think of fusion/primfunc/primfunc wrapper as a black box.
|
||||
(Fusion tx, Fusion ty) => ReferenceEqualityComparer.Instance.Equals(tx, ty),
|
||||
(TIR.PrimFunction tx, TIR.PrimFunction ty) => ReferenceEqualityComparer.Instance.Equals(tx, ty),
|
||||
(PrimFunctionWrapper tx, PrimFunctionWrapper ty) => ReferenceEqualityComparer.Instance.Equals(tx, ty),
|
||||
(Function tx, Function ty) => tx.Parameters.Equals(ty.Parameters),
|
||||
(Fusion tx, Fusion ty) => ReferenceEquals(tx, ty),
|
||||
(TIR.PrimFunction tx, TIR.PrimFunction ty) => ReferenceEquals(tx, ty),
|
||||
(PrimFunctionWrapper tx, PrimFunctionWrapper ty) => ReferenceEquals(tx, ty),
|
||||
(Function tx, Function ty) => tx.Parameters.Length == ty.Parameters.Length,
|
||||
(Tuple tx, Tuple ty) => tx.Count == ty.Count,
|
||||
(Call tx, Call ty) => tx.Parameters.Count == ty.Parameters.Count,
|
||||
(Call tx, Call ty) => tx.Arguments.Length == ty.Arguments.Length,
|
||||
(Op tx, Op ty) => tx.Equals(ty),
|
||||
(IR.If, IR.If) => true,
|
||||
(Marker tx, Marker ty) => tx.Name == ty.Name,
|
||||
|
@ -64,14 +65,14 @@ public class LeafExprEqualityComparer : IEqualityComparer<Expr>
|
|||
{
|
||||
Var x => x.GetHashCode(),
|
||||
Const x => x.GetHashCode(),
|
||||
Function x => x.Parameters.GetHashCode(),
|
||||
Function x => ReferenceEqualityComparer.Instance.GetHashCode(x),
|
||||
Fusion x => ReferenceEqualityComparer.Instance.GetHashCode(x),
|
||||
TIR.PrimFunction x => ReferenceEqualityComparer.Instance.GetHashCode(x),
|
||||
PrimFunctionWrapper x => ReferenceEqualityComparer.Instance.GetHashCode(x),
|
||||
Tuple x => x.Count.GetHashCode(),
|
||||
Call x => x.Parameters.Count.GetHashCode(),
|
||||
Call x => x.Arguments.Length.GetHashCode(),
|
||||
Op x => x.GetHashCode(),
|
||||
Marker x => x.Name.GetHashCode(StringComparison.InvariantCulture),
|
||||
Marker x => x.Name.GetHashCode(StringComparison.Ordinal),
|
||||
None x => x.GetHashCode(),
|
||||
IR.If x => x.GetType().GetHashCode(),
|
||||
_ => throw new InvalidOperationException("Invalid expression type."),
|
||||
|
|
|
@ -9,6 +9,17 @@ using System.Threading.Tasks;
|
|||
|
||||
namespace Nncase.IR;
|
||||
|
||||
/// <summary>
|
||||
/// staic marker name collection.
|
||||
/// </summary>
|
||||
public static class WellknownMarkerNames
|
||||
{
|
||||
/// <summary>
|
||||
/// attribute. <seealso cref="IR.Math.RangeOf"/>
|
||||
/// </summary>
|
||||
public static readonly string RangeOf = "RangeOf";
|
||||
}
|
||||
|
||||
public class MixQuantInfo
|
||||
{
|
||||
public bool HasBindedMixQuantInfo { get; set; }
|
||||
|
@ -38,11 +49,28 @@ public class AdaQuantInfo
|
|||
/// <summary>
|
||||
/// The marker expression, it's can attach the attribute on the target.
|
||||
/// </summary>
|
||||
/// <param name="Name"> Name will belong to <see cref="WellknownMarkerNames"/>. </param>
|
||||
/// <param name="Target"> expr target. </param>
|
||||
/// <param name="Attribute"> expr attribute. </param>
|
||||
public sealed record Marker(string Name, Expr Target, Expr Attribute) : Expr
|
||||
public sealed class Marker : Expr, IEquatable<Marker?>
|
||||
{
|
||||
private readonly string _name;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="Marker"/> class.
|
||||
/// </summary>
|
||||
/// <param name="name">Name will belong to <see cref="WellknownMarkerNames"/>.</param>
|
||||
/// <param name="target">expr target.</param>
|
||||
/// <param name="attribute">expr attribute.</param>
|
||||
public Marker(string name, Expr target, Expr attribute)
|
||||
: base(new[] { target, attribute })
|
||||
{
|
||||
_name = name;
|
||||
}
|
||||
|
||||
public string Name => _name;
|
||||
|
||||
public Expr Target => Operands[0];
|
||||
|
||||
public Expr Attribute => Operands[1];
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the mix quant info.
|
||||
/// </summary>
|
||||
|
@ -52,15 +80,36 @@ public sealed record Marker(string Name, Expr Target, Expr Attribute) : Expr
|
|||
/// Gets or sets the ada quant info.
|
||||
/// </summary>
|
||||
public AdaQuantInfo? AdaQuantInfo { get; set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// staic marker name collection.
|
||||
/// </summary>
|
||||
public static class WellknownMarkerNames
|
||||
{
|
||||
/// <summary>
|
||||
/// attribute. <seealso cref="IR.Math.RangeOf"/>
|
||||
/// </summary>
|
||||
public static readonly string RangeOf = "RangeOf";
|
||||
public static bool operator ==(Marker? left, Marker? right) => EqualityComparer<Marker>.Default.Equals(left, right);
|
||||
|
||||
public static bool operator !=(Marker? left, Marker? right) => !(left == right);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
|
||||
=> functor.VisitMarker(this, context);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override bool Equals(object? obj) => Equals(obj as Marker);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public bool Equals(Marker? other)
|
||||
{
|
||||
if (ReferenceEquals(this, other))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
return other is not null && base.Equals(other) && Name == other.Name;
|
||||
}
|
||||
|
||||
public Marker With(string? name = null, Expr? target = null, Expr? attribute = null, MixQuantInfo? mixQuantInfo = null, AdaQuantInfo? adaQuantInfo = null)
|
||||
=> new Marker(name ?? Name, target ?? Target, attribute ?? Attribute)
|
||||
{
|
||||
MixQuantInfo = mixQuantInfo ?? MixQuantInfo,
|
||||
AdaQuantInfo = adaQuantInfo ?? AdaQuantInfo,
|
||||
};
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override int GetHashCodeCore() => HashCode.Combine(base.GetHashCodeCore(), Name);
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ namespace Nncase.IR.Math;
|
|||
/// Binary expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Binary(BinaryOp BinaryOp) : Op
|
||||
public sealed partial class Binary : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets lhs.
|
||||
|
@ -27,6 +27,8 @@ public sealed record Binary(BinaryOp BinaryOp) : Op
|
|||
/// </summary>
|
||||
public static readonly ParameterInfo Rhs = new(typeof(Binary), 1, "rhs");
|
||||
|
||||
public BinaryOp BinaryOp { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty()
|
||||
{
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace Nncase.IR.Math;
|
|||
/// Clamp expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Clamp() : Op
|
||||
public sealed partial class Clamp : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
|
|
@ -15,7 +15,7 @@ namespace Nncase.IR.Math;
|
|||
/// Binary expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Compare(CompareOp CompareOp) : Op
|
||||
public sealed partial class Compare : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets lhs.
|
||||
|
@ -27,6 +27,8 @@ public sealed record Compare(CompareOp CompareOp) : Op
|
|||
/// </summary>
|
||||
public static readonly ParameterInfo Rhs = new(typeof(Compare), 1, "rhs");
|
||||
|
||||
public CompareOp CompareOp { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"CompareOp.{CompareOp}";
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace Nncase.IR.Math;
|
|||
/// Condition operation.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Condition() : Op
|
||||
public sealed partial class Condition : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets Condition.
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace Nncase.IR.Math;
|
|||
/// CumSum expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record CumSum() : Op
|
||||
public sealed partial class CumSum : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
|
|
@ -10,7 +10,7 @@ namespace Nncase.IR.Math;
|
|||
/// Dequantize expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Dequantize(DataType TargetType) : Op
|
||||
public sealed partial class Dequantize : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -22,6 +22,8 @@ public sealed record Dequantize(DataType TargetType) : Op
|
|||
/// </summary>
|
||||
public static readonly ParameterInfo DequantParam = new(typeof(Dequantize), 1, "dequantParam", IsQuantParamType());
|
||||
|
||||
public DataType TargetType { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"{TargetType.GetCSharpName()}";
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ namespace Nncase.IR.Math;
|
|||
/// Fake dequantize expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record FakeDequantize(DataType TargetType) : Op
|
||||
public sealed partial class FakeDequantize : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -21,6 +21,8 @@ public sealed record FakeDequantize(DataType TargetType) : Op
|
|||
/// </summary>
|
||||
public static readonly ParameterInfo DequantParam = new(typeof(FakeDequantize), 1, "dequantParam");
|
||||
|
||||
public DataType TargetType { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"{TargetType.GetCSharpName()}";
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ namespace Nncase.IR.Math;
|
|||
/// Fake quantize expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record FakeQuantize(DataType TargetType) : Op
|
||||
public sealed partial class FakeQuantize : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -21,6 +21,8 @@ public sealed record FakeQuantize(DataType TargetType) : Op
|
|||
/// </summary>
|
||||
public static readonly ParameterInfo QuantParam = new(typeof(FakeQuantize), 1, "quantParam");
|
||||
|
||||
public DataType TargetType { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"{TargetType.GetCSharpName()}";
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ namespace Nncase.IR.Math;
|
|||
/// MatMul expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record MatMul() : Op
|
||||
public sealed partial class MatMul : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
|
|
@ -10,7 +10,7 @@ namespace Nncase.IR.Math;
|
|||
/// QuantParamOf expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record QuantParamOf(QuantMode QuantMode) : Op
|
||||
public sealed partial class QuantParamOf : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets range.
|
||||
|
@ -22,6 +22,8 @@ public sealed record QuantParamOf(QuantMode QuantMode) : Op
|
|||
/// </summary>
|
||||
public static readonly ParameterInfo Bits = new(typeof(QuantParamOf), 1, "bits", IsIntegralScalar());
|
||||
|
||||
public QuantMode QuantMode { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"QuantMode.{QuantMode}";
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ namespace Nncase.IR.Math;
|
|||
/// Quantize expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Quantize(DataType TargetType) : Op
|
||||
public sealed partial class Quantize : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -22,6 +22,8 @@ public sealed record Quantize(DataType TargetType) : Op
|
|||
/// </summary>
|
||||
public static readonly ParameterInfo QuantParam = new(typeof(Quantize), 1, "quantParam", IsQuantParamType());
|
||||
|
||||
public DataType TargetType { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"{TargetType.GetCSharpName()}";
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ namespace Nncase.IR.Math;
|
|||
/// RangeOf expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record RangeOf() : Op
|
||||
public sealed partial class RangeOf : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
|
|
@ -10,7 +10,7 @@ namespace Nncase.IR.Math;
|
|||
/// Reduce expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Reduce(ReduceOp ReduceOp) : Op
|
||||
public sealed partial class Reduce : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -32,6 +32,8 @@ public sealed record Reduce(ReduceOp ReduceOp) : Op
|
|||
/// </summary>
|
||||
public static readonly ParameterInfo KeepDims = new(typeof(Reduce), 3, "keepDims", IsScalar() & IsIntegral());
|
||||
|
||||
public ReduceOp ReduceOp { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"ReduceOp.{ReduceOp}";
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace Nncase.IR.Math;
|
|||
/// ReduceArg expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record ReduceArg(ReduceArgOp ReduceArgOp, DataType DestType) : Op
|
||||
public sealed partial class ReduceArg : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -40,6 +40,10 @@ public sealed record ReduceArg(ReduceArgOp ReduceArgOp, DataType DestType) : Op
|
|||
/// <remarks>Only used in onnx.</remarks>
|
||||
public static readonly ParameterInfo SelectLastIndex = new(typeof(ReduceArg), 3, "selectLastIndex", IsBoolScalar());
|
||||
|
||||
public ReduceArgOp ReduceArgOp { get; }
|
||||
|
||||
public DataType DestType { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"ReduceArgOp.{ReduceArgOp}";
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace Nncase.IR.Math;
|
|||
/// Require expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Require(string Message) : Op
|
||||
public sealed partial class Require : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets Condition.
|
||||
|
@ -28,6 +28,8 @@ public sealed record Require(string Message) : Op
|
|||
/// </summary>
|
||||
public static readonly ParameterInfo Value = new(typeof(Require), 1, "value");
|
||||
|
||||
public string Message { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => "\"\"";
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace Nncase.IR.Math;
|
|||
/// Unary expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Select() : Op
|
||||
public sealed partial class Select : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets Condition.
|
||||
|
|
|
@ -15,13 +15,15 @@ namespace Nncase.IR.Math;
|
|||
/// Unary expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Unary(UnaryOp UnaryOp) : Op
|
||||
public sealed partial class Unary : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Unary), 0, "input");
|
||||
|
||||
public UnaryOp UnaryOp { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty()
|
||||
{
|
||||
|
|
|
@ -15,7 +15,7 @@ namespace Nncase.IR.NN;
|
|||
/// <summary>
|
||||
/// The base class.
|
||||
/// </summary>
|
||||
public abstract record ActivationOp : Op
|
||||
public abstract class ActivationOp : Op
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -23,7 +23,7 @@ public abstract record ActivationOp : Op
|
|||
/// Sigmoid expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Sigmoid() : ActivationOp
|
||||
public sealed partial class Sigmoid : ActivationOp
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -35,7 +35,7 @@ public sealed record Sigmoid() : ActivationOp
|
|||
/// Relu expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Relu() : ActivationOp
|
||||
public sealed partial class Relu : ActivationOp
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -47,7 +47,7 @@ public sealed record Relu() : ActivationOp
|
|||
/// Relu6 expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Relu6() : ActivationOp
|
||||
public sealed partial class Relu6 : ActivationOp
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -59,7 +59,7 @@ public sealed record Relu6() : ActivationOp
|
|||
/// PRelu expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record PRelu() : ActivationOp
|
||||
public sealed partial class PRelu : ActivationOp
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -76,7 +76,7 @@ public sealed record PRelu() : ActivationOp
|
|||
/// LeakyRelu expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record LeakyRelu() : ActivationOp
|
||||
public sealed partial class LeakyRelu : ActivationOp
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -93,7 +93,7 @@ public sealed record LeakyRelu() : ActivationOp
|
|||
/// Celu expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Celu() : ActivationOp
|
||||
public sealed partial class Celu : ActivationOp
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -110,7 +110,7 @@ public sealed record Celu() : ActivationOp
|
|||
/// Selu expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Selu() : ActivationOp
|
||||
public sealed partial class Selu : ActivationOp
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -132,7 +132,7 @@ public sealed record Selu() : ActivationOp
|
|||
/// Elu expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Elu() : ActivationOp
|
||||
public sealed partial class Elu : ActivationOp
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -149,7 +149,7 @@ public sealed record Elu() : ActivationOp
|
|||
/// HardSwish expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record HardSwish() : ActivationOp
|
||||
public sealed partial class HardSwish : ActivationOp
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -161,7 +161,7 @@ public sealed record HardSwish() : ActivationOp
|
|||
/// HardSigmoid expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record HardSigmoid() : ActivationOp
|
||||
public sealed partial class HardSigmoid : ActivationOp
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -183,7 +183,7 @@ public sealed record HardSigmoid() : ActivationOp
|
|||
/// Erf expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Erf() : ActivationOp
|
||||
public sealed partial class Erf : ActivationOp
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
|
|
@ -15,7 +15,7 @@ namespace Nncase.IR.NN;
|
|||
/// BatchToSpace expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record BatchToSpace() : Op
|
||||
public sealed partial class BatchToSpace : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
|
|
@ -10,55 +10,56 @@ using System.Threading.Tasks;
|
|||
using Nncase.PatternMatch;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
|
||||
namespace Nncase.IR.NN
|
||||
namespace Nncase.IR.NN;
|
||||
|
||||
/// <summary>
|
||||
/// Conv2D.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed partial class Conv2D : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Conv2D.
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Conv2D(PadMode PadMode) : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Conv2D), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(Conv2D), 0, "input");
|
||||
|
||||
/// <summary>
|
||||
/// Gets Weights.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Weights = new(typeof(Conv2D), 1, "weights", HasRank(4));
|
||||
/// <summary>
|
||||
/// Gets Weights.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Weights = new(typeof(Conv2D), 1, "weights", HasRank(4));
|
||||
|
||||
/// <summary>
|
||||
/// Gets Bias.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Bias = new(typeof(Conv2D), 2, "bias", HasRank(1));
|
||||
/// <summary>
|
||||
/// Gets Bias.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Bias = new(typeof(Conv2D), 2, "bias", HasRank(1));
|
||||
|
||||
/// <summary>
|
||||
/// Gets Stride.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Stride = new(typeof(Conv2D), 3, "stride", HasRank(1) & IsIntegral());
|
||||
/// <summary>
|
||||
/// Gets Stride.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Stride = new(typeof(Conv2D), 3, "stride", HasRank(1) & IsIntegral());
|
||||
|
||||
/// <summary>
|
||||
/// Gets Padding.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Padding = new(typeof(Conv2D), 4, "padding", HasRank(2) & IsIntegral());
|
||||
/// <summary>
|
||||
/// Gets Padding.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Padding = new(typeof(Conv2D), 4, "padding", HasRank(2) & IsIntegral());
|
||||
|
||||
/// <summary>
|
||||
/// Gets Dilation.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Dilation = new(typeof(Conv2D), 5, "dilation", HasRank(1) & IsIntegral());
|
||||
/// <summary>
|
||||
/// Gets Dilation.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Dilation = new(typeof(Conv2D), 5, "dilation", HasRank(1) & IsIntegral());
|
||||
|
||||
/// <summary>
|
||||
/// Gets Groups.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Groups = new(typeof(Conv2D), 6, "groups", IsScalar() & IsIntegral());
|
||||
/// <summary>
|
||||
/// Gets Groups.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Groups = new(typeof(Conv2D), 6, "groups", IsScalar() & IsIntegral());
|
||||
|
||||
/// <summary>
|
||||
/// Gets FusedClamp.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo FusedClamp = new(typeof(Conv2D), 7, "fused_clamp", HasShape(new Shape(2)) & HasDataType(DataTypes.Float32));
|
||||
/// <summary>
|
||||
/// Gets FusedClamp.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo FusedClamp = new(typeof(Conv2D), 7, "fused_clamp", HasShape(new Shape(2)) & HasDataType(DataTypes.Float32));
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"PadMode.{PadMode}";
|
||||
}
|
||||
public PadMode PadMode { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"PadMode.{PadMode}";
|
||||
}
|
||||
|
|
|
@ -11,65 +11,66 @@ using Nncase.IR.Tensors;
|
|||
using Nncase.PatternMatch;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
|
||||
namespace Nncase.IR.NN
|
||||
namespace Nncase.IR.NN;
|
||||
|
||||
/// <summary>
|
||||
/// Conv2DTranspose.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed partial class Conv2DTranspose : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Conv2DTranspose.
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Conv2DTranspose(PadMode PadMode) : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Conv2DTranspose), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(Conv2DTranspose), 0, "input");
|
||||
|
||||
/// <summary>
|
||||
/// Gets Weights.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Weights = new(typeof(Conv2DTranspose), 1, "weights");
|
||||
/// <summary>
|
||||
/// Gets Weights.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Weights = new(typeof(Conv2DTranspose), 1, "weights");
|
||||
|
||||
/// <summary>
|
||||
/// Gets Bias.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Bias = new(typeof(Conv2DTranspose), 2, "bias");
|
||||
/// <summary>
|
||||
/// Gets Bias.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Bias = new(typeof(Conv2DTranspose), 2, "bias");
|
||||
|
||||
/// <summary>
|
||||
/// Gets OutputShape.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo OutputShape = new(typeof(Conv2DTranspose), 3, "outputShape");
|
||||
/// <summary>
|
||||
/// Gets OutputShape.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo OutputShape = new(typeof(Conv2DTranspose), 3, "outputShape");
|
||||
|
||||
/// <summary>
|
||||
/// Gets Stride.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Stride = new(typeof(Conv2DTranspose), 4, "stride");
|
||||
/// <summary>
|
||||
/// Gets Stride.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Stride = new(typeof(Conv2DTranspose), 4, "stride");
|
||||
|
||||
/// <summary>
|
||||
/// Gets Padding.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Padding = new(typeof(Conv2DTranspose), 5, "padding");
|
||||
/// <summary>
|
||||
/// Gets Padding.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Padding = new(typeof(Conv2DTranspose), 5, "padding");
|
||||
|
||||
/// <summary>
|
||||
/// Gets Output Padding.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo OutputPadding = new(typeof(Conv2DTranspose), 6, "output_padding");
|
||||
/// <summary>
|
||||
/// Gets Output Padding.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo OutputPadding = new(typeof(Conv2DTranspose), 6, "output_padding");
|
||||
|
||||
/// <summary>
|
||||
/// Gets Dilation.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Dilation = new(typeof(Conv2DTranspose), 7, "dilation");
|
||||
/// <summary>
|
||||
/// Gets Dilation.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Dilation = new(typeof(Conv2DTranspose), 7, "dilation");
|
||||
|
||||
/// <summary>
|
||||
/// Gets Groups.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Groups = new(typeof(Conv2DTranspose), 8, "groups");
|
||||
/// <summary>
|
||||
/// Gets Groups.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Groups = new(typeof(Conv2DTranspose), 8, "groups");
|
||||
|
||||
/// <summary>
|
||||
/// Gets FusedClamp.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo FusedClamp = new(typeof(Conv2D), 9, "fused_clamp", HasShape(new Shape(2)) & HasDataType(DataTypes.Float32));
|
||||
/// <summary>
|
||||
/// Gets FusedClamp.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo FusedClamp = new(typeof(Conv2D), 9, "fused_clamp", HasShape(new Shape(2)) & HasDataType(DataTypes.Float32));
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"PadMode.{PadMode}";
|
||||
}
|
||||
public PadMode PadMode { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"PadMode.{PadMode}";
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace Nncase.IR.NN;
|
|||
/// Hardmax expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Hardmax() : Op
|
||||
public sealed partial class Hardmax : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
#if false
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Collections.Immutable;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.PatternMatch;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
|
||||
namespace Nncase.IR.NN
|
||||
{
|
||||
public enum lstm_direction
|
||||
{
|
||||
kForward,
|
||||
kReverse,
|
||||
kBidirectional
|
||||
}
|
||||
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record LSTM(lstm_direction direction, String framework) : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(LSTM), 0, "input", HasDataType(DataTypes.Float32));
|
||||
|
||||
/// <summary>
|
||||
/// Gets w.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo W = new(typeof(LSTM), 1, "w");
|
||||
|
||||
/// <summary>
|
||||
/// Gets r.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo R = new(typeof(LSTM), 2, "r");
|
||||
|
||||
/// <summary>
|
||||
/// Gets b.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo B = new(typeof(LSTM), 3, "b");
|
||||
|
||||
/// <summary>
|
||||
/// Gets initial_h.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo initial_h = new(typeof(LSTM), 4, "initial_h",
|
||||
HasDataType(DataTypes.Float32));
|
||||
|
||||
/// <summary>
|
||||
/// Gets initial_c.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo initial_c = new(typeof(LSTM), 5, "initial_c",
|
||||
HasDataType(DataTypes.Float32));
|
||||
|
||||
/// <summary>
|
||||
/// Gets has_static.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo has_static = new(typeof(LSTM), 6, "has_static", IsBool());
|
||||
}
|
||||
}
|
||||
#endif
|
|
@ -16,7 +16,7 @@ namespace Nncase.IR.NN;
|
|||
/// Hardmax expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record LayerNorm(int Axis, float Epsilon) : Op
|
||||
public sealed partial class LayerNorm : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -32,4 +32,8 @@ public sealed record LayerNorm(int Axis, float Epsilon) : Op
|
|||
/// Gets bias.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Bias = new(typeof(LayerNorm), 2, "bias");
|
||||
|
||||
public int Axis { get; }
|
||||
|
||||
public float Epsilon { get; }
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ using static Nncase.IR.TypePatternUtility;
|
|||
namespace Nncase.IR.NN;
|
||||
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record L2Normalization() : Op
|
||||
public sealed partial class L2Normalization : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -17,7 +17,7 @@ public sealed record L2Normalization() : Op
|
|||
}
|
||||
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record BatchNormalization() : Op
|
||||
public sealed partial class BatchNormalization : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -56,7 +56,7 @@ public sealed record BatchNormalization() : Op
|
|||
}
|
||||
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record InstanceNormalization() : Op
|
||||
public sealed partial class InstanceNormalization : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -80,7 +80,7 @@ public sealed record InstanceNormalization() : Op
|
|||
}
|
||||
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record LpNormalization() : Op
|
||||
public sealed partial class LpNormalization : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -99,7 +99,7 @@ public sealed record LpNormalization() : Op
|
|||
}
|
||||
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record LRN() : Op
|
||||
public sealed partial class LRN : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
|
|
@ -15,7 +15,7 @@ namespace Nncase.IR.NN;
|
|||
/// OneHot expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record OneHot(OneHotMode OneHotMode) : Op
|
||||
public sealed partial class OneHot : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
|
@ -37,6 +37,8 @@ public sealed record OneHot(OneHotMode OneHotMode) : Op
|
|||
/// </summary>
|
||||
public static readonly ParameterInfo Axis = new(typeof(OneHot), 3, "axis");
|
||||
|
||||
public OneHotMode OneHotMode { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"OneHotMode.{OneHotMode}";
|
||||
}
|
||||
|
|
|
@ -9,9 +9,8 @@ namespace Nncase.IR.NN;
|
|||
/// <summary>
|
||||
/// Pad tensor, a little difference with pytorch pad.
|
||||
/// </summary>
|
||||
/// <param name="PadMode">Pad mode.</param>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed record Pad(PadMode PadMode) : Op
|
||||
public sealed partial class Pad : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// input.
|
||||
|
@ -28,6 +27,11 @@ public sealed record Pad(PadMode PadMode) : Op
|
|||
/// </summary>
|
||||
public static readonly ParameterInfo Value = new(typeof(Pad), 2, "value", IsScalar());
|
||||
|
||||
/// <summary>
|
||||
/// Gets pad mode.
|
||||
/// </summary>
|
||||
public PadMode PadMode { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"PadMode.{PadMode}";
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue