Add expr user analysis (#827)

* Make IR mutable

* Add expr user analysis

* fix il dot dump && fusion merge condition

* Merge from rebuild-ir

* Fix fusion tests

* Fix warnings

* Fix sg reference

* Fix merge

* Apply code-format changes

* Fix TestNoneValue

* Fix test dump

* Update with k230 requires

* Fix k230 tests

* ExprRewriter: Use replace opr instead of rauw

* Fix

* Fix

* Optimize performance

* Disable CS0659;CS0661

* Fix invalidate hashcode cache

* Fix new commits

---------

Co-authored-by: 郑启航 <597323109@qq.com>
Co-authored-by: sunnycase <sunnycase@users.noreply.github.com>
pull/837/head
sunnycase 2023-03-14 15:02:56 +08:00 committed by GitHub
parent 00ee74c402
commit e88d9c8d5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
470 changed files with 10066 additions and 7373 deletions

View File

@ -306,7 +306,7 @@ dotnet_diagnostic.IDE1006.severity = suggestion
dotnet_code_quality.copy_analysis = true
dotnet_diagnostic.CA1000.severity = none
dotnet_diagnostic.CA1001.severity = warning
dotnet_diagnostic.CA1001.severity = suggestion
dotnet_diagnostic.CA1018.severity = warning
dotnet_diagnostic.CA1019.severity = warning
dotnet_diagnostic.CA1036.severity = warning

View File

@ -4,7 +4,7 @@
<CodeAnalysisRuleSet>$(MSBuildThisFileDirectory)/tools/StyleCopAnalyzers.ruleset</CodeAnalysisRuleSet>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<Nullable>enable</Nullable>
<NoWarn>$(NoWarn);MSB3270</NoWarn>
<NoWarn>$(NoWarn);MSB3270;CS0659;CS0661</NoWarn>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)' == 'Release'">
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
@ -23,7 +23,7 @@
<PackageVersion Include="Google.Protobuf" Version="3.19.1" />
<PackageVersion Include="Grpc.Tools" Version="2.42.0" />
<PackageVersion Include="Humanizer.Core" Version="2.14.1" />
<PackageVersion Include="LanguageExt.Core" Version="4.0.3" />
<PackageVersion Include="LanguageExt.Core" Version="4.4.0" />
<PackageVersion Include="libtorch-cpu-linux-x64" Version="1.13.0.1" />
<PackageVersion Include="libtorch-cpu-osx-x64" Version="1.13.0.1" />
<PackageVersion Include="libtorch-cpu-win-x64" Version="1.13.0.1" />
@ -55,6 +55,7 @@
<PackageVersion Include="StyleCop.Analyzers" Version="1.2.0-beta.435" />
<PackageVersion Include="System.CommandLine.Hosting" Version="0.3.0-alpha.21216.1" />
<PackageVersion Include="System.Linq.Async" Version="6.0.1" />
<PackageVersion Include="System.Reactive" Version="5.0.0" />
<PackageVersion Include="Tomlyn.Extensions.Configuration" Version="1.0.5" />
<PackageVersion Include="TorchSharp" Version="0.99.0" />
<PackageVersion Include="xunit" Version="2.4.1" />

View File

@ -6,12 +6,14 @@
<add key="nuget.org" value="https://api.nuget.org/v3/index.json" protocolVersion="3" />
<add key="Nncase.Libs" value="https://www.myget.org/F/magicallibs/api/v3/index.json" protocolVersion="3"/>
<add key="myget-xunit" value="https://www.myget.org/F/xunit/api/v3/index.json" />
<add key="design-packages" value="tools/design-packages" />
</packageSources>
<activePackageSource>
<add key="nuget.cnblogs.com" value="https://nuget.cnblogs.com/v3/index.json" protocolVersion="3" />
<add key="nuget.org" value="https://api.nuget.org/v3/index.json" protocolVersion="3" />
<add key="Nncase.Libs" value="https://www.myget.org/F/magicallibs/api/v3/index.json" protocolVersion="3"/>
<add key="myget-xunit" value="https://www.myget.org/F/xunit/api/v3/index.json" />
<add key="design-packages" value="tools/design-packages" />
</activePackageSource>
<packageSourceMapping>
<!-- key value for <packageSource> should match key values from <packageSources> element -->
@ -26,5 +28,8 @@
<packageSource key="myget-xunit">
<package pattern="xunit.v3.assert" />
</packageSource>
<packageSource key="design-packages">
<package pattern="Nncase.SourceGenerator" />
</packageSource>
</packageSourceMapping>
</configuration>

View File

@ -11,10 +11,10 @@ using Nncase.CodeGen;
using Nncase.CodeGen.K210;
using Nncase.CodeGen.StackVM;
using Nncase.IR;
using Nncase.Passes;
using Nncase.Passes.Rules.K210;
using Nncase.Quantization;
using Nncase.Runtime.K210;
using Nncase.Transform;
using Nncase.Transform.Rules.K210;
namespace Nncase.Targets;

View File

@ -15,7 +15,7 @@ using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.F.NN;
using static Nncase.PatternMatch.Utility;
namespace Nncase.Transform.Rules.K210;
namespace Nncase.Passes.Rules.K210;
/// <summary>
/// Lower <see cref="IR.NN.Conv2D"/> to <see cref="IR.K210.FakeKPUConv2D"/>.

View File

@ -21,7 +21,7 @@ using static Nncase.PatternMatch.Utility;
using static Nncase.Quantization.Utility;
using Math = Nncase.IR.F.Math;
namespace Nncase.Transform.Rules.K210;
namespace Nncase.Passes.Rules.K210;
/// <summary>
/// Lower <see cref="FakeDequantize"/> to <see cref="Dequantize"/>.

View File

@ -22,7 +22,7 @@ using static Nncase.PatternMatch.F.NN;
using static Nncase.PatternMatch.Utility;
using Math = System.Math;
namespace Nncase.Transform.Rules.K210;
namespace Nncase.Passes.Rules.K210;
/// <summary>
/// Lower <see cref="IR.K210.FakeKPUConv2D"/> to <see cref="IR.K210.KPUConv2D"/>.

View File

@ -20,7 +20,7 @@ using static Nncase.PatternMatch.F.NN;
using static Nncase.PatternMatch.Utility;
using Math = Nncase.IR.F.Math;
namespace Nncase.Transform.Rules.K210;
namespace Nncase.Passes.Rules.K210;
/// <summary>
/// Lower <see cref="IR.K210.FakeKPUDownload"/> to <see cref="IR.K210.KPUDownload"/>.

View File

@ -21,7 +21,7 @@ using static Nncase.PatternMatch.F.NN;
using static Nncase.PatternMatch.Utility;
using Math = Nncase.IR.F.Math;
namespace Nncase.Transform.Rules.K210;
namespace Nncase.Passes.Rules.K210;
/// <summary>
/// Lower <see cref="IR.K210.FakeKPUUpload"/> to <see cref="IR.K210.KPUUpload"/>.

View File

@ -21,7 +21,7 @@ using static Nncase.PatternMatch.Utility;
using static Nncase.Quantization.Utility;
using Math = Nncase.IR.F.Math;
namespace Nncase.Transform.Rules.K210;
namespace Nncase.Passes.Rules.K210;
/// <summary>
/// Lower <see cref="FakeQuantize"/> to <see cref="Quantize"/>.

View File

@ -1,4 +1,4 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/3/3 下午1:12:08 +08:00. */
@ -188,7 +188,7 @@ internal partial class CodeGenVisitor
Emitter.T.GetItem();
break;
case IR.Tensors.LSTM top:
Emitter.T.LSTM(top.Direction, top.Layout, top.Activations);
Emitter.T.LSTM(top.Direction, top.Layout, top.Activations.ToArray());
break;
case IR.Tensors.Prod top:
Emitter.T.Prod();

View File

@ -4,8 +4,10 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive;
using System.Text;
using System.Threading.Tasks;
using NetFabric.Hyperlinq;
using Nncase.IR;
namespace Nncase.CodeGen.StackVM;
@ -105,7 +107,7 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
private StackVMEmitter Emitter => CurrentTextSnippet.Emitter;
public override TextSnippet VisitLeaf(Const expr)
protected override TextSnippet VisitLeafConst(Const expr)
{
if (expr is TensorConst tc)
{
@ -113,11 +115,11 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
}
else
{
return Visit(new IR.Tuple(((TupleConst)expr).Fields));
return Visit(new IR.Tuple(((TupleConst)expr).Value.Select(x => Const.FromValue(x)).ToArray()));
}
}
public override TextSnippet VisitLeaf(Var expr)
protected override TextSnippet VisitLeafVar(Var expr)
{
var snippet = BeginTextSnippet(expr);
var varIndex = ((Function)_function).Parameters.IndexOf(expr);
@ -130,13 +132,13 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
return snippet;
}
public override TextSnippet VisitLeaf(IR.Tuple expr)
protected override TextSnippet VisitLeafTuple(IR.Tuple expr)
{
var snippet = BeginTextSnippet(expr);
foreach (var field in expr.Fields.Reverse())
foreach (var field in expr.Fields.ToArray().Reverse())
{
var inputSnippet = Visit(field);
inputSnippet.MaxUserParameters = Math.Max(inputSnippet.MaxUserParameters, expr.Fields.Count);
inputSnippet.MaxUserParameters = Math.Max(inputSnippet.MaxUserParameters, expr.Fields.Length);
snippet.AddInput(inputSnippet);
}
@ -145,7 +147,7 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
return snippet;
}
public override TextSnippet Visit(Function expr)
protected override TextSnippet VisitFunction(Function expr)
{
if (ReferenceEquals(expr, _function))
{
@ -157,7 +159,7 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
}
}
public override TextSnippet Visit(PrimFunctionWrapper expr)
protected override TextSnippet VisitPrimFunctionWrapper(PrimFunctionWrapper expr)
{
if (ReferenceEquals(expr, _function))
{
@ -177,25 +179,25 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
}
}
public override TextSnippet VisitLeaf(Op expr)
protected override TextSnippet VisitLeafOp(Op expr)
{
return null!;
}
public override TextSnippet VisitLeaf(Call expr)
protected override TextSnippet VisitLeafCall(Call expr)
{
var snippet = BeginTextSnippet(expr);
foreach (var param in expr.Parameters.Reverse())
foreach (var param in expr.Arguments.ToArray().Reverse())
{
var paramSnippet = Visit(param);
paramSnippet.MaxUserParameters = Math.Max(paramSnippet.MaxUserParameters, expr.Parameters.Count);
paramSnippet.MaxUserParameters = Math.Max(paramSnippet.MaxUserParameters, expr.Arguments.Length);
snippet.AddInput(paramSnippet);
}
if (expr.Target is CustomOp custom_op)
{
_context.AddCustomCallModule(custom_op.ModuleType);
Emitter.CusCall(custom_op.RegisteredName, custom_op.SerializeFields(), checked((ushort)expr.Parameters.Count));
Emitter.CusCall(custom_op.RegisteredName, custom_op.SerializeFields(), checked((ushort)expr.Arguments.Length));
}
else if (expr.Target is Op op)
{
@ -209,7 +211,7 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
else if (expr.Target is Function func)
{
LdFunctionId(func);
Emitter.ExtCall(checked((ushort)func.Parameters.Count), false);
Emitter.ExtCall(checked((ushort)func.Parameters.Length), false);
}
else
{
@ -219,15 +221,9 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
return snippet;
}
public override TextSnippet Visit(If expr)
protected override TextSnippet VisitIf(If expr)
{
if (!ExpressionMemo.TryGetValue(expr, out var result))
{
result = VisitLeaf(expr);
ExpressionMemo.Add(expr, result);
}
return result;
return VisitLeafIf(expr);
}
/// <summary>
@ -245,7 +241,7 @@ internal partial class CodeGenVisitor : ExprVisitor<TextSnippet, IRType>
/// </summary>
/// <param name="if">If expr.</param>
/// <returns>TextSnippet.</returns>
public override TextSnippet VisitLeaf(If @if)
protected override TextSnippet VisitLeafIf(If @if)
{
var condSnippet = Visit(@if.Condition);
condSnippet.Emitter.LdScalar();

View File

@ -11,8 +11,8 @@ using Microsoft.Extensions.Options;
using Nncase.CodeGen;
using Nncase.CodeGen.StackVM;
using Nncase.IR;
using Nncase.Passes;
using Nncase.Quantization;
using Nncase.Transform;
namespace Nncase.Targets;

View File

@ -14,6 +14,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{E11A07F5
EndProjectSection
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Core", "src\Nncase.Core\Nncase.Core.csproj", "{2978CFB1-1530-4EC3-BB5F-130B6F606F85}"
ProjectSection(ProjectDependencies) = postProject
{24DF3895-9473-4C98-A494-8275456D02CC} = {24DF3895-9473-4C98-A494-8275456D02CC}
EndProjectSection
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Importer", "src\Nncase.Importer\Nncase.Importer.csproj", "{37410602-6F3A-46F4-8CFD-9B8742BB98F4}"
EndProject
@ -47,8 +50,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Targets", "src\Nncas
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Simulator", "src\Nncase.Simulator\Nncase.Simulator.csproj", "{901AC17C-7B53-4B10-A2AC-EA7AEA6DC614}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Transform", "src\Nncase.Transform\Nncase.Transform.csproj", "{2537C687-E989-47A5-AE93-CA7CEB4CB7AA}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.EGraph", "src\Nncase.EGraph\Nncase.EGraph.csproj", "{ED7F9E49-E5F1-4A0F-9AB1-D2247D3AC596}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Compiler", "src\Nncase.Compiler\Nncase.Compiler.csproj", "{54B4C0B9-8BF0-4157-AB04-A55BA268444C}"
@ -63,8 +64,6 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "modules", "modules", "{9859
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Modules.StackVM", "modules\Nncase.Modules.StackVM\Nncase.Modules.StackVM.csproj", "{70D3FA34-B0B6-488F-812D-7E076CC0DFB3}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Modules.K210", "modules\Nncase.Modules.K210\Nncase.Modules.K210.csproj", "{7390DF41-E804-4F12-B441-6EB5119C81BF}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Quantization", "src\Nncase.Quantization\Nncase.Quantization.csproj", "{317C0D8F-75B3-4248-83E8-17AADDCF247A}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "root", "root", "{EF1F8779-2B98-4E6F-A3DC-CA3FD2CADAD8}"
@ -78,6 +77,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Diagnostics", "src\N
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Tests.TestFixture", "src\Nncase.Tests.TestFixture\Nncase.Tests.TestFixture.csproj", "{98A03405-CA53-4EC4-9B18-94D1C8DF9453}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Passes", "src\Nncase.Passes\Nncase.Passes.csproj", "{E6462E82-B48F-4AFA-AE34-725EF0A9CB42}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@ -136,10 +137,6 @@ Global
{901AC17C-7B53-4B10-A2AC-EA7AEA6DC614}.Debug|Any CPU.Build.0 = Debug|Any CPU
{901AC17C-7B53-4B10-A2AC-EA7AEA6DC614}.Release|Any CPU.ActiveCfg = Release|Any CPU
{901AC17C-7B53-4B10-A2AC-EA7AEA6DC614}.Release|Any CPU.Build.0 = Release|Any CPU
{2537C687-E989-47A5-AE93-CA7CEB4CB7AA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{2537C687-E989-47A5-AE93-CA7CEB4CB7AA}.Debug|Any CPU.Build.0 = Debug|Any CPU
{2537C687-E989-47A5-AE93-CA7CEB4CB7AA}.Release|Any CPU.ActiveCfg = Release|Any CPU
{2537C687-E989-47A5-AE93-CA7CEB4CB7AA}.Release|Any CPU.Build.0 = Release|Any CPU
{ED7F9E49-E5F1-4A0F-9AB1-D2247D3AC596}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{ED7F9E49-E5F1-4A0F-9AB1-D2247D3AC596}.Debug|Any CPU.Build.0 = Debug|Any CPU
{ED7F9E49-E5F1-4A0F-9AB1-D2247D3AC596}.Release|Any CPU.ActiveCfg = Release|Any CPU
@ -160,10 +157,6 @@ Global
{70D3FA34-B0B6-488F-812D-7E076CC0DFB3}.Debug|Any CPU.Build.0 = Debug|Any CPU
{70D3FA34-B0B6-488F-812D-7E076CC0DFB3}.Release|Any CPU.ActiveCfg = Release|Any CPU
{70D3FA34-B0B6-488F-812D-7E076CC0DFB3}.Release|Any CPU.Build.0 = Release|Any CPU
{7390DF41-E804-4F12-B441-6EB5119C81BF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{7390DF41-E804-4F12-B441-6EB5119C81BF}.Debug|Any CPU.Build.0 = Debug|Any CPU
{7390DF41-E804-4F12-B441-6EB5119C81BF}.Release|Any CPU.ActiveCfg = Release|Any CPU
{7390DF41-E804-4F12-B441-6EB5119C81BF}.Release|Any CPU.Build.0 = Release|Any CPU
{317C0D8F-75B3-4248-83E8-17AADDCF247A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{317C0D8F-75B3-4248-83E8-17AADDCF247A}.Debug|Any CPU.Build.0 = Debug|Any CPU
{317C0D8F-75B3-4248-83E8-17AADDCF247A}.Release|Any CPU.ActiveCfg = Release|Any CPU
@ -176,6 +169,10 @@ Global
{98A03405-CA53-4EC4-9B18-94D1C8DF9453}.Debug|Any CPU.Build.0 = Debug|Any CPU
{98A03405-CA53-4EC4-9B18-94D1C8DF9453}.Release|Any CPU.ActiveCfg = Release|Any CPU
{98A03405-CA53-4EC4-9B18-94D1C8DF9453}.Release|Any CPU.Build.0 = Release|Any CPU
{E6462E82-B48F-4AFA-AE34-725EF0A9CB42}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{E6462E82-B48F-4AFA-AE34-725EF0A9CB42}.Debug|Any CPU.Build.0 = Debug|Any CPU
{E6462E82-B48F-4AFA-AE34-725EF0A9CB42}.Release|Any CPU.ActiveCfg = Release|Any CPU
{E6462E82-B48F-4AFA-AE34-725EF0A9CB42}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@ -196,17 +193,16 @@ Global
{8E0E0672-0F96-4EF1-BDCD-D31F96A3DF73} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
{56283378-06E3-4C6E-A8BF-7BD85C92D42C} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
{901AC17C-7B53-4B10-A2AC-EA7AEA6DC614} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
{2537C687-E989-47A5-AE93-CA7CEB4CB7AA} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
{ED7F9E49-E5F1-4A0F-9AB1-D2247D3AC596} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
{54B4C0B9-8BF0-4157-AB04-A55BA268444C} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
{24DF3895-9473-4C98-A494-8275456D02CC} = {E11A07F5-DDF9-4BCA-94E4-988913EF7F51}
{A79769F9-D567-4236-8965-08CFB440783B} = {E11A07F5-DDF9-4BCA-94E4-988913EF7F51}
{920E3584-7987-44ED-8794-F6929E812038} = {A79769F9-D567-4236-8965-08CFB440783B}
{70D3FA34-B0B6-488F-812D-7E076CC0DFB3} = {9859F5E8-5504-4AFE-B955-9497A0A0CD66}
{7390DF41-E804-4F12-B441-6EB5119C81BF} = {9859F5E8-5504-4AFE-B955-9497A0A0CD66}
{317C0D8F-75B3-4248-83E8-17AADDCF247A} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
{E365B1B1-4D13-4839-9763-A7A7C5F32FD4} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
{98A03405-CA53-4EC4-9B18-94D1C8DF9453} = {E5A4516C-4080-4346-991D-57A7AA76ADA6}
{E6462E82-B48F-4AFA-AE34-725EF0A9CB42} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {9492E141-292E-4D60-9C6E-3738AB234DB2}

View File

@ -14,8 +14,8 @@ using Nncase.CodeGen;
using Nncase.Compiler;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.Passes;
using Nncase.Quantization;
using Nncase.Transform;
namespace Nncase.Cli.Commands;

View File

@ -15,12 +15,8 @@ public class LinkedFunction : ILinkedFunction
public LinkedFunction(uint id, Callable sourceFunction, uint textBegin, uint textLength, IReadOnlyList<ILinkedSection> sections)
{
Id = id;
if (sourceFunction.CheckedType is null)
{
CompilerServices.InferenceType(sourceFunction);
}
ParameterTypes = ((CallableType)sourceFunction.CheckedType!).Parameters.ToArray();
CompilerServices.InferenceType(sourceFunction);
ParameterTypes = ((CallableType)sourceFunction.CheckedType).Parameters.ToArray();
ReturnType = ((CallableType)sourceFunction.CheckedType).ReturnType;
TextBegin = textBegin;
TextLength = textLength;

View File

@ -10,10 +10,10 @@ using Nncase.Diagnostics;
using Nncase.Evaluator;
using Nncase.Hosting;
using Nncase.IR;
using Nncase.Passes;
using Nncase.Passes.Rules.Lower;
using Nncase.Passes.Transforms;
using Nncase.Quantization;
using Nncase.Transform;
using Nncase.Transform.Passes;
using Nncase.Transform.Rules.Lower;
using Nncase.Utilities;
namespace Nncase.Compiler;
@ -61,43 +61,43 @@ internal class Compiler : ICompiler
var quantMode = _compileSession.CompileOptions.QuantizeOptions.ModelQuantMode;
if (quantMode == ModelQuantMode.UsePTQ)
{
passManager.AddWithName<EGraphPass>("NeutralOptimizeTranspose").Configure(p =>
passManager.AddWithName<EGraphRulesPass>("NeutralOptimizeTranspose").Configure(p =>
{
p.Add<Transform.Rules.Neutral.FoldConstCall>();
p.Add<Transform.Rules.Neutral.FoldNopTranspose>();
p.Add<Transform.Rules.Neutral.FoldTwoTransposes>();
p.Add<Transform.Rules.Neutral.CombineTransposeUnary>();
p.Add<Transform.Rules.Neutral.CombineTransposePad>();
p.Add<Transform.Rules.Neutral.CombinePadTranspose>();
p.Add<Transform.Rules.Neutral.CombineBinaryTranspose>();
p.Add<Transform.Rules.Neutral.CombineConstBinaryTranspose>();
p.Add<Transform.Rules.Neutral.CombineTransposeConstBinary>();
p.Add<Transform.Rules.Neutral.CombineTransposeReduce>();
p.Add<Transform.Rules.Neutral.CombineTransposeActivations>();
p.Add<Transform.Rules.Neutral.CombineActivationsTranspose>();
p.Add<Transform.Rules.Neutral.CombineTransposeConcat>();
p.Add<Transform.Rules.Neutral.FoldNopPad>();
p.Add<Transform.Rules.Neutral.FoldConv2DPads>();
p.Add<Transform.Rules.Neutral.FoldReduceWindow2DPads>();
p.Add<Transform.Rules.Neutral.SqueezeToReshape>();
p.Add<Transform.Rules.Neutral.UnSqueezeToReshape>();
p.Add<Transform.Rules.Neutral.TransposeToReshape>();
p.Add<Transform.Rules.Neutral.FoldNopReshape>();
p.Add<Transform.Rules.Neutral.FoldTwoReshapes>();
p.Add<Transform.Rules.Neutral.FoldLayerNormPattern1>();
p.Add<Transform.Rules.Neutral.FoldLayerNormPattern2>();
p.Add<Transform.Rules.Neutral.FoldLayerNormPattern3>();
p.Add<Passes.Rules.Neutral.FoldConstCall>();
p.Add<Passes.Rules.Neutral.FoldNopTranspose>();
p.Add<Passes.Rules.Neutral.FoldTwoTransposes>();
p.Add<Passes.Rules.Neutral.CombineTransposeUnary>();
p.Add<Passes.Rules.Neutral.CombineTransposePad>();
p.Add<Passes.Rules.Neutral.CombinePadTranspose>();
p.Add<Passes.Rules.Neutral.CombineBinaryTranspose>();
p.Add<Passes.Rules.Neutral.CombineConstBinaryTranspose>();
p.Add<Passes.Rules.Neutral.CombineTransposeConstBinary>();
p.Add<Passes.Rules.Neutral.CombineTransposeReduce>();
p.Add<Passes.Rules.Neutral.CombineTransposeActivations>();
p.Add<Passes.Rules.Neutral.CombineActivationsTranspose>();
p.Add<Passes.Rules.Neutral.CombineTransposeConcat>();
p.Add<Passes.Rules.Neutral.FoldNopPad>();
p.Add<Passes.Rules.Neutral.FoldConv2DPads>();
p.Add<Passes.Rules.Neutral.FoldReduceWindow2DPads>();
p.Add<Passes.Rules.Neutral.SqueezeToReshape>();
p.Add<Passes.Rules.Neutral.UnSqueezeToReshape>();
p.Add<Passes.Rules.Neutral.TransposeToReshape>();
p.Add<Passes.Rules.Neutral.FoldNopReshape>();
p.Add<Passes.Rules.Neutral.FoldTwoReshapes>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern1>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern2>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern3>();
});
// passManager.AddWithName<EGraphPass>("NeutralOptimizeClamp").Configure(p =>
// {
// p.Add<Transform.Rules.Neutral.FoldConstCall>();
// p.Add<Transform.Rules.Neutral.FoldConv2DAddMul>();
// p.Add<Transform.Rules.Neutral.ReluToClamp>();
// p.Add<Transform.Rules.Neutral.Relu6ToClamp>();
// p.Add<Transform.Rules.Neutral.CombineClampAdd>();
// p.Add<Transform.Rules.Neutral.CombineClampMul>();
// p.Add<Transform.Rules.Neutral.FoldNopClamp>();
// p.Add<Passes.Rules.Neutral.FoldConstCall>();
// p.Add<Passes.Rules.Neutral.FoldConv2DAddMul>();
// p.Add<Passes.Rules.Neutral.ReluToClamp>();
// p.Add<Passes.Rules.Neutral.Relu6ToClamp>();
// p.Add<Passes.Rules.Neutral.CombineClampAdd>();
// p.Add<Passes.Rules.Neutral.CombineClampMul>();
// p.Add<Passes.Rules.Neutral.FoldNopClamp>();
// });
}
@ -107,7 +107,7 @@ internal class Compiler : ICompiler
{
passManager.AddWithName<DataflowPass>("AddRangeOfMarker").Configure(p =>
{
p.Add<Transform.Rules.Neutral.AddRangeOfAndMarker>();
p.Add<Passes.Rules.Neutral.AddRangeOfAndMarker>();
});
passManager.AddWithName<EGraphPassWithQuantize>("AssignRanges");
}

View File

@ -52,8 +52,8 @@ public static class CompilerHostBuilderExtensions
.AddGraph()
.AddEGraph()
.AddCodeGen()
.AddStackVM()
.AddK210();
.AddPasses()
.AddStackVM();
}
private static void ConfigureServices(HostBuilderContext context, IServiceCollection services)

View File

@ -9,8 +9,6 @@ using System.Reflection;
using System.Reflection.Metadata;
using System.Reflection.PortableExecutable;
using System.Runtime.CompilerServices;
using System.Runtime.Loader;
using LanguageExt;
using Microsoft.Extensions.Logging;
using Nncase.Compiler;
using Nncase.IR;

View File

@ -171,8 +171,8 @@ public static unsafe class CApi
var samples = (dataset.Length == 0 ?
Array.Empty<Dictionary<Var, IValue>>() :
dataset.Chunk(dataset.Length / (int)samplesCount).Select(inputs => inputs.Zip(fnParams).ToDictionary(
item => item.Item2,
item => item.Item1.ToValue()))).ToAsyncEnumerable();
item => item.Second,
item => item.First.ToValue()))).ToAsyncEnumerable();
return GCHandle.ToIntPtr(GCHandle.Alloc(new CCalibrationDatasetProvider(samples, (int)samplesCount)));
}
@ -288,8 +288,8 @@ public static unsafe class CApi
var fnParams = Get<Var[]>(fnParamsHandle);
var inputs = Get<RTValue[]>(inputsHandle);
var result = CompilerServices.Evaluate(expr, fnParams.Zip(inputs).ToDictionary(
x => x.Item1,
x => x.Item2.ToValue()));
x => x.First,
x => x.Second.ToValue()));
var rtValue = RTValue.FromValue(result);
return GCHandle.ToIntPtr(GCHandle.Alloc(rtValue));
}

View File

@ -15,7 +15,6 @@
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\modules\Nncase.Modules.K210\Nncase.Modules.K210.csproj" />
<ProjectReference Include="..\Nncase.CodeGen\Nncase.CodeGen.csproj" />
<ProjectReference Include="..\Nncase.Core\Nncase.Core.csproj" />
<ProjectReference Include="..\Nncase.Diagnostics\Nncase.Diagnostics.csproj" />
@ -25,7 +24,7 @@
<ProjectReference Include="..\Nncase.Importer\Nncase.Importer.csproj" />
<ProjectReference Include="..\Nncase.Simulator\Nncase.Simulator.csproj" />
<ProjectReference Include="..\Nncase.Quantization\Nncase.Quantization.csproj" />
<ProjectReference Include="..\Nncase.Transform\Nncase.Transform.csproj" />
<ProjectReference Include="..\Nncase.Passes\Nncase.Passes.csproj" />
<ProjectReference Include="..\..\modules\Nncase.Modules.StackVM\Nncase.Modules.StackVM.csproj" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,23 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
namespace Nncase.Collections;
public static class CollectionsExtensions
{
public static void AddRange<T>(this List<T> list, ReadOnlySpan<T> items)
{
list.Capacity += items.Length;
foreach (var item in items)
{
list.Add(item);
}
}
}

View File

@ -0,0 +1,197 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Nncase.Collections;
/// <summary>
/// Represents an ordered sequence of weak references.
/// </summary>
public sealed class WeakList<T> : IEnumerable<T>
where T : class
{
private const int MinimalNonEmptySize = 4;
private WeakReference<T>[] _items;
private int _size;
public WeakList()
{
_items = Array.Empty<WeakReference<T>>();
}
/// <summary>
/// Gets the number of weak references in this list.
/// Note that some of them might not point to live objects anymore.
/// </summary>
public int WeakCount
{
get { return _size; }
}
public WeakReference<T> GetWeakReference(int index)
{
if (index < 0 || index >= _size)
{
throw new ArgumentOutOfRangeException(nameof(index));
}
return _items[index];
}
public void Add(T item)
{
if (_size == _items.Length)
{
Resize();
}
_items[_size++] = new WeakReference<T>(item);
}
public IEnumerator<T> GetEnumerator()
{
int count = _size;
int alive = _size;
int firstDead = -1;
for (int i = 0; i < count; i++)
{
T? item;
if (_items[i].TryGetTarget(out item))
{
yield return item;
}
else
{
// object has been collected
if (firstDead < 0)
{
firstDead = i;
}
alive--;
}
}
if (alive == 0)
{
_items = Array.Empty<WeakReference<T>>();
_size = 0;
}
else if (alive < _items.Length / 4)
{
// If we have just a few items left we shrink the array.
Shrink(firstDead, alive);
}
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
private static int GetExpandedSize(int baseSize)
{
return Math.Max((baseSize * 2) + 1, MinimalNonEmptySize);
}
private void Resize()
{
int alive = _items.Length;
int firstDead = -1;
for (int i = 0; i < _items.Length; i++)
{
if (!_items[i].TryGetTarget(out _))
{
if (firstDead == -1)
{
firstDead = i;
}
alive--;
}
}
if (alive < _items.Length / 4)
{
// If we have just a few items left we shrink the array.
// We avoid expanding the array until the number of new items added exceeds half of its capacity.
Shrink(firstDead, alive);
}
else if (alive >= 3 * _items.Length / 4)
{
// If we have a lot of items alive we expand the array since just compacting them
// wouldn't free up much space (we would end up calling Resize again after adding a few more items).
var newItems = new WeakReference<T>[GetExpandedSize(_items.Length)];
if (firstDead >= 0)
{
Compact(firstDead, newItems);
}
else
{
Array.Copy(_items, 0, newItems, 0, _items.Length);
}
_items = newItems;
}
else
{
// Compact in-place to make space for new items at the end.
// We will free up to length/4 slots in the array.
Compact(firstDead, _items);
}
Debug.Assert(_items.Length > 0 && _size < 3 * _items.Length / 4, "length: " + _items.Length + " size: " + _size);
}
private void Shrink(int firstDead, int alive)
{
int newSize = GetExpandedSize(alive);
var newItems = (newSize == _items.Length) ? _items : new WeakReference<T>[newSize];
Compact(firstDead, newItems);
_items = newItems;
}
/// <summary>
/// Copies all live references from <see cref="_items"/> to <paramref name="result"/>.
/// Assumes that all references prior <paramref name="firstDead"/> are alive.
/// </summary>
private void Compact(int firstDead, WeakReference<T>[] result)
{
if (!ReferenceEquals(_items, result))
{
Array.Copy(_items, 0, result, 0, firstDead);
}
int oldSize = _size;
int j = firstDead;
for (int i = firstDead + 1; i < oldSize; i++)
{
var item = _items[i];
if (item.TryGetTarget(out _))
{
result[j++] = item;
}
}
_size = j;
// free WeakReferences
if (ReferenceEquals(_items, result))
{
while (j < oldSize)
{
_items[j++] = null!;
}
}
}
}

View File

@ -11,7 +11,7 @@ using Microsoft.Extensions.DependencyInjection;
using Nncase.CodeGen;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.Transform;
using Nncase.Passes;
namespace Nncase;
@ -79,9 +79,7 @@ public sealed class CompileSession : IServiceProvider, IDisposable
/// <param name="name">Name.</param>
/// <returns>Created pass manager.</returns>
public IPassManager CreatePassManager(string name)
{
return new PassManager(name, this);
}
=> _serviceProvider.GetRequiredService<IPassManagerFactory>().Create(name, this);
/// <inheritdoc/>
public void Dispose()

View File

@ -9,7 +9,7 @@ using System.Threading.Tasks;
namespace Nncase;
internal struct CompileSessionScope : IDisposable
public struct CompileSessionScope : IDisposable
{
private static readonly AsyncLocal<CompileSession?> _compileSession = new AsyncLocal<CompileSession?>();

View File

@ -14,9 +14,9 @@ using Microsoft.Extensions.DependencyInjection;
using Nncase.CostModel;
using Nncase.Evaluator;
using Nncase.IR;
using Nncase.Passes;
using Nncase.PatternMatch;
using Nncase.Targets;
using Nncase.Transform;
namespace Nncase;
@ -108,9 +108,8 @@ public interface ICompilerServicesProvider
/// Evaluate cost of the expression tree.
/// </summary>
/// <param name="expr">Expression.</param>
/// <param name="varsValues">Optional vars' values.</param>
/// <returns>Evaluate result.</returns>
Cost? EvaluateCost(Expr expr, IReadOnlyDictionary<Var, Cost>? varsValues = null);
Cost EvaluateCost(Expr expr);
/// <summary>
/// Evaluate cost of operator.
@ -118,7 +117,7 @@ public interface ICompilerServicesProvider
/// <param name="op">Target operator.</param>
/// <param name="context">Evaluate context.</param>
/// <returns>Evaluate result.</returns>
Cost? EvaluateOpCost(Op op, ICostEvaluateContext context);
Cost EvaluateOpCost(Op op, ICostEvaluateContext context);
/// <summary>
/// Match expression.
@ -267,11 +266,10 @@ public static class CompilerServices
/// Evaluate cost of the expression tree.
/// </summary>
/// <param name="expr">Expression.</param>
/// <param name="varsValues">Optional vars' values.</param>
/// <returns>Evaluate result.</returns>
public static Cost? EvaluateCost(Expr expr, IReadOnlyDictionary<Var, Cost>? varsValues = null)
public static Cost EvaluateCost(Expr expr)
{
return Provider.EvaluateCost(expr, varsValues);
return Provider.EvaluateCost(expr);
}
/// <summary>
@ -280,7 +278,7 @@ public static class CompilerServices
/// <param name="op">Target operator.</param>
/// <param name="context">Evaluate context.</param>
/// <returns>Evaluate result.</returns>
public static Cost? EvaluateOpCost(Op op, ICostEvaluateContext context)
public static Cost EvaluateOpCost(Op op, ICostEvaluateContext context)
{
return Provider.EvaluateOpCost(op, context);
}
@ -560,13 +558,13 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
}
/// <inheritdoc/>
public Cost? EvaluateCost(Expr expr, IReadOnlyDictionary<Var, Cost>? varsValues = null)
public Cost EvaluateCost(Expr expr)
{
return _costEvaluateProvider.EvaluateCost(expr, varsValues);
return _costEvaluateProvider.EvaluateCost(expr);
}
/// <inheritdoc/>
public Cost? EvaluateOpCost(Op op, ICostEvaluateContext context)
public Cost EvaluateOpCost(Op op, ICostEvaluateContext context)
{
return _costEvaluateProvider.EvaluateOpCost(op, context);
}

View File

@ -3,7 +3,6 @@
using DryIoc;
using Nncase.Hosting;
using Nncase.IR;
namespace Nncase;
@ -16,7 +15,6 @@ internal class CoreModule : IApplicationPart
{
registrator.RegisterManyInterface<CompilerServicesProvider>(reuse: Reuse.Singleton);
registrator.Register<IDataTypeServiceProvider, DataTypeServiceProvider>(reuse: Reuse.Singleton);
registrator.Register<IIRPrinterProvider, IRPrinterProvider>(reuse: Reuse.Singleton);
// Prim types
registrator.Register<PrimType, BooleanType>(reuse: Reuse.Singleton);

View File

@ -4,6 +4,7 @@
using System;
using NetFabric.Hyperlinq;
using Nncase.IR;
using static NetFabric.Hyperlinq.ArrayExtensions;
namespace Nncase.CostModel;
@ -162,9 +163,26 @@ public static class CostExtensions
/// </summary>
/// <param name="costs">Source.</param>
/// <returns>Result.</returns>
public static Cost Sum(this IEnumerable<Cost?> costs)
public static Cost Sum(this IEnumerable<Cost> costs)
{
return costs.Aggregate((Cost?)Cost.Zero, (x, y) => y == null ? null : x! + y)!;
return costs.Aggregate(Cost.Zero, (x, y) => x + y)!;
}
/// <summary>
/// Sum all costs.
/// </summary>
/// <param name="costs">Source.</param>
/// <returns>Result.</returns>
public static Cost Sum<TSource, TSelector>(this in SpanSelectEnumerable<TSource, Cost, TSelector> costs)
where TSelector : struct, IFunction<TSource, Cost>
{
var sum = Cost.Zero;
foreach (var cost in costs)
{
sum += cost;
}
return sum;
}
}

80
src/Nncase.Core/Either.cs Normal file
View File

@ -0,0 +1,80 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Nncase;
internal enum EitherState
{
None,
T1,
T2,
}
public struct Either<T1, T2>
{
private readonly EitherState _state;
[AllowNull]
private readonly T1 _t1;
[AllowNull]
private readonly T2 _t2;
private Either(T1 value)
{
_state = EitherState.T1;
_t1 = value;
_t2 = default;
}
private Either(T2 value)
{
_state = EitherState.T2;
_t2 = value;
_t1 = default;
}
public T1 Value1 => _state == EitherState.T1 ? _t1 : throw new InvalidOperationException($"Value is not a {typeof(T1)}.");
public T2 Value2 => _state == EitherState.T2 ? _t2 : throw new InvalidOperationException($"Value is not a {typeof(T2)}.");
public static implicit operator Either<T1, T2>(T1 value) => new(value);
public static implicit operator Either<T1, T2>(T2 value) => new(value);
public static explicit operator T1(Either<T1, T2> value) => value.Value1;
public static explicit operator T2(Either<T1, T2> value) => value.Value2;
public static Either<T1, T2> From(T1 value) => new(value);
public static Either<T1, T2> From(T2 value) => new(value);
public bool Is<T>()
{
if (typeof(T) == typeof(T1))
{
return _state == EitherState.T1;
}
if (typeof(T) == typeof(T2))
{
return _state == EitherState.T2;
}
return false;
}
public T Match<T>(Func<T1, T> t1Selector, Func<T2, T> t2Selector)
=> Is<T1>() ? t1Selector(_t1) : t2Selector(_t2);
public object Match(Func<T1, object> t1Selector, Func<T2, object> t2Selector)
=> Is<T1>() ? t1Selector(_t1) : t2Selector(_t2);
}

View File

@ -255,7 +255,7 @@ namespace Nncase.TIR
IsEntryFunc = 549755813888,
/// <summary>
/// Parameters used in the module that should be linked by the codegen.
/// Arguments used in the module that should be linked by the codegen.
/// </summary>
LinkedParams = 1099511627776,

View File

@ -20,9 +20,8 @@ public interface ICostEvaluateProvider
/// Evaluate cost of the expression tree.
/// </summary>
/// <param name="expr">Expression.</param>
/// <param name="varsValues">Optional vars' values.</param>
/// <returns>Evaluate result.</returns>
Cost? EvaluateCost(Expr expr, IReadOnlyDictionary<Var, Cost>? varsValues = null);
Cost EvaluateCost(Expr expr);
/// <summary>
/// Evaluate cost of operator.
@ -30,5 +29,5 @@ public interface ICostEvaluateProvider
/// <param name="op">Target operator.</param>
/// <param name="context">Evaluate context.</param>
/// <returns>Evaluate result.</returns>
Cost? EvaluateOpCost(Op op, ICostEvaluateContext context);
Cost EvaluateOpCost(Op op, ICostEvaluateContext context);
}

View File

@ -0,0 +1,31 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Nncase.IR;
/// <summary>
/// Base function.
/// </summary>
public abstract class BaseFunction : Callable
{
public BaseFunction(string name, string moduleKind, Expr[] operands)
: base(name, moduleKind, operands)
{
}
/// <summary>
/// Gets sched result.
/// </summary>
public Schedule.SchedFunctionResult SchedResult { get; } = new();
/// <summary>
/// Gets parameter types.
/// </summary>
public abstract IEnumerable<IRType?> ParameterTypes { get; }
}

View File

@ -6,11 +6,12 @@ using System.Collections.Generic;
using System.Linq;
using Nncase.IR;
namespace Nncase.IR.Buffer;
namespace Nncase.IR.Buffers;
/// <summary>
/// get the buffer basement.
/// </summary>
public record Allocate(TensorType ElemType) : Op
public sealed partial class Allocate : Op
{
public TensorType ElemType { get; }
}

View File

@ -4,12 +4,12 @@ using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
namespace Nncase.IR.Buffer;
namespace Nncase.IR.Buffers;
/// <summary>
/// get the buffer basement.
/// </summary>
public record BaseMentOf() : Op
public sealed partial class BaseMentOf : Op
{
/// <summary>
/// Get the input parameter.

View File

@ -4,18 +4,20 @@ using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
namespace Nncase.IR.Buffer;
namespace Nncase.IR.Buffers;
/// <summary>
/// get the buffer from the input.
/// </summary>
public record BufferOf(Schedule.MemoryLocation MemoryLocation) : Op
public sealed partial class BufferOf : Op
{
/// <summary>
/// Get the input parameter.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(BufferOf), 0, "input", IsTensor());
public Schedule.MemoryLocation MemoryLocation { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"Schedule.MemoryLocation.{MemoryLocation}";
}

View File

@ -5,13 +5,13 @@ using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
namespace Nncase.IR.Buffer;
namespace Nncase.IR.Buffers;
/// <summary>
/// DDrOf expression.
/// </summary>
[PatternFunctionalGenerator]
public record DDrOf() : Op
public sealed partial class DDrOf : Op
{
/// <summary>
/// Get the input parameter.

View File

@ -6,7 +6,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR.Buffer;
using Nncase.IR.Buffers;
namespace Nncase.IR.F;

View File

@ -10,13 +10,13 @@ using System.Threading.Tasks;
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
namespace Nncase.IR.Buffer;
namespace Nncase.IR.Buffers;
/// <summary>
/// Shape expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record StrideOf() : Op
public sealed partial class StrideOf : Op
{
/// <summary>
/// Gets input.

View File

@ -4,19 +4,23 @@
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
namespace Nncase.IR.Buffer;
namespace Nncase.IR.Buffers;
/// <summary>
/// Gets input.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Uninitialized(DataType DType, Schedule.MemoryLocation MemoryLocation) : Op
public sealed partial class Uninitialized : Op
{
/// <summary>
/// the shape.
/// </summary>
public static readonly ParameterInfo Shape = new(typeof(Uninitialized), 0, "shape", IsIntegral() & IsTensor() & HasRank(1));
public DataType DType { get; }
public Schedule.MemoryLocation MemoryLocation { get; }
/// <inheritdoc/>
public override bool CanFoldConstCall => false;

View File

@ -7,6 +7,7 @@ using System.Collections.Immutable;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.Utilities;
namespace Nncase.IR;
@ -24,18 +25,32 @@ public interface IParameterList<T>
/// <summary>
/// Call expression.
/// </summary>
public sealed record Call(Expr Target, IRArray<Expr> Parameters) : Expr, IParameterList<Expr>
public sealed class Call : Expr, IParameterList<Expr>
{
/// <summary>
/// Initializes a new instance of the <see cref="Call"/> class.
/// </summary>
/// <param name="target">Call target.</param>
/// <param name="parameters">Parameters.</param>
public Call(Expr target, params Expr[] parameters)
: this(target, new IRArray<Expr>(parameters.ToImmutableArray()))
/// <param name="arguments">Arguments.</param>
public Call(Expr target, ReadOnlySpan<Expr> arguments)
: base(ArrayUtility.Concat(target, arguments))
{
}
/// <summary>
/// Initializes a new instance of the <see cref="Call"/> class.
/// </summary>
/// <param name="target">Call target.</param>
/// <param name="arguments">Arguments.</param>
public Call(Expr target, params Expr[] arguments)
: this(target, (ReadOnlySpan<Expr>)arguments)
{
}
public Expr Target => Operands[0];
public ReadOnlySpan<Expr> Arguments => Operands[1..];
// /// <summary>
// /// used by fake ir, represents that whether this op permit int 16 quant.
// /// </summary>
@ -62,7 +77,7 @@ public sealed record Call(Expr Target, IRArray<Expr> Parameters) : Expr, IParame
var type = Target.GetType();
if (type == parameter.OwnerType)
{
return Parameters[parameter.Index];
return Arguments[parameter.Index];
}
else
{
@ -74,9 +89,16 @@ public sealed record Call(Expr Target, IRArray<Expr> Parameters) : Expr, IParame
public void ParametersForeach(Action<Expr, ParameterInfo> f)
{
var parameterInfos = ((Op)Target).Parameters.ToArray();
for (int i = 0; i < Parameters.Count; i++)
for (int i = 0; i < Arguments.Length; i++)
{
f(Parameters[i], parameterInfos[i]);
f(Arguments[i], parameterInfos[i]);
}
}
/// <inheritdoc/>
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
=> functor.VisitCall(this, context);
public Call With(Expr? target = null, Expr[]? arguments = null)
=> new Call(target ?? Target, arguments ?? Arguments);
}

View File

@ -0,0 +1,32 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Nncase.IR;
/// <summary>
/// the Callable Expr.
/// </summary>
public abstract class Callable : Expr
{
/// <summary>
/// StackVM module kind.
/// </summary>
public static readonly string StackVMModuleKind = "stackvm";
public Callable(string name, string moduleKind, Expr[] operands)
: base(operands)
{
Name = name;
ModuleKind = moduleKind;
}
public string Name { get; }
public string ModuleKind { get; }
}

View File

@ -12,8 +12,20 @@ namespace Nncase.IR;
/// <summary>
/// Constant expression.
/// </summary>
public abstract record Const(IRType ValueType) : Expr
public abstract class Const : Expr
{
/// <summary>
/// Initializes a new instance of the <see cref="Const"/> class.
/// </summary>
/// <param name="valueType">Type of value.</param>
public Const(IRType valueType)
: base(Array.Empty<Expr>())
{
ValueType = valueType;
}
public IRType ValueType { get; }
/// <summary>
/// Create constant from a <see cref="byte"/>.
/// </summary>
@ -125,7 +137,7 @@ public abstract record Const(IRType ValueType) : Expr
else
{
var tpv = (TupleValue)value;
return new TupleConst(tpv.Select(x => FromValue(x)).ToArray());
return new TupleConst(tpv);
}
}
}

View File

@ -7,149 +7,148 @@ using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Nncase.IR
namespace Nncase.IR;
/// <summary>
/// Conversion of expression.
/// </summary>
public abstract partial class Expr
{
/// <summary>
/// Conversion of expression.
/// Create <see cref="Expr"/> from a <see cref="byte"/>.
/// </summary>
public abstract partial record Expr
{
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="byte"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(byte value) => (Const)value;
/// <param name="value">Value.</param>
public static implicit operator Expr(byte value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="ushort"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(ushort value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="ushort"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(ushort value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="uint"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(uint value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="uint"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(uint value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="ulong"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(ulong value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="ulong"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(ulong value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="sbyte"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(sbyte value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="sbyte"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(sbyte value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="short"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(short value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="short"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(short value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="int"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(int value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="int"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(int value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="long"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(long value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="long"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(long value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="Half"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(Half value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="Half"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(Half value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="float"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(float value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="float"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(float value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="double"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(double value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="double"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(double value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="BFloat16"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(BFloat16 value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="BFloat16"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(BFloat16 value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="bool"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(bool value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="bool"/>.
/// </summary>
/// <param name="value">Value.</param>
public static implicit operator Expr(bool value) => (Const)value;
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="Shape"/>.
/// </summary>
/// <param name="shape">Shape.</param>
public static implicit operator Expr(Shape shape) => Const.FromShape(shape);
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="Shape"/>.
/// </summary>
/// <param name="shape">Shape.</param>
public static implicit operator Expr(Shape shape) => Const.FromShape(shape);
/// <summary>
/// Create <see cref="Expr"/> from an array of<see cref="int"/>.
/// </summary>
/// <param name="array">Array.</param>
public static implicit operator Expr(int[] array) => Tensor.From<int>(array);
/// <summary>
/// Create <see cref="Expr"/> from an array of<see cref="int"/>.
/// </summary>
/// <param name="array">Array.</param>
public static implicit operator Expr(int[] array) => Tensor.From<int>(array);
/// <summary>
/// Create <see cref="Expr"/> from an array of<see cref="int"/>.
/// </summary>
/// <param name="array">Array.</param>
public static implicit operator Expr(long[] array) => Tensor.From<long>(array);
/// <summary>
/// Create <see cref="Expr"/> from an array of<see cref="int"/>.
/// </summary>
/// <param name="array">Array.</param>
public static implicit operator Expr(long[] array) => Tensor.From<long>(array);
/// <summary>
/// Create <see cref="Expr"/> from an array of<see cref="float"/>.
/// </summary>
/// <param name="array">Array.</param>
public static implicit operator Expr(float[] array) => Tensor.From<float>(array);
/// <summary>
/// Create <see cref="Expr"/> from an array of<see cref="float"/>.
/// </summary>
/// <param name="array">Array.</param>
public static implicit operator Expr(float[] array) => Tensor.From<float>(array);
/// <summary>
/// Create <see cref="Expr"/> from an array of<see cref="int"/>.
/// </summary>
/// <param name="array">Array.</param>
public static implicit operator Expr(Array array) => Tensor.FromArray(array);
/// <summary>
/// Create <see cref="Expr"/> from an array of<see cref="int"/>.
/// </summary>
/// <param name="array">Array.</param>
public static implicit operator Expr(Array array) => Tensor.FromArray(array);
/// <summary>
/// Create <see cref="Expr"/> from a memory of<see cref="int"/>.
/// </summary>
/// <param name="memory">Span.</param>
public static implicit operator Expr(Memory<int> memory) => Tensor.From<int>(memory);
/// <summary>
/// Create <see cref="Expr"/> from a memory of<see cref="int"/>.
/// </summary>
/// <param name="memory">Span.</param>
public static implicit operator Expr(Memory<int> memory) => Tensor.From<int>(memory);
/// <summary>
/// Create <see cref="Expr"/> from a memory of<see cref="long"/>.
/// </summary>
/// <param name="memory">Span.</param>
public static implicit operator Expr(Memory<long> memory) => Tensor.From<long>(memory);
/// <summary>
/// Create <see cref="Expr"/> from a memory of<see cref="long"/>.
/// </summary>
/// <param name="memory">Span.</param>
public static implicit operator Expr(Memory<long> memory) => Tensor.From<long>(memory);
/// <summary>
/// Create <see cref="Expr"/> from a memory of<see cref="float"/>.
/// </summary>
/// <param name="memory">Span.</param>
public static implicit operator Expr(Memory<float> memory) => Tensor.From<float>(memory);
/// <summary>
/// Create <see cref="Expr"/> from a memory of<see cref="float"/>.
/// </summary>
/// <param name="memory">Span.</param>
public static implicit operator Expr(Memory<float> memory) => Tensor.From<float>(memory);
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="Tensor"/>.
/// </summary>
/// <param name="tensor">Tensor.</param>
public static implicit operator Expr(Tensor tensor) => Const.FromTensor(tensor);
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="Tensor"/>.
/// </summary>
/// <param name="tensor">Tensor.</param>
public static implicit operator Expr(Tensor tensor) => Const.FromTensor(tensor);
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="QuantParam"/>.
/// </summary>
/// <param name="quantParam">QuantParam.</param>
public static implicit operator Expr(QuantParam quantParam) => Tensor.FromScalar(quantParam);
}
/// <summary>
/// Create <see cref="Expr"/> from a <see cref="QuantParam"/>.
/// </summary>
/// <param name="quantParam">QuantParam.</param>
public static implicit operator Expr(QuantParam quantParam) => Tensor.FromScalar(quantParam);
}

View File

@ -13,7 +13,7 @@ namespace Nncase.IR;
/// <summary>
/// Math operators for <see cref="Expr"/>.
/// </summary>
public partial record Expr
public partial class Expr
{
/// <summary>
/// get the item from the expr.

View File

@ -4,27 +4,74 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Toolkit.HighPerformance.Helpers;
namespace Nncase.IR;
/// <summary>
/// Expression.
/// </summary>
public abstract partial record Expr
public abstract partial class Expr : IDisposable
{
private readonly Expr[] _operands;
private readonly HashSet<Expr> _users = new(ReferenceEqualityComparer.Instance);
private IRType? _checkedType;
private int? _hashCodeCache;
private bool _disposedValue;
internal Expr(IEnumerable<Expr> operands)
{
_operands = operands.ToArray();
foreach (var operand in _operands)
{
operand.AddUser(this);
}
}
internal Expr(Expr[] operands)
{
_operands = operands;
foreach (var operand in _operands)
{
operand.AddUser(this);
}
}
/// <summary>
/// Gets or sets checked type.
/// </summary>
public IRType? CheckedType { get; set; }
public IRType CheckedType
{
get
{
if (_checkedType == null)
{
CompilerServices.InferenceType(this);
}
Trace.Assert(_checkedType is not null);
return _checkedType!;
}
set
{
if (_checkedType != value)
{
_checkedType = value;
InvalidateUsersTypeInference();
}
}
}
/// <summary>
/// Gets checked shape.
/// </summary>
public Shape CheckedShape => (CheckedType ?? ((Const)this).ValueType) switch
public Shape CheckedShape => CheckedType switch
{
TensorType type => type.Shape,
_ => throw new InvalidOperationException("Only The Expr Have CheckedType Can Get It's Shape"),
@ -35,30 +82,220 @@ public abstract partial record Expr
/// </summary>
public DataType CheckedDataType => CheckedType switch
{
// todo:more info
TensorType type => type.DType,
_ => throw new InvalidOperationException("Expr don't have a valid tensor type"),
};
/// <summary>
/// Gets or sets hash code cache.
/// Gets users.
/// </summary>
protected int? HashCodeCache { get; set; }
public IReadOnlyCollection<Expr> Users => EnsureAlive()._users;
/// <summary>
/// Gets operands.
/// </summary>
public ReadOnlySpan<Expr> Operands => EnsureAlive()._operands;
/// <summary>
/// Gets or sets raw checked type.
/// </summary>
internal IRType? RawCheckedType
{
get => _checkedType;
set => _checkedType = value;
}
public static bool operator ==(Expr? left, Expr? right) => EqualityComparer<Expr>.Default.Equals(left, right);
public static bool operator !=(Expr? left, Expr? right) => !(left == right);
/// <summary>
/// Accept a <see cref="ExprFunctor{TExprResult, TTypeResult, TContext}"/>.
/// </summary>
/// <typeparam name="TExprResult">Result type of visiting expressions.</typeparam>
/// <typeparam name="TTypeResult">Result type of visiting types.</typeparam>
/// <typeparam name="TContext">Visit context.</typeparam>
/// <param name="functor">Expression functor.</param>
/// <param name="context">Context.</param>
/// <returns>Visit result.</returns>
public abstract TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context);
/// <inheritdoc/>
public virtual bool Equals(Expr? other)
public override string ToString()
{
return !(other is null) && EqualityContract == other.EqualityContract;
return GetType().ToString();
}
/// <inheritdoc/>
public override int GetHashCode()
public override bool Equals(object? obj)
{
return HashCodeCache ??= EqualityComparer<Type>.Default.GetHashCode(EqualityContract);
if (ReferenceEquals(this, obj))
{
return true;
}
return obj is Expr other && GetHashCode() == other.GetHashCode() && Operands.SequenceEqual(other.Operands);
}
protected virtual bool PrintMembers(StringBuilder builder)
/// <inheritdoc/>
public sealed override int GetHashCode() => _hashCodeCache ??= GetHashCodeCore();
public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
public void DisposeIfNoUsers()
{
if (_users.Count == 0)
{
Dispose();
}
}
internal void AddUser(Expr user)
{
EnsureAlive();
Trace.Assert(!ReferenceEquals(this, user));
_users.Add(user.EnsureAlive());
}
internal void RemoveUser(Expr user)
{
_users.Remove(user);
}
internal void ReplaceOperand(int index, Expr newOperand)
{
ref var operand = ref _operands[index];
if (!ReferenceEquals(operand, newOperand))
{
newOperand.AddUser(this);
operand.RemoveUser(this);
operand = newOperand;
OnOperandsReplaced();
}
}
internal void ReplaceAllUsesWith(Expr newOperand)
=> ReplaceScopedUsesWith(newOperand, null);
internal void ReplaceScopedUsesWith(Expr newOperand, IReadOnlySet<Expr>? scope)
{
EnsureAlive();
if (!ReferenceEquals(this, newOperand))
{
foreach (var user in Users.ToArray())
{
if ((scope is null || scope.Contains(user))
&& !newOperand.IsDescendantOf(this))
{
newOperand.AddUser(user);
var operands = user._operands;
for (int i = 0; i < operands.Length; i++)
{
ref var operand = ref operands[i];
if (ReferenceEquals(operand, this))
{
operand = newOperand;
}
}
user.OnOperandsReplaced();
RemoveUser(user);
}
}
}
}
protected virtual int GetHashCodeCore()
{
return HashCode.Combine(GetType(), HashCode<Expr>.Combine(Operands));
}
protected virtual void Dispose(bool disposing)
{
if (!_disposedValue)
{
foreach (var operand in _operands)
{
operand.RemoveUser(this);
operand.DisposeIfNoUsers();
}
_disposedValue = true;
}
}
private bool IsDescendantOf(Expr other)
{
foreach (var operand in _operands)
{
if (ReferenceEquals(operand, other))
{
return true;
}
}
foreach (var operand in _operands)
{
if (operand.IsDescendantOf(other))
{
return true;
}
}
return false;
}
private void OnOperandsReplaced()
{
InvalidateTypeInference();
InvalidateHashCodeCache();
}
private void InvalidateTypeInference()
{
if (_checkedType != null)
{
_checkedType = null;
InvalidateUsersTypeInference();
}
}
private void InvalidateUsersTypeInference()
{
foreach (var user in Users)
{
user.InvalidateTypeInference();
}
}
private void InvalidateHashCodeCache()
{
if (_hashCodeCache != null)
{
_hashCodeCache = null;
InvalidateUsersHashCodeCache();
}
}
private void InvalidateUsersHashCodeCache()
{
foreach (var user in Users)
{
user.InvalidateHashCodeCache();
}
}
private Expr EnsureAlive()
{
if (_disposedValue)
{
throw new ObjectDisposedException(null);
}
return this;
}
}

View File

@ -0,0 +1,53 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive;
using System.Text;
using System.Threading.Tasks;
using Nncase.TIR;
namespace Nncase.IR;
public static class ExprClonerExtensions
{
public static T Clone<T>(this T expr, bool cloneOtherFunctions = false)
where T : Expr
{
return new ExprCloner<Unit>(cloneOtherFunctions).Clone(expr, default);
}
}
/// <summary>
/// Expression cloner.
/// </summary>
/// <typeparam name="TContext">Clone context.</typeparam>
public partial class ExprCloner<TContext> : ExprVisitor<Expr, IRType, TContext>
{
/// <summary>
/// Initializes a new instance of the <see cref="ExprCloner{TContext}"/> class.
/// </summary>
/// <param name="cloneOtherFunctions">Clone other functions.</param>
public ExprCloner(bool cloneOtherFunctions = false)
: base(cloneOtherFunctions)
{
}
public T Clone<T>(T expr, TContext context)
where T : Expr
=> (T)Visit(expr, context);
protected T[] CloneArray<T>(ReadOnlySpan<T> values, TContext context)
where T : Expr
{
var array = new T[values.Length];
for (int i = 0; i < values.Length; i++)
{
array[i] = Clone(values[i], context);
}
return array;
}
}

View File

@ -0,0 +1,258 @@

//---------------------------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by T4 template.
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
// </auto-generated>
//---------------------------------------------------------------------------------------------------
using System;
using System.Collections.Generic;
using System.Reactive;
namespace Nncase.IR;
public partial class ExprCloner<TContext>
{
/// <inheritdoc />
protected override Expr VisitLeafCall(Call expr, TContext context)
{
return expr.With(
target: Clone(expr.Target, context),
arguments: CloneArray(expr.Arguments, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafFunction(Function expr, TContext context)
{
if (!CanVisitFunctionBody(expr))
{
return expr;
}
return expr.With(
parameters: CloneArray(expr.Parameters, context),
body: Clone(expr.Body, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafFusion(Fusion expr, TContext context)
{
if (!CanVisitFunctionBody(expr))
{
return expr;
}
return expr.With(
parameters: CloneArray(expr.Parameters, context),
body: Clone(expr.Body, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafIf(If expr, TContext context)
{
return expr.With(
condition: Clone(expr.Condition, context),
then: Clone(expr.Then, context),
@else: Clone(expr.Else, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafMarker(Marker expr, TContext context)
{
return expr.With(
target: Clone(expr.Target, context),
attribute: Clone(expr.Attribute, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafNone(None expr, TContext context)
{
return expr.With(
);
}
/// <inheritdoc />
protected override Expr VisitLeafOp(Op expr, TContext context)
{
return expr.With(
);
}
/// <inheritdoc />
protected override Expr VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, TContext context)
{
if (!CanVisitFunctionBody(expr))
{
return expr;
}
return expr.With(
target: Clone(expr.Target, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafTensorConst(TensorConst expr, TContext context)
{
return expr.With(
);
}
/// <inheritdoc />
protected override Expr VisitLeafTuple(IR.Tuple expr, TContext context)
{
return expr.With(
fields: CloneArray(expr.Fields, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafTupleConst(TupleConst expr, TContext context)
{
return expr.With(
);
}
/// <inheritdoc />
protected override Expr VisitLeafVar(Var expr, TContext context)
{
return expr.With(
);
}
/// <inheritdoc />
protected override Expr VisitLeafBlock(TIR.Block expr, TContext context)
{
return expr.With(
body: Clone(expr.Body, context),
initBody: Clone(expr.InitBody, context),
iterVars: CloneArray(expr.IterVars, context),
reads: CloneArray(expr.Reads, context),
writes: CloneArray(expr.Writes, context),
allocBuffers: CloneArray(expr.AllocBuffers, context),
predicate: Clone(expr.Predicate, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context)
{
return expr.With(
dimensions: CloneArray(expr.Dimensions, context),
strides: CloneArray(expr.Strides, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context)
{
return expr.With(
);
}
/// <inheritdoc />
protected override Expr VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context)
{
return expr.With(
buffer: Clone(expr.Buffer, context),
indices: CloneArray(expr.Indices, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context)
{
return expr.With(
buffer: Clone(expr.Buffer, context),
region: CloneArray(expr.Region, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafBufferStore(TIR.BufferStore expr, TContext context)
{
return expr.With(
buffer: Clone(expr.Buffer, context),
indices: CloneArray(expr.Indices, context),
value: Clone(expr.Value, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafFor(TIR.For expr, TContext context)
{
return expr.With(
loopVar: Clone(expr.LoopVar, context),
domain: Clone(expr.Domain, context),
body: Clone(expr.Body, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafIfThenElse(TIR.IfThenElse expr, TContext context)
{
return expr.With(
condition: Clone(expr.Condition, context),
then: Clone(expr.Then, context),
@else: Clone(expr.Else, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafLet(TIR.Let expr, TContext context)
{
return expr.With(
var: Clone(expr.Var, context),
expression: Clone(expr.Expression, context),
body: Clone(expr.Body, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafPrimFunction(TIR.PrimFunction expr, TContext context)
{
if (!CanVisitFunctionBody(expr))
{
return expr;
}
return expr.With(
parameters: CloneArray(expr.Parameters, context),
body: Clone(expr.Body, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafSequential(TIR.Sequential expr, TContext context)
{
return expr.With(
fields: CloneArray(expr.Fields, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafRange(TIR.Range expr, TContext context)
{
return expr.With(
start: Clone(expr.Start, context),
stop: Clone(expr.Stop, context),
step: Clone(expr.Step, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafIterVar(TIR.IterVar expr, TContext context)
{
return expr.With(
value: Clone(expr.Value, context),
dom: Clone(expr.Dom, context)
);
}
}

View File

@ -0,0 +1,66 @@
<#@ template debug="false" hostspecific="false" language="C#" #>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.IO" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ output extension=".cs" #>
<#@ include file="IRListParser.tt"#>
//---------------------------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by T4 template.
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
// </auto-generated>
//---------------------------------------------------------------------------------------------------
using System;
using System.Collections.Generic;
using System.Reactive;
namespace Nncase.IR;
public partial class ExprCloner<TContext>
{
<#
foreach (var ir in irs.Where(x => x.IsDerived))
{
#>
/// <inheritdoc />
protected override Expr VisitLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, TContext context)
{
<#
if (ir.IsFunction)
{
#>
if (!CanVisitFunctionBody(expr))
{
return expr;
}
<#
}
#>
return expr.With(
<#
for (int i = 0; i < ir.Fields.Length; i++)
{
var field = ir.Fields[i];
var func = field.StartsWith("@") ? "CloneArray" : "Clone";
var fieldName = field.TrimStart('@');
var paramName = $"{char.ToLowerInvariant(fieldName[0])}{fieldName.Substring(1)}";
if (paramName == "else")
{
paramName = "@" + paramName;
}
#>
<#=paramName#>: <#=func#>(expr.<#=fieldName#>, context)<#=i == ir.Fields.Length - 1 ? string.Empty : ","#>
<#
}
#>
);
}
<#
}
#>
}

View File

@ -0,0 +1,32 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive;
using System.Text;
using System.Threading.Tasks;
using Nncase.TIR;
namespace Nncase.IR;
public sealed class ExprCollector : ExprWalker<List<Expr>>
{
private ExprCollector()
{
}
public static IReadOnlyList<Expr> Collect(Expr expr)
{
var exprs = new List<Expr>();
new ExprCollector().Visit(expr, exprs);
return exprs;
}
protected override Unit DefaultVisitLeaf(Expr expr, List<Expr> context)
{
context.Add(expr);
return default;
}
}

View File

@ -3,238 +3,159 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reactive;
using System.Text;
using System.Threading.Tasks;
namespace Nncase.IR
namespace Nncase.IR;
/// <summary>
/// Expression functor.
/// </summary>
/// <typeparam name="TExprResult">Expression visit result type.</typeparam>
/// <typeparam name="TTypeResult">Type visit result type.</typeparam>
/// <typeparam name="TContext">Visit context.</typeparam>
public abstract partial class ExprFunctor<TExprResult, TTypeResult, TContext> : TypeFunctor<TTypeResult, TContext>
{
/// <summary>
/// Expression functor.
/// Gets visit root.
/// </summary>
/// <typeparam name="TExprResult">Expression visit result type.</typeparam>
/// <typeparam name="TTypeResult">Type visit result type.</typeparam>
public abstract class ExprFunctor<TExprResult, TTypeResult> : TypeFunctor<TTypeResult>
protected Expr? VisitRoot { get; private set; }
/// <summary>
/// Visit <see cref="Expr"/>.
/// </summary>
public TExprResult Visit(Expr expr, TContext context)
{
/// <summary>
/// Visit expression.
/// </summary>
/// <param name="expr">Expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(Expr expr)
{
return expr switch
{
Var var => Visit(var),
Const con => Visit(con),
Call call => Visit(call),
If @if => Visit(@if),
Tuple tuple => Visit(tuple),
Op op => Visit(op),
None none => Visit(none),
Marker marker => Visit(marker),
BaseFunction basefunc => Visit(basefunc),
TIR.IterVar itvar => Visit(itvar),
TIR.Sequential seq => Visit(seq),
TIR.For @for => Visit(@for),
TIR.Block block => Visit(block),
TIR.BufferLoad bload => Visit(bload),
TIR.BufferStore bstore => Visit(bstore),
TIR.IfThenElse ift => Visit(ift),
TIR.Let let => Visit(let),
TIR.Buffer buffer => Visit(buffer),
TIR.BufferRegion region => Visit(region),
_ => DefaultVisit(expr),
};
}
/// <summary>
/// Visit Basefunction expression.
/// </summary>
/// <param name="baseFunction"> base function. </param>
public virtual TExprResult Visit(BaseFunction baseFunction) => baseFunction switch
{
Function func => Visit(func),
Fusion fusion => Visit(fusion),
PrimFunctionWrapper wrapper => Visit(wrapper),
TIR.PrimFunction primfunc => Visit(primfunc),
_ => DefaultVisit(baseFunction),
};
/// <summary>
/// Visit variable expression.
/// </summary>
/// <param name="expr">Variable expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(Var expr) => DefaultVisit(expr);
/// <summary>
/// Visit constant expression.
/// </summary>
/// <param name="expr">Constant expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(Const expr) => DefaultVisit(expr);
/// <summary>
/// Visit function expression.
/// </summary>
/// <param name="expr">Variable expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(Function expr) => DefaultVisit(expr);
/// <summary>
/// Visit fusion expression.
/// </summary>
/// <param name="expr">Fusion Expression.</param>
public virtual TExprResult Visit(Fusion expr) => DefaultVisit(expr);
/// <summary>
/// Visit prim function wrapper expression.
/// </summary>
/// <param name="expr">PrimFunctionWrapper expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(PrimFunctionWrapper expr) => DefaultVisit(expr);
/// <summary>
/// Visit prim function expression.
/// </summary>
/// <param name="expr">Variable expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(TIR.PrimFunction expr) => DefaultVisit(expr);
/// <summary>
/// Visit call expression.
/// </summary>
/// <param name="expr">Call expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(Call expr) => DefaultVisit(expr);
/// <summary>
/// Visit if expression.
/// </summary>
/// <param name="expr">If expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(If expr) => DefaultVisit(expr);
/// <summary>
/// Visit tuple expression.
/// </summary>
/// <param name="expr">Variable expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(Tuple expr) => DefaultVisit(expr);
/// <summary>
/// Visit operator expression.
/// </summary>
/// <param name="expr">Operator expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(Op expr) => DefaultVisit(expr);
/// <summary>
/// Visit None expression.
/// </summary>
/// <param name="expr">None expr.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(None expr) => DefaultVisit(expr);
/// <summary>
/// Visit marker expression.
/// </summary>
/// <param name="expr">Marker expr.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(Marker expr) => DefaultVisit(expr);
/// <summary>
/// Visit IterVar expression.
/// </summary>
/// <param name="expr">IterVar expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(TIR.IterVar expr) => DefaultVisit(expr);
/// <summary>
/// Visit Sequential expression.
/// </summary>
/// <param name="expr">Sequential expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(TIR.Sequential expr) => DefaultVisit(expr);
/// <summary>
/// Visit For expression.
/// </summary>
/// <param name="expr">For expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(TIR.For expr) => DefaultVisit(expr);
/// <summary>
/// Visit block expression.
/// </summary>
/// <param name="expr">block expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(TIR.Block expr) => DefaultVisit(expr);
/// <summary>
/// Visit BufferLoad expression.
/// </summary>
/// <param name="expr">BufferLoad expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(TIR.BufferLoad expr) => DefaultVisit(expr);
/// <summary>
/// Visit BufferStore expression.
/// </summary>
/// <param name="expr">BufferStore expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(TIR.BufferStore expr) => DefaultVisit(expr);
/// <summary>
/// Visit IfThenElse expression.
/// </summary>
/// <param name="expr">IfThenElse expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(TIR.IfThenElse expr) => DefaultVisit(expr);
/// <summary>
/// Visit Let expression.
/// </summary>
/// <param name="expr">Let expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(TIR.Let expr) => DefaultVisit(expr);
/// <summary>
/// Visit MemRef expression.
/// </summary>
/// <param name="expr">MemRef expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(TIR.Buffer expr) => DefaultVisit(expr);
/// <summary>
/// Visit buffer region expression.
/// </summary>
/// <param name="expr">buffer region expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult Visit(TIR.BufferRegion expr) => DefaultVisit(expr);
/// <summary>
/// Visit visitable.
/// </summary>
public virtual object Visit(IVisitable visitable) => DefaultVisit(visitable);
/// <summary>
/// Default Visit IVisitable.
/// </summary>
public virtual object DefaultVisit(IVisitable visitable)
{
return visitable.Visit<TExprResult, TTypeResult>(this);
}
/// <summary>
/// Default visit routine.
/// </summary>
/// <param name="expr">Expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult DefaultVisit(Expr expr)
{
throw new NotImplementedException($"Unhandled visit routine for {expr.GetType()}.");
}
VisitRoot ??= expr;
return DispatchVisit(expr, context);
}
/// <summary>
/// Clear functor states.
/// </summary>
public virtual void Clear()
{
VisitRoot = null;
}
/// <summary>
/// Default visit routine.
/// </summary>
/// <param name="expr">Expression.</param>
/// <param name="context">Context.</param>
/// <returns>Result.</returns>
protected internal virtual TExprResult DefaultVisit(Expr expr, TContext context)
{
throw new NotImplementedException($"Unhandled visit routine for {expr.GetType()}.");
}
protected virtual TExprResult DispatchVisit(Expr expr, TContext context) => expr.Accept(this, context);
}
/// <summary>
/// Expression functor.
/// </summary>
/// <typeparam name="TExprResult">Expression visit result type.</typeparam>
/// <typeparam name="TTypeResult">Type visit result type.</typeparam>
public partial class ExprFunctor<TExprResult, TTypeResult> : ExprFunctor<TExprResult, TTypeResult, Unit>
{
/// <summary>
/// Visit <see cref="Expr"/>.
/// </summary>
public TExprResult Visit(Expr expr) => Visit(expr, default);
/// <summary>
/// Visit type.
/// </summary>
/// <param name="type">Type.</param>
/// <returns>Result.</returns>
public TTypeResult VisitType(IRType type) => VisitType(type, default);
/// <summary>
/// Visit any type.
/// </summary>
/// <param name="type">Any type.</param>
/// <returns>Result.</returns>
public virtual TTypeResult VisitType(AnyType type) => base.VisitType(type, default);
/// <summary>
/// Visit None type.
/// </summary>
/// <param name="type">None type.</param>
/// <returns>Result.</returns>
public virtual TTypeResult VisitType(NoneType type) => base.VisitType(type, default);
/// <summary>
/// Visit invalid type.
/// </summary>
/// <param name="type">Invalid type.</param>
/// <returns>Result.</returns>
public virtual TTypeResult VisitType(InvalidType type) => base.VisitType(type, default);
/// <summary>
/// Visit tensor type.
/// </summary>
/// <param name="type">Tensor type.</param>
/// <returns>Result.</returns>
public virtual TTypeResult VisitType(TensorType type) => base.VisitType(type, default);
/// <summary>
/// Visit tuple type.
/// </summary>
/// <param name="type">Tuple type.</param>
/// <returns>Result.</returns>
public virtual TTypeResult VisitType(TupleType type) => base.VisitType(type, default);
/// <summary>
/// Visit callable type.
/// </summary>
/// <param name="type">Callable type.</param>
/// <returns>Result.</returns>
public virtual TTypeResult VisitType(CallableType type) => base.VisitType(type, default);
/// <summary>
/// Default visit routine.
/// </summary>
/// <param name="type">Type.</param>
/// <returns>Result.</returns>
public virtual TTypeResult DefaultVisitType(IRType type) => base.DefaultVisitType(type, default);
/// <inheritdoc/>
public sealed override TTypeResult VisitType(AnyType type, Unit context) => VisitType(type);
/// <inheritdoc/>
public sealed override TTypeResult VisitType(NoneType type, Unit context) => VisitType(type);
/// <inheritdoc/>
public sealed override TTypeResult VisitType(InvalidType type, Unit context) => VisitType(type);
/// <inheritdoc/>
public sealed override TTypeResult VisitType(TensorType type, Unit context) => VisitType(type);
/// <inheritdoc/>
public sealed override TTypeResult VisitType(TupleType type, Unit context) => VisitType(type);
/// <inheritdoc/>
public sealed override TTypeResult VisitType(CallableType type, Unit context) => VisitType(type);
/// <inheritdoc/>
public sealed override TTypeResult DefaultVisitType(IRType type, Unit context) => DefaultVisitType(type);
/// <summary>
/// Default visit routine.
/// </summary>
/// <param name="expr">Expression.</param>
/// <returns>Result.</returns>
protected internal virtual TExprResult DefaultVisit(Expr expr) => base.DefaultVisit(expr, default);
/// <inheritdoc/>
protected internal sealed override TExprResult DefaultVisit(Expr expr, Unit context) => DefaultVisit(expr);
protected virtual TExprResult DispatchVisit(Expr expr) => base.DispatchVisit(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult DispatchVisit(Expr expr, Unit context) => DispatchVisit(expr);
}

View File

@ -0,0 +1,357 @@

//---------------------------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by T4 template.
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
// </auto-generated>
//---------------------------------------------------------------------------------------------------
using System;
using System.Collections.Generic;
using System.Reactive;
namespace Nncase.IR;
public partial class ExprFunctor<TExprResult, TTypeResult, TContext>
{
/// <summary>
/// Visit <see cref="BaseFunction"/>.
/// </summary>
internal protected virtual TExprResult VisitBaseFunction(BaseFunction expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="Call"/>.
/// </summary>
internal protected virtual TExprResult VisitCall(Call expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="Const"/>.
/// </summary>
internal protected virtual TExprResult VisitConst(Const expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="Function"/>.
/// </summary>
internal protected virtual TExprResult VisitFunction(Function expr, TContext context) => VisitBaseFunction(expr, context);
/// <summary>
/// Visit <see cref="Fusion"/>.
/// </summary>
internal protected virtual TExprResult VisitFusion(Fusion expr, TContext context) => VisitBaseFunction(expr, context);
/// <summary>
/// Visit <see cref="If"/>.
/// </summary>
internal protected virtual TExprResult VisitIf(If expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="Marker"/>.
/// </summary>
internal protected virtual TExprResult VisitMarker(Marker expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="None"/>.
/// </summary>
internal protected virtual TExprResult VisitNone(None expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="Op"/>.
/// </summary>
internal protected virtual TExprResult VisitOp(Op expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="PrimFunctionWrapper"/>.
/// </summary>
internal protected virtual TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, TContext context) => VisitBaseFunction(expr, context);
/// <summary>
/// Visit <see cref="TensorConst"/>.
/// </summary>
internal protected virtual TExprResult VisitTensorConst(TensorConst expr, TContext context) => VisitConst(expr, context);
/// <summary>
/// Visit <see cref="IR.Tuple"/>.
/// </summary>
internal protected virtual TExprResult VisitTuple(IR.Tuple expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TupleConst"/>.
/// </summary>
internal protected virtual TExprResult VisitTupleConst(TupleConst expr, TContext context) => VisitConst(expr, context);
/// <summary>
/// Visit <see cref="Var"/>.
/// </summary>
internal protected virtual TExprResult VisitVar(Var expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.Block"/>.
/// </summary>
internal protected virtual TExprResult VisitBlock(TIR.Block expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.Buffer"/>.
/// </summary>
internal protected virtual TExprResult VisitBuffer(TIR.Buffer expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.LogicalBuffer"/>.
/// </summary>
internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => VisitBuffer(expr, context);
/// <summary>
/// Visit <see cref="TIR.PhysicalBuffer"/>.
/// </summary>
internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitBuffer(expr, context);
/// <summary>
/// Visit <see cref="TIR.BufferLoad"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.BufferRegion"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.BufferStore"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.For"/>.
/// </summary>
internal protected virtual TExprResult VisitFor(TIR.For expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.IfThenElse"/>.
/// </summary>
internal protected virtual TExprResult VisitIfThenElse(TIR.IfThenElse expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.Let"/>.
/// </summary>
internal protected virtual TExprResult VisitLet(TIR.Let expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.PrimFunction"/>.
/// </summary>
internal protected virtual TExprResult VisitPrimFunction(TIR.PrimFunction expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.Sequential"/>.
/// </summary>
internal protected virtual TExprResult VisitSequential(TIR.Sequential expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.Range"/>.
/// </summary>
internal protected virtual TExprResult VisitRange(TIR.Range expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.IterVar"/>.
/// </summary>
internal protected virtual TExprResult VisitIterVar(TIR.IterVar expr, TContext context) => DefaultVisit(expr, context);
}
public partial class ExprFunctor<TExprResult, TTypeResult>
{
/// <summary>
/// Visit <see cref="BaseFunction"/>.
/// </summary>
internal protected virtual TExprResult VisitBaseFunction(BaseFunction expr) => base.VisitBaseFunction(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBaseFunction(BaseFunction expr, Unit context) => VisitBaseFunction(expr);
/// <summary>
/// Visit <see cref="Call"/>.
/// </summary>
internal protected virtual TExprResult VisitCall(Call expr) => base.VisitCall(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitCall(Call expr, Unit context) => VisitCall(expr);
/// <summary>
/// Visit <see cref="Const"/>.
/// </summary>
internal protected virtual TExprResult VisitConst(Const expr) => base.VisitConst(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitConst(Const expr, Unit context) => VisitConst(expr);
/// <summary>
/// Visit <see cref="Function"/>.
/// </summary>
internal protected virtual TExprResult VisitFunction(Function expr) => base.VisitFunction(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitFunction(Function expr, Unit context) => VisitFunction(expr);
/// <summary>
/// Visit <see cref="Fusion"/>.
/// </summary>
internal protected virtual TExprResult VisitFusion(Fusion expr) => base.VisitFusion(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitFusion(Fusion expr, Unit context) => VisitFusion(expr);
/// <summary>
/// Visit <see cref="If"/>.
/// </summary>
internal protected virtual TExprResult VisitIf(If expr) => base.VisitIf(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitIf(If expr, Unit context) => VisitIf(expr);
/// <summary>
/// Visit <see cref="Marker"/>.
/// </summary>
internal protected virtual TExprResult VisitMarker(Marker expr) => base.VisitMarker(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitMarker(Marker expr, Unit context) => VisitMarker(expr);
/// <summary>
/// Visit <see cref="None"/>.
/// </summary>
internal protected virtual TExprResult VisitNone(None expr) => base.VisitNone(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitNone(None expr, Unit context) => VisitNone(expr);
/// <summary>
/// Visit <see cref="Op"/>.
/// </summary>
internal protected virtual TExprResult VisitOp(Op expr) => base.VisitOp(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitOp(Op expr, Unit context) => VisitOp(expr);
/// <summary>
/// Visit <see cref="PrimFunctionWrapper"/>.
/// </summary>
internal protected virtual TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitPrimFunctionWrapper(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitPrimFunctionWrapper(expr);
/// <summary>
/// Visit <see cref="TensorConst"/>.
/// </summary>
internal protected virtual TExprResult VisitTensorConst(TensorConst expr) => base.VisitTensorConst(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr);
/// <summary>
/// Visit <see cref="IR.Tuple"/>.
/// </summary>
internal protected virtual TExprResult VisitTuple(IR.Tuple expr) => base.VisitTuple(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitTuple(IR.Tuple expr, Unit context) => VisitTuple(expr);
/// <summary>
/// Visit <see cref="TupleConst"/>.
/// </summary>
internal protected virtual TExprResult VisitTupleConst(TupleConst expr) => base.VisitTupleConst(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitTupleConst(TupleConst expr, Unit context) => VisitTupleConst(expr);
/// <summary>
/// Visit <see cref="Var"/>.
/// </summary>
internal protected virtual TExprResult VisitVar(Var expr) => base.VisitVar(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitVar(Var expr, Unit context) => VisitVar(expr);
/// <summary>
/// Visit <see cref="TIR.Block"/>.
/// </summary>
internal protected virtual TExprResult VisitBlock(TIR.Block expr) => base.VisitBlock(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBlock(TIR.Block expr, Unit context) => VisitBlock(expr);
/// <summary>
/// Visit <see cref="TIR.Buffer"/>.
/// </summary>
internal protected virtual TExprResult VisitBuffer(TIR.Buffer expr) => base.VisitBuffer(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBuffer(TIR.Buffer expr, Unit context) => VisitBuffer(expr);
/// <summary>
/// Visit <see cref="TIR.LogicalBuffer"/>.
/// </summary>
internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLogicalBuffer(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLogicalBuffer(expr);
/// <summary>
/// Visit <see cref="TIR.PhysicalBuffer"/>.
/// </summary>
internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitPhysicalBuffer(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr);
/// <summary>
/// Visit <see cref="TIR.BufferLoad"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr) => base.VisitBufferLoad(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBufferLoad(TIR.BufferLoad expr, Unit context) => VisitBufferLoad(expr);
/// <summary>
/// Visit <see cref="TIR.BufferRegion"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBufferRegion(TIR.BufferRegion expr, Unit context) => VisitBufferRegion(expr);
/// <summary>
/// Visit <see cref="TIR.BufferStore"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr) => base.VisitBufferStore(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBufferStore(TIR.BufferStore expr, Unit context) => VisitBufferStore(expr);
/// <summary>
/// Visit <see cref="TIR.For"/>.
/// </summary>
internal protected virtual TExprResult VisitFor(TIR.For expr) => base.VisitFor(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitFor(TIR.For expr, Unit context) => VisitFor(expr);
/// <summary>
/// Visit <see cref="TIR.IfThenElse"/>.
/// </summary>
internal protected virtual TExprResult VisitIfThenElse(TIR.IfThenElse expr) => base.VisitIfThenElse(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitIfThenElse(TIR.IfThenElse expr, Unit context) => VisitIfThenElse(expr);
/// <summary>
/// Visit <see cref="TIR.Let"/>.
/// </summary>
internal protected virtual TExprResult VisitLet(TIR.Let expr) => base.VisitLet(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitLet(TIR.Let expr, Unit context) => VisitLet(expr);
/// <summary>
/// Visit <see cref="TIR.PrimFunction"/>.
/// </summary>
internal protected virtual TExprResult VisitPrimFunction(TIR.PrimFunction expr) => base.VisitPrimFunction(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitPrimFunction(TIR.PrimFunction expr, Unit context) => VisitPrimFunction(expr);
/// <summary>
/// Visit <see cref="TIR.Sequential"/>.
/// </summary>
internal protected virtual TExprResult VisitSequential(TIR.Sequential expr) => base.VisitSequential(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitSequential(TIR.Sequential expr, Unit context) => VisitSequential(expr);
/// <summary>
/// Visit <see cref="TIR.Range"/>.
/// </summary>
internal protected virtual TExprResult VisitRange(TIR.Range expr) => base.VisitRange(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitRange(TIR.Range expr, Unit context) => VisitRange(expr);
/// <summary>
/// Visit <see cref="TIR.IterVar"/>.
/// </summary>
internal protected virtual TExprResult VisitIterVar(TIR.IterVar expr) => base.VisitIterVar(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitIterVar(TIR.IterVar expr, Unit context) => VisitIterVar(expr);
}

View File

@ -0,0 +1,55 @@
<#@ template debug="false" hostspecific="false" language="C#" #>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.IO" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ output extension=".cs" #>
<#@ include file="IRListParser.tt"#>
//---------------------------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by T4 template.
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
// </auto-generated>
//---------------------------------------------------------------------------------------------------
using System;
using System.Collections.Generic;
using System.Reactive;
namespace Nncase.IR;
public partial class ExprFunctor<TExprResult, TTypeResult, TContext>
{
<#
foreach (var ir in irs)
{
var func = ir.VisitBase == "Default" ? "DefaultVisit" : $"Visit{ir.VisitBase}";
#>
/// <summary>
/// Visit <see cref="<#=ir.Namespace#><#=ir.Name#>"/>.
/// </summary>
internal protected virtual TExprResult Visit<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, TContext context) => <#=func#>(expr, context);
<#
}
#>
}
public partial class ExprFunctor<TExprResult, TTypeResult>
{
<#
foreach (var ir in irs)
{
#>
/// <summary>
/// Visit <see cref="<#=ir.Namespace#><#=ir.Name#>"/>.
/// </summary>
internal protected virtual TExprResult Visit<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr) => base.Visit<#=ir.Name#>(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult Visit<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, Unit context) => Visit<#=ir.Name#>(expr);
<#
}
#>
}

View File

@ -1,781 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Nncase.IR;
/// <summary>
/// IVisitable interface for the custom class visit leaf.
/// </summary>
public interface IVisitable
{
/// <summary>
/// accept the visit.
/// </summary>
object Visit<TExprResult, TTypeResult>(ExprFunctor<TExprResult, TTypeResult> functor);
}
/// <summary>
/// IMutatable Define.
/// </summary>
public interface IMutatable : IVisitable
{
#if false
/// <summary>
/// mutate the current object.
/// NOTE In order to ensure the consistency of coding, please return a new object.
/// </summary>
/// <param name="mutator">ExprMutator.</param>
/// <returns> new instance. </returns>
// object MutateLeaf(ExprMutator mutator);
#endif
/// <summary>
/// recursive build new object.
/// </summary>
/// <param name="mutator">ExprMutator.</param>
object WithNew(ExprVisitor<Expr, IRType> mutator);
}
/// <summary>
/// Deep Expression matutor.
/// </summary>
public abstract class DeepExprMutator : ExprVisitor<Expr, IRType>
{
/// <summary>
/// The Struct Equal Memo folding the const/op.
/// </summary>
private readonly Dictionary<Expr, Expr> _exprSEqualMemo = new();
/// <summary>
/// Gets the Struct Equal Memo.
/// </summary>
public Dictionary<Expr, Expr> ExpressionStructMemo => _exprSEqualMemo;
/// <summary>
/// Gets or sets a value indicating whether for speedup the Mutator, If is Mutated we need MutateLeaf recursive.
/// </summary>
public bool IsMutated { get; set; }
/// <inheritdoc/>
public override Expr VisitLeaf(Call expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
Target = Visit(expr.Target),
Parameters = MutateArray(expr.Parameters, Visit),
};
}
public override Expr VisitLeaf(If expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
Condition = Visit(expr.Condition),
Then = Visit(expr.Then),
Else = Visit(expr.Else),
};
}
/// <inheritdoc/>
public override Expr VisitLeaf(Const expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return StructEqualFolding(expr);
}
/// <inheritdoc/>
public override Expr VisitLeaf(Function expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
Body = Visit(expr.Body),
Parameters = new(expr.Parameters.Select(x => (Var)Visit(x))),
};
}
/// <inheritdoc/>
public override Expr VisitLeaf(Fusion expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
Body = Visit(expr.Body),
Parameters = new(expr.Parameters.Select(x => (Var)Visit(x))),
};
}
/// <inheritdoc/>
public override Expr VisitLeaf(PrimFunctionWrapper expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
Target = (TIR.PrimFunction)Visit((IR.BaseFunction)expr.Target),
};
}
/// <inheritdoc/>
public override Expr VisitLeaf(TIR.PrimFunction expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
Body = (TIR.Sequential)Visit(expr.Body),
Parameters = new(expr.Parameters.Select(x => (TIR.PhysicalBuffer)Visit(x))),
};
}
/// <inheritdoc/>
public override Expr VisitLeaf(Op expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr;
}
/// <inheritdoc/>
public override Expr VisitLeaf(Tuple expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
Fields = MutateArray(expr.Fields, Visit),
};
}
/// <inheritdoc/>
public override Expr VisitLeaf(Var expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr;
}
/// <inheritdoc/>
public override Expr VisitLeaf(None expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr;
}
/// <inheritdoc/>
public override Expr VisitLeaf(Marker expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
Target = Visit(expr.Target),
Attribute = Visit(expr.Attribute),
};
}
/// <inheritdoc/>
public override Expr VisitLeaf(TIR.IterVar expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
Dom = (TIR.Range)Visit(expr.Dom),
Value = (Var)Visit(expr.Value),
};
}
/// <inheritdoc/>
public override Expr VisitLeaf(TIR.Sequential expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
Fields = MutateArray(expr.Fields, Visit),
};
}
/// <inheritdoc/>
public override Expr VisitLeaf(TIR.For expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
LoopVar = (Var)Visit(expr.LoopVar),
Domain = (TIR.Range)Visit(expr.Domain),
Body = (TIR.Sequential)Visit(expr.Body),
};
}
/// <inheritdoc/>
public override Expr VisitLeaf(TIR.IfThenElse expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
Condition = Visit(expr.Condition),
Then = (TIR.Sequential)Visit(expr.Then),
Else = (TIR.Sequential)Visit(expr.Else),
};
}
/// <inheritdoc/>
public override Expr VisitLeaf(TIR.Block expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
// the block realize
InitBody = expr.InitBody.Fields.IsDefaultOrEmpty ? expr.InitBody : (TIR.Sequential)Visit(expr.InitBody),
Predicate = Visit(expr.Predicate),
IterVars = expr.IterVars.IsDefaultOrEmpty ? expr.IterVars : MutateArray(expr.IterVars, x => (TIR.IterVar)Visit(x)),
// the block internal.
Body = expr.Body.Fields.IsDefaultOrEmpty ? expr.Body : (TIR.Sequential)Visit(expr.Body),
Reads = expr.Reads.IsDefaultOrEmpty ? expr.Reads : MutateArray(expr.Reads, b => (TIR.BufferRegion)Visit(b)),
Writes = expr.Writes.IsDefaultOrEmpty ? expr.Writes : MutateArray(expr.Writes, b => (TIR.BufferRegion)Visit(b)),
AllocBuffers = expr.AllocBuffers.IsDefaultOrEmpty ? expr.AllocBuffers : MutateArray(expr.AllocBuffers, b => (TIR.Buffer)Visit(b)),
};
}
/// <inheritdoc/>
public override Expr VisitLeaf(TIR.BufferStore expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
Value = Visit(expr.Value),
Indices = MutateArray(expr.Indices, Visit),
};
}
/// <inheritdoc/>
public override Expr VisitLeaf(TIR.BufferLoad expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
Indices = MutateArray(expr.Indices, Visit),
};
}
/// <inheritdoc/>
public override Expr VisitLeaf(TIR.Let expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
Var = (Var)Visit(expr.Var),
Expression = Visit(expr.Expression),
Body = (TIR.Sequential)Visit(expr.Body),
};
}
/// <inheritdoc/>
public override Expr VisitLeaf(TIR.Buffer expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr;
}
/// <inheritdoc/>
public override Expr VisitLeaf(TIR.BufferRegion expr)
{
var nexpr = MutateLeaf(expr);
if (!object.ReferenceEquals(expr, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return expr;
}
return expr with
{
Buffer = (TIR.Buffer)Visit(expr.Buffer),
Region = MutateArray(expr.Region, rg => (TIR.Range)Visit(rg)),
};
}
/// <inheritdoc/>
public override object VisitLeaf(IVisitable visitable)
{
if (visitable is IMutatable mutatable)
{
var nexpr = MutateLeaf(mutatable);
if (!object.ReferenceEquals(mutatable, nexpr))
{
IsMutated = true;
return nexpr;
}
if (!IsMutated)
{
return mutatable;
}
return mutatable.WithNew(this);
}
throw new NotSupportedException($"IVisitable {visitable.GetType().Name} Is Not IMutatable!");
}
/// <summary>
/// defulat mutate leaf is not mutate.
/// </summary>
public virtual Expr DefaultMutateLeaf(Expr expr) => expr;
/// <summary>
/// default mutate leaf is not mutate.
/// </summary>
public virtual IMutatable DefaultMutateLeaf(IMutatable mutatable) => mutatable;
/// <summary>
/// mutate the call.
/// </summary>
public virtual Expr MutateLeaf(Call expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the if.
/// </summary>
public virtual Expr MutateLeaf(If expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the const.
/// </summary>
public virtual Expr MutateLeaf(Const expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the function.
/// </summary>
public virtual Expr MutateLeaf(Function expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the fusion.
/// </summary>
public virtual Expr MutateLeaf(Fusion expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the prim function wrapper.
/// </summary>
public virtual Expr MutateLeaf(PrimFunctionWrapper expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the prim function.
/// </summary>
public virtual Expr MutateLeaf(TIR.PrimFunction expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the op.
/// </summary>
public virtual Expr MutateLeaf(Op expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the tuple.
/// </summary>
public virtual Expr MutateLeaf(Tuple expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the var.
/// </summary>
public virtual Expr MutateLeaf(Var expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the var.
/// </summary>
public virtual Expr MutateLeaf(None expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the marker.
/// </summary>
public virtual Expr MutateLeaf(Marker expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the itervar.
/// </summary>
public virtual Expr MutateLeaf(TIR.IterVar expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the sequential.
/// </summary>
public virtual Expr MutateLeaf(TIR.Sequential expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the for.
/// </summary>
public virtual Expr MutateLeaf(TIR.For expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the for.
/// </summary>
public virtual Expr MutateLeaf(TIR.IfThenElse expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the block.
/// </summary>
public virtual Expr MutateLeaf(TIR.Block expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the bufferstore.
/// </summary>
public virtual Expr MutateLeaf(TIR.BufferStore expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the buffer load.
/// </summary>
public virtual Expr MutateLeaf(TIR.BufferLoad expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the let.
/// </summary>
/// <param name="expr">let expr.</param>
/// <returns>new expr.</returns>
public virtual Expr MutateLeaf(TIR.Let expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the memref.
/// </summary>
/// <param name="expr">new memref.</param>
/// <returns>new expr.</returns>
public virtual Expr MutateLeaf(TIR.Buffer expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the buffer region.
/// </summary>
/// <param name="expr">new memref.</param>
/// <returns>new expr.</returns>
public virtual Expr MutateLeaf(TIR.BufferRegion expr) => DefaultMutateLeaf(expr);
/// <summary>
/// mutate the imutatable.
/// </summary>
/// <param name="mutatable">IMutatable instance.</param>
/// <returns>new expr.</returns>
public virtual IMutatable MutateLeaf(IMutatable mutatable) => DefaultMutateLeaf(mutatable);
/// <summary>
/// Mutate IRArray.
/// </summary>
public virtual IRArray<TResult> MutateArray<TInput, TResult>(IRArray<TInput> array, Func<TInput, TResult> visitor)
{
return new(array.Select(visitor));
}
/// <summary>
/// fold the expr by struct comparer.
/// </summary>
public virtual Expr StructEqualFolding(Expr expr)
{
if (!_exprSEqualMemo.TryGetValue(expr, out var folded))
{
folded = expr;
_exprSEqualMemo.Add(expr, folded);
}
return folded;
}
}
/// <summary>
/// NOTE the mutator only visit the only one basefunction and skip other basefunction.
/// </summary>
public abstract class ExprMutator : DeepExprMutator
{
private BaseFunction? _entryBaseFunc;
/// <inheritdoc/>
public override Expr Visit(BaseFunction baseFunction)
{
if (_entryBaseFunc is null)
{
_entryBaseFunc = baseFunction;
}
return base.Visit(baseFunction);
}
/// <inheritdoc/>
public override Expr Visit(Fusion expr)
{
if (_entryBaseFunc is null)
{
_entryBaseFunc = expr;
return base.Visit(expr);
}
return object.ReferenceEquals(_entryBaseFunc, expr) ? base.Visit(expr) : expr;
}
/// <inheritdoc/>
public override Expr Visit(Function expr)
{
if (_entryBaseFunc is null)
{
_entryBaseFunc = expr;
return base.Visit(expr);
}
return object.ReferenceEquals(_entryBaseFunc, expr) ? base.Visit(expr) : expr;
}
/// <inheritdoc/>
public override Expr Visit(PrimFunctionWrapper expr)
{
if (_entryBaseFunc is null)
{
_entryBaseFunc = expr;
return base.Visit(expr);
}
return object.ReferenceEquals(_entryBaseFunc, expr) ? base.Visit(expr) : expr;
}
/// <inheritdoc/>
public override Expr Visit(TIR.PrimFunction expr)
{
if (_entryBaseFunc is null)
{
_entryBaseFunc = expr;
return base.Visit(expr);
}
return object.ReferenceEquals(_entryBaseFunc, expr) ? base.Visit(expr) : expr;
}
}

View File

@ -0,0 +1,51 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Nncase.IR;
public sealed class ExprPinner : IDisposable
{
private static readonly ExprUser _user = new();
private readonly Expr[] _exprs;
private bool _disposed;
public ExprPinner(params Expr[] exprs)
{
_exprs = exprs;
foreach (var expr in _exprs)
{
expr.AddUser(_user);
}
}
public void Dispose()
{
if (!_disposed)
{
foreach (var expr in _exprs)
{
expr.RemoveUser(_user);
}
_disposed = true;
}
}
private sealed class ExprUser : Expr
{
public ExprUser()
: base(Array.Empty<Expr>())
{
}
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context) => throw new NotSupportedException();
}
}

View File

@ -0,0 +1,109 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reactive;
using System.Text;
using System.Threading.Tasks;
using TorchSharp.Modules;
namespace Nncase.IR;
/// <summary>
/// Expression rewriter.
/// </summary>
/// <typeparam name="TContext">Rewrite context.</typeparam>
public abstract partial class ExprRewriter<TContext> : ExprVisitor<Expr, IRType, TContext>
{
/// <summary>
/// Initializes a new instance of the <see cref="ExprRewriter{TContext}"/> class.
/// </summary>
/// <param name="visitOtherFunctions">Vist other functions.</param>
public ExprRewriter(bool visitOtherFunctions = false)
: base(visitOtherFunctions)
{
}
/// <summary>
/// Gets a value indicating whether expression is mutated.
/// </summary>
public bool IsMutated { get; private set; }
/// <summary>
/// Rewrite expression.
/// </summary>
/// <param name="expr">Expression to rewrite.</param>
/// <param name="context">Context.</param>
/// <returns>Rewritten expression.</returns>
public Expr Rewrite(Expr expr, TContext context)
{
var newExpr = Visit(expr, context);
DCE(newExpr);
return newExpr;
}
/// <summary>
/// Default rewrite leaf routine.
/// </summary>
protected virtual Expr DefaultRewriteLeaf(Expr expr, TContext context) => expr;
protected void SetMutated() => IsMutated = true;
protected override void VisitOperands(Expr expr, TContext context)
{
var operands = expr.Operands;
for (int i = 0; i < operands.Length; i++)
{
var operand = operands[i];
var newOperand = Visit(operand, context);
if (!ReferenceEquals(operand, newOperand))
{
expr.ReplaceOperand(i, newOperand);
SetMutated();
}
}
}
private void DCE(Expr root)
{
using var exprPin = new ExprPinner(root);
foreach (var expr in ExprMemo)
{
expr.Key.DisposeIfNoUsers();
expr.Value.DisposeIfNoUsers();
}
}
}
/// <summary>
/// Expression rewriter.
/// </summary>
public abstract partial class ExprRewriter : ExprRewriter<Unit>
{
/// <summary>
/// Initializes a new instance of the <see cref="ExprRewriter"/> class.
/// </summary>
/// <param name="visitOtherFunctions">Vist other functions.</param>
protected ExprRewriter(bool visitOtherFunctions = false)
: base(visitOtherFunctions)
{
}
/// <summary>
/// Rewrite expression.
/// </summary>
/// <param name="expr">Expression to rewrite.</param>
/// <returns>Rewritten expression.</returns>
public Expr Rewrite(Expr expr) => Rewrite(expr, default);
/// <summary>
/// Default rewrite leaf routine.
/// </summary>
protected virtual Expr DefaultRewriteLeaf(Expr expr) => base.DefaultRewriteLeaf(expr, default);
/// <inheritdoc/>
protected sealed override Expr DefaultRewriteLeaf(Expr expr, Unit context) => DefaultRewriteLeaf(expr);
}

View File

@ -0,0 +1,553 @@

//---------------------------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by T4 template.
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
// </auto-generated>
//---------------------------------------------------------------------------------------------------
using System;
using System.Collections.Generic;
using System.Reactive;
namespace Nncase.IR;
public partial class ExprRewriter<TContext>
{
/// <inheritdoc/>
protected sealed override Expr VisitLeafBaseFunction(BaseFunction expr, TContext context)
{
return RewriteLeafBaseFunction(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafCall(Call expr, TContext context)
{
return RewriteLeafCall(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafConst(Const expr, TContext context)
{
return RewriteLeafConst(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafFunction(Function expr, TContext context)
{
return RewriteLeafFunction(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafFusion(Fusion expr, TContext context)
{
return RewriteLeafFusion(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafIf(If expr, TContext context)
{
return RewriteLeafIf(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafMarker(Marker expr, TContext context)
{
return RewriteLeafMarker(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafNone(None expr, TContext context)
{
return RewriteLeafNone(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafOp(Op expr, TContext context)
{
return RewriteLeafOp(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, TContext context)
{
return RewriteLeafPrimFunctionWrapper(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafTensorConst(TensorConst expr, TContext context)
{
return RewriteLeafTensorConst(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafTuple(IR.Tuple expr, TContext context)
{
return RewriteLeafTuple(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafTupleConst(TupleConst expr, TContext context)
{
return RewriteLeafTupleConst(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafVar(Var expr, TContext context)
{
return RewriteLeafVar(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafBlock(TIR.Block expr, TContext context)
{
return RewriteLeafBlock(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafBuffer(TIR.Buffer expr, TContext context)
{
return RewriteLeafBuffer(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context)
{
return RewriteLeafLogicalBuffer(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context)
{
return RewriteLeafPhysicalBuffer(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context)
{
return RewriteLeafBufferLoad(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context)
{
return RewriteLeafBufferRegion(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafBufferStore(TIR.BufferStore expr, TContext context)
{
return RewriteLeafBufferStore(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafFor(TIR.For expr, TContext context)
{
return RewriteLeafFor(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafIfThenElse(TIR.IfThenElse expr, TContext context)
{
return RewriteLeafIfThenElse(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafLet(TIR.Let expr, TContext context)
{
return RewriteLeafLet(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafPrimFunction(TIR.PrimFunction expr, TContext context)
{
return RewriteLeafPrimFunction(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafSequential(TIR.Sequential expr, TContext context)
{
return RewriteLeafSequential(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafRange(TIR.Range expr, TContext context)
{
return RewriteLeafRange(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafIterVar(TIR.IterVar expr, TContext context)
{
return RewriteLeafIterVar(expr, context);
}
/// <summary>
/// Rewrite leaf <see cref="BaseFunction"/>.
/// </summary>
protected virtual Expr RewriteLeafBaseFunction(BaseFunction expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="Call"/>.
/// </summary>
protected virtual Expr RewriteLeafCall(Call expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="Const"/>.
/// </summary>
protected virtual Expr RewriteLeafConst(Const expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="Function"/>.
/// </summary>
protected virtual Expr RewriteLeafFunction(Function expr, TContext context) => RewriteLeafBaseFunction(expr, context);
/// <summary>
/// Rewrite leaf <see cref="Fusion"/>.
/// </summary>
protected virtual Expr RewriteLeafFusion(Fusion expr, TContext context) => RewriteLeafBaseFunction(expr, context);
/// <summary>
/// Rewrite leaf <see cref="If"/>.
/// </summary>
protected virtual Expr RewriteLeafIf(If expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="Marker"/>.
/// </summary>
protected virtual Expr RewriteLeafMarker(Marker expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="None"/>.
/// </summary>
protected virtual Expr RewriteLeafNone(None expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="Op"/>.
/// </summary>
protected virtual Expr RewriteLeafOp(Op expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="PrimFunctionWrapper"/>.
/// </summary>
protected virtual Expr RewriteLeafPrimFunctionWrapper(PrimFunctionWrapper expr, TContext context) => RewriteLeafBaseFunction(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TensorConst"/>.
/// </summary>
protected virtual Expr RewriteLeafTensorConst(TensorConst expr, TContext context) => RewriteLeafConst(expr, context);
/// <summary>
/// Rewrite leaf <see cref="IR.Tuple"/>.
/// </summary>
protected virtual Expr RewriteLeafTuple(IR.Tuple expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TupleConst"/>.
/// </summary>
protected virtual Expr RewriteLeafTupleConst(TupleConst expr, TContext context) => RewriteLeafConst(expr, context);
/// <summary>
/// Rewrite leaf <see cref="Var"/>.
/// </summary>
protected virtual Expr RewriteLeafVar(Var expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.Block"/>.
/// </summary>
protected virtual Expr RewriteLeafBlock(TIR.Block expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.Buffer"/>.
/// </summary>
protected virtual Expr RewriteLeafBuffer(TIR.Buffer expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.LogicalBuffer"/>.
/// </summary>
protected virtual Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => RewriteLeafBuffer(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.PhysicalBuffer"/>.
/// </summary>
protected virtual Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => RewriteLeafBuffer(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.BufferLoad"/>.
/// </summary>
protected virtual Expr RewriteLeafBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.BufferRegion"/>.
/// </summary>
protected virtual Expr RewriteLeafBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.BufferStore"/>.
/// </summary>
protected virtual Expr RewriteLeafBufferStore(TIR.BufferStore expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.For"/>.
/// </summary>
protected virtual Expr RewriteLeafFor(TIR.For expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.IfThenElse"/>.
/// </summary>
protected virtual Expr RewriteLeafIfThenElse(TIR.IfThenElse expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.Let"/>.
/// </summary>
protected virtual Expr RewriteLeafLet(TIR.Let expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.PrimFunction"/>.
/// </summary>
protected virtual Expr RewriteLeafPrimFunction(TIR.PrimFunction expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.Sequential"/>.
/// </summary>
protected virtual Expr RewriteLeafSequential(TIR.Sequential expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.Range"/>.
/// </summary>
protected virtual Expr RewriteLeafRange(TIR.Range expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.IterVar"/>.
/// </summary>
protected virtual Expr RewriteLeafIterVar(TIR.IterVar expr, TContext context) => DefaultRewriteLeaf(expr, context);
}
public partial class ExprRewriter
{
/// <summary>
/// Rewrite leaf <see cref="BaseFunction"/>.
/// </summary>
protected virtual Expr RewriteLeafBaseFunction(BaseFunction expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafBaseFunction(BaseFunction expr, Unit context) => RewriteLeafBaseFunction(expr);
/// <summary>
/// Rewrite leaf <see cref="Call"/>.
/// </summary>
protected virtual Expr RewriteLeafCall(Call expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafCall(Call expr, Unit context) => RewriteLeafCall(expr);
/// <summary>
/// Rewrite leaf <see cref="Const"/>.
/// </summary>
protected virtual Expr RewriteLeafConst(Const expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafConst(Const expr, Unit context) => RewriteLeafConst(expr);
/// <summary>
/// Rewrite leaf <see cref="Function"/>.
/// </summary>
protected virtual Expr RewriteLeafFunction(Function expr) => RewriteLeafBaseFunction(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafFunction(Function expr, Unit context) => RewriteLeafFunction(expr);
/// <summary>
/// Rewrite leaf <see cref="Fusion"/>.
/// </summary>
protected virtual Expr RewriteLeafFusion(Fusion expr) => RewriteLeafBaseFunction(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafFusion(Fusion expr, Unit context) => RewriteLeafFusion(expr);
/// <summary>
/// Rewrite leaf <see cref="If"/>.
/// </summary>
protected virtual Expr RewriteLeafIf(If expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafIf(If expr, Unit context) => RewriteLeafIf(expr);
/// <summary>
/// Rewrite leaf <see cref="Marker"/>.
/// </summary>
protected virtual Expr RewriteLeafMarker(Marker expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafMarker(Marker expr, Unit context) => RewriteLeafMarker(expr);
/// <summary>
/// Rewrite leaf <see cref="None"/>.
/// </summary>
protected virtual Expr RewriteLeafNone(None expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafNone(None expr, Unit context) => RewriteLeafNone(expr);
/// <summary>
/// Rewrite leaf <see cref="Op"/>.
/// </summary>
protected virtual Expr RewriteLeafOp(Op expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafOp(Op expr, Unit context) => RewriteLeafOp(expr);
/// <summary>
/// Rewrite leaf <see cref="PrimFunctionWrapper"/>.
/// </summary>
protected virtual Expr RewriteLeafPrimFunctionWrapper(PrimFunctionWrapper expr) => RewriteLeafBaseFunction(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => RewriteLeafPrimFunctionWrapper(expr);
/// <summary>
/// Rewrite leaf <see cref="TensorConst"/>.
/// </summary>
protected virtual Expr RewriteLeafTensorConst(TensorConst expr) => RewriteLeafConst(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafTensorConst(TensorConst expr, Unit context) => RewriteLeafTensorConst(expr);
/// <summary>
/// Rewrite leaf <see cref="IR.Tuple"/>.
/// </summary>
protected virtual Expr RewriteLeafTuple(IR.Tuple expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafTuple(IR.Tuple expr, Unit context) => RewriteLeafTuple(expr);
/// <summary>
/// Rewrite leaf <see cref="TupleConst"/>.
/// </summary>
protected virtual Expr RewriteLeafTupleConst(TupleConst expr) => RewriteLeafConst(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafTupleConst(TupleConst expr, Unit context) => RewriteLeafTupleConst(expr);
/// <summary>
/// Rewrite leaf <see cref="Var"/>.
/// </summary>
protected virtual Expr RewriteLeafVar(Var expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafVar(Var expr, Unit context) => RewriteLeafVar(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.Block"/>.
/// </summary>
protected virtual Expr RewriteLeafBlock(TIR.Block expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafBlock(TIR.Block expr, Unit context) => RewriteLeafBlock(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.Buffer"/>.
/// </summary>
protected virtual Expr RewriteLeafBuffer(TIR.Buffer expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafBuffer(TIR.Buffer expr, Unit context) => RewriteLeafBuffer(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.LogicalBuffer"/>.
/// </summary>
protected virtual Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr) => RewriteLeafBuffer(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => RewriteLeafLogicalBuffer(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.PhysicalBuffer"/>.
/// </summary>
protected virtual Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr) => RewriteLeafBuffer(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => RewriteLeafPhysicalBuffer(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.BufferLoad"/>.
/// </summary>
protected virtual Expr RewriteLeafBufferLoad(TIR.BufferLoad expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafBufferLoad(TIR.BufferLoad expr, Unit context) => RewriteLeafBufferLoad(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.BufferRegion"/>.
/// </summary>
protected virtual Expr RewriteLeafBufferRegion(TIR.BufferRegion expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafBufferRegion(TIR.BufferRegion expr, Unit context) => RewriteLeafBufferRegion(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.BufferStore"/>.
/// </summary>
protected virtual Expr RewriteLeafBufferStore(TIR.BufferStore expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafBufferStore(TIR.BufferStore expr, Unit context) => RewriteLeafBufferStore(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.For"/>.
/// </summary>
protected virtual Expr RewriteLeafFor(TIR.For expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafFor(TIR.For expr, Unit context) => RewriteLeafFor(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.IfThenElse"/>.
/// </summary>
protected virtual Expr RewriteLeafIfThenElse(TIR.IfThenElse expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafIfThenElse(TIR.IfThenElse expr, Unit context) => RewriteLeafIfThenElse(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.Let"/>.
/// </summary>
protected virtual Expr RewriteLeafLet(TIR.Let expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafLet(TIR.Let expr, Unit context) => RewriteLeafLet(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.PrimFunction"/>.
/// </summary>
protected virtual Expr RewriteLeafPrimFunction(TIR.PrimFunction expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafPrimFunction(TIR.PrimFunction expr, Unit context) => RewriteLeafPrimFunction(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.Sequential"/>.
/// </summary>
protected virtual Expr RewriteLeafSequential(TIR.Sequential expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafSequential(TIR.Sequential expr, Unit context) => RewriteLeafSequential(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.Range"/>.
/// </summary>
protected virtual Expr RewriteLeafRange(TIR.Range expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafRange(TIR.Range expr, Unit context) => RewriteLeafRange(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.IterVar"/>.
/// </summary>
protected virtual Expr RewriteLeafIterVar(TIR.IterVar expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafIterVar(TIR.IterVar expr, Unit context) => RewriteLeafIterVar(expr);
}

View File

@ -0,0 +1,70 @@
<#@ template debug="false" hostspecific="false" language="C#" #>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.IO" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ output extension=".cs" #>
<#@ include file="IRListParser.tt"#>
//---------------------------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by T4 template.
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
// </auto-generated>
//---------------------------------------------------------------------------------------------------
using System;
using System.Collections.Generic;
using System.Reactive;
namespace Nncase.IR;
public partial class ExprRewriter<TContext>
{
<#
foreach (var ir in irs)
{
#>
/// <inheritdoc/>
protected sealed override Expr VisitLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, TContext context)
{
return RewriteLeaf<#=ir.Name#>(expr, context);
}
<#
}
#>
<#
foreach (var ir in irs)
{
var func = ir.VisitBase == "Default" ? "DefaultRewriteLeaf" : $"RewriteLeaf{ir.VisitBase}";
#>
/// <summary>
/// Rewrite leaf <see cref="<#=ir.Namespace#><#=ir.Name#>"/>.
/// </summary>
protected virtual Expr RewriteLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, TContext context) => <#=func#>(expr, context);
<#
}
#>
}
public partial class ExprRewriter
{
<#
foreach (var ir in irs)
{
var func = ir.VisitBase == "Default" ? "DefaultRewriteLeaf" : $"RewriteLeaf{ir.VisitBase}";
#>
/// <summary>
/// Rewrite leaf <see cref="<#=ir.Namespace#><#=ir.Name#>"/>.
/// </summary>
protected virtual Expr RewriteLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr) => <#=func#>(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, Unit context) => RewriteLeaf<#=ir.Name#>(expr);
<#
}
#>
}

View File

@ -3,765 +3,262 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reactive;
using System.Text;
using System.Threading.Tasks;
namespace Nncase.IR
namespace Nncase.IR;
/// <summary>
/// Expression visitor.
/// </summary>
/// <typeparam name="TExprResult">Expression visit result type.</typeparam>
/// <typeparam name="TTypeResult">Type visit result type.</typeparam>
/// <typeparam name="TContext">Visit context.</typeparam>
public abstract partial class ExprVisitor<TExprResult, TTypeResult, TContext> : ExprFunctor<TExprResult, TTypeResult, TContext>
{
private readonly bool _visitOtherFunctions;
/// <summary>
/// Expression visitor.
/// Initializes a new instance of the <see cref="ExprVisitor{TExprResult, TTypeResult, TContext}"/> class.
/// </summary>
/// <typeparam name="TExprResult">Expression visit result type.</typeparam>
/// <typeparam name="TTypeResult">Type visit result type.</typeparam>
public abstract class ExprVisitor<TExprResult, TTypeResult> : ExprFunctor<TExprResult, TTypeResult>
/// <param name="visitOtherFunctions">Vist other functions.</param>
public ExprVisitor(bool visitOtherFunctions = false)
{
private readonly Dictionary<Expr, TExprResult> _exprMemo = new Dictionary<Expr, TExprResult>(ReferenceEqualityComparer.Instance);
private readonly Dictionary<IVisitable, object> _visitableMemo = new Dictionary<IVisitable, object>(ReferenceEqualityComparer.Instance);
private readonly Dictionary<IRType, TTypeResult> _typeMemo = new Dictionary<IRType, TTypeResult>();
private readonly Dictionary<string, Action<Expr>> _callbacksAfterCall = new();
private readonly Dictionary<string, Action<Expr>> _callbacksBeforeCall = new();
_visitOtherFunctions = visitOtherFunctions;
}
/// <summary>
/// Gets expression visit result memo.
/// </summary>
public Dictionary<Expr, TExprResult> ExpressionMemo => _exprMemo;
/// <summary>
/// Gets expression memo.
/// </summary>
public Dictionary<Expr, TExprResult> ExprMemo { get; } = new(ReferenceEqualityComparer.Instance);
/// <summary>
/// Gets visitable visit result memo.
/// </summary>
public Dictionary<IVisitable, object> VisitAbleMemo => _visitableMemo;
/// <summary>
/// Gets type memo.
/// </summary>
public Dictionary<IRType, TTypeResult> TypeMemo { get; } = new(ReferenceEqualityComparer.Instance);
/// <inheritdoc/>
public override TExprResult Visit(Call expr)
/// <inheritdoc/>
public override TTypeResult VisitType(AnyType type, TContext context)
{
if (HasVisited(type, out var result))
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
Visit(expr.Target);
foreach (var param in expr.Parameters)
{
Visit(param);
}
CallbacksBeforeCall(expr);
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
CallbacksAfterCall(expr);
}
return result;
}
/// <inheritdoc />
public override TExprResult Visit(If expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
Visit(expr.Condition);
Visit(expr.Then);
Visit(expr.Else);
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return MarkVisited(type, VisitTypeLeaf(type, context));
}
/// <inheritdoc/>
public override TTypeResult VisitType(CallableType type, TContext context)
{
if (HasVisited(type, out var result))
{
return result;
}
/// <inheritdoc/>
public sealed override TExprResult Visit(Const expr)
foreach (var param in type.Parameters)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
VisitType(param, context);
}
VisitType(type.ReturnType, context);
return MarkVisited(type, VisitTypeLeaf(type, context));
}
/// <inheritdoc/>
public override TTypeResult VisitType(InvalidType type, TContext context)
{
if (HasVisited(type, out var result))
{
return result;
}
/// <inheritdoc/>
public override TExprResult Visit(Function expr)
return MarkVisited(type, VisitTypeLeaf(type, context));
}
/// <inheritdoc/>
public override TTypeResult VisitType(TensorType type, TContext context)
{
if (HasVisited(type, out var result))
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
foreach (var param in expr.Parameters)
{
Visit(param);
}
Visit(expr.Body);
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public override TExprResult Visit(Fusion expr)
return MarkVisited(type, VisitTypeLeaf(type, context));
}
/// <inheritdoc/>
public override TTypeResult VisitType(TupleType type, TContext context)
{
if (HasVisited(type, out var result))
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
foreach (var param in expr.Parameters)
{
Visit(param);
}
Visit(expr.Body);
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public override TExprResult Visit(PrimFunctionWrapper expr)
foreach (var field in type.Fields)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
Visit(expr.Target);
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
VisitType(field, context);
}
return MarkVisited(type, VisitTypeLeaf(type, context));
}
/// <summary>
/// Visit any type leaf.
/// </summary>
public virtual TTypeResult VisitTypeLeaf(AnyType type, TContext context) => DefaultVisitTypeLeaf(type, context);
/// <summary>
/// Visit invalid type leaf.
/// </summary>
public virtual TTypeResult VisitTypeLeaf(InvalidType type, TContext context) => DefaultVisitTypeLeaf(type, context);
/// <summary>
/// Visit tensor type leaf.
/// </summary>
public virtual TTypeResult VisitTypeLeaf(TensorType type, TContext context) => DefaultVisitTypeLeaf(type, context);
/// <summary>
/// Visit tuple type leaf.
/// </summary>
public virtual TTypeResult VisitTypeLeaf(TupleType type, TContext context) => DefaultVisitTypeLeaf(type, context);
/// <summary>
/// Visit tuple type leaf.
/// </summary>
public virtual TTypeResult VisitTypeLeaf(CallableType type, TContext context) => DefaultVisitTypeLeaf(type, context);
/// <summary>
/// Default visit leaf routine.
/// </summary>
public virtual TTypeResult DefaultVisitTypeLeaf(IRType type, TContext context)
{
throw new NotImplementedException($"Unhandled visit leaf routine for {type.GetType()}.");
}
/// <inheritdoc/>
public override void Clear()
{
ExprMemo.Clear();
TypeMemo.Clear();
base.Clear();
}
/// <summary>
/// Whether this expression is not visited before.
/// </summary>
protected bool HasVisited(Expr expr, [MaybeNullWhen(false)] out TExprResult result)
=> ExprMemo.TryGetValue(expr, out result);
/// <summary>
/// Whether this type is not visited before.
/// </summary>
protected bool HasVisited(IRType type, [MaybeNullWhen(false)] out TTypeResult result)
=> TypeMemo.TryGetValue(type, out result);
/// <summary>
/// Mark expression is visited.
/// </summary>
/// <param name="expr">Expression to visit.</param>
/// <param name="result">Visit result.</param>
protected TExprResult MarkVisited(Expr expr, TExprResult result)
{
ExprMemo[expr] = result;
return result;
}
/// <summary>
/// Mark type is visited.
/// </summary>
/// <param name="type">Type to visit.</param>
/// <param name="result">Visit result.</param>
protected TTypeResult MarkVisited(IRType type, TTypeResult result)
{
TypeMemo[type] = result;
return result;
}
protected bool CanVisitFunctionBody(BaseFunction baseFunction)
{
if (_visitOtherFunctions)
{
return true;
}
return ReferenceEquals(baseFunction, VisitRoot);
}
/// <summary>
/// Default leaf visit routine.
/// </summary>
protected virtual TExprResult DefaultVisitLeaf(Expr expr, TContext context)
{
throw new NotImplementedException($"Unhandled visit leaf routine for {expr.GetType()}.");
}
/// <inheritdoc/>
protected override TExprResult DispatchVisit(Expr expr, TContext context)
{
if (HasVisited(expr, out var result))
{
return result;
}
/// <inheritdoc/>
public override TExprResult Visit(TIR.PrimFunction expr)
return MarkVisited(expr, base.DispatchVisit(expr, context));
}
protected virtual void VisitOperands(Expr expr, TContext context)
{
foreach (var operand in expr.Operands)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
foreach (var param in expr.Parameters)
{
Visit(param);
}
Visit(expr.Body);
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public sealed override TExprResult Visit(Op expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public sealed override TExprResult Visit(Tuple expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
foreach (var field in expr.Fields)
{
Visit(field);
}
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public override TExprResult Visit(Var expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public override TExprResult Visit(None expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public override TExprResult Visit(Marker expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
Visit(expr.Target);
Visit(expr.Attribute);
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public sealed override TExprResult Visit(TIR.IterVar expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
Visit(expr.Value);
Visit(expr.Dom);
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public sealed override TExprResult Visit(TIR.Sequential expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
foreach (var item in expr.Fields)
{
Visit(item);
}
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public override TExprResult Visit(TIR.For expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
Visit(expr.LoopVar);
Visit(expr.Domain);
Visit(expr.Body);
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public sealed override TExprResult Visit(TIR.Block expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
Visit(expr.Body);
Visit(expr.InitBody);
foreach (var iterVar in expr.IterVars)
{
Visit(iterVar);
}
foreach (var reads in expr.Reads)
{
Visit(reads);
}
foreach (var writes in expr.Writes)
{
Visit(writes);
}
foreach (var buffer in expr.AllocBuffers)
{
Visit(buffer);
}
Visit(expr.Predicate);
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public sealed override TExprResult Visit(TIR.BufferLoad expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
foreach (var index in expr.Indices)
{
Visit(index);
}
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public sealed override TExprResult Visit(TIR.BufferStore expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
Visit(expr.Buffer);
foreach (var index in expr.Indices)
{
Visit(index);
}
Visit(expr.Value);
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public sealed override TExprResult Visit(TIR.IfThenElse expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
Visit(expr.Condition);
Visit(expr.Then);
Visit(expr.Else);
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public override TExprResult Visit(TIR.Let expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
Visit(expr.Var);
Visit(expr.Expression);
Visit(expr.Body);
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public override TExprResult Visit(TIR.Buffer expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public override TExprResult Visit(TIR.BufferRegion expr)
{
if (!_exprMemo.TryGetValue(expr, out var result))
{
Visit(expr.Buffer);
foreach (var param in expr.Region)
{
Visit(param);
}
result = VisitLeaf(expr);
_exprMemo.Add(expr, result);
}
return result;
}
/// <summary>
/// visit ivisitable.
/// </summary>
public override object Visit(IVisitable visitable)
{
if (!_visitableMemo.TryGetValue(visitable, out var result))
{
visitable.Visit<TExprResult, TTypeResult>(this);
result = VisitLeaf(visitable);
_visitableMemo.Add(visitable, result);
}
return result;
}
/// <summary>
/// Visit expression.
/// </summary>
/// <param name="expr">Expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(Expr expr)
{
return expr switch
{
Var var => VisitLeaf(var),
Const con => VisitLeaf(con),
Function func => VisitLeaf(func),
Fusion fusion => VisitLeaf(fusion),
Call call => VisitLeaf(call),
If @if => VisitLeaf(@if),
Tuple tuple => VisitLeaf(tuple),
Op op => VisitLeaf(op),
None none => VisitLeaf(none),
TIR.Sequential seq => VisitLeaf(seq),
TIR.For @for => VisitLeaf(@for),
TIR.Block block => VisitLeaf(block),
TIR.BufferLoad bufload => VisitLeaf(bufload),
TIR.BufferStore bufstore => VisitLeaf(bufstore),
TIR.IfThenElse ift => VisitLeaf(ift),
TIR.Let let => VisitLeaf(let),
TIR.Buffer buffer => VisitLeaf(buffer),
_ => DefaultVisitLeaf(expr),
};
}
/// <summary>
/// Visit leaf variable expression.
/// </summary>
/// <param name="expr">Variable expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(Var expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf constant expression.
/// </summary>
/// <param name="expr">Constant expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(Const expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf function expression.
/// </summary>
/// <param name="expr">Function expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(Function expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf fusion expression.
/// </summary>
/// <param name="expr">Fusion expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(Fusion expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf prim function wrapper expression.
/// </summary>
/// <param name="expr">Function expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(PrimFunctionWrapper expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf prim function expression.
/// </summary>
/// <param name="expr">PrimFunction expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(TIR.PrimFunction expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf call expression.
/// </summary>
/// <param name="expr">Call expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(Call expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf if expression.
/// </summary>
/// <param name="expr">If expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(If expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf tuple expression.
/// </summary>
/// <param name="expr">Variable expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(Tuple expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf operator expression.
/// </summary>
/// <param name="expr">Operator expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(Op expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf None expression.
/// </summary>
/// <param name="expr">None expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(None expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf marker expression.
/// </summary>
/// <param name="expr">None expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(Marker expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf IterVar expression.
/// </summary>
/// <param name="expr">IterVar expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(TIR.IterVar expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf sequential expression.
/// </summary>
/// <param name="expr">sequential expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(TIR.Sequential expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf For expression.
/// </summary>
/// <param name="expr">For expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(TIR.For expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf Block expression.
/// </summary>
/// <param name="expr">Block expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(TIR.Block expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf BufferLoad expression.
/// </summary>
/// <param name="expr">BufferLoad expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(TIR.BufferLoad expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf BufferRead expression.
/// </summary>
/// <param name="expr">BufferRead expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(TIR.BufferStore expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf IfThenElse expression.
/// </summary>
/// <param name="expr">IfThenElse expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(TIR.IfThenElse expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf Let expression.
/// </summary>
/// <param name="expr">Let expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(TIR.Let expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf MemRef expression.
/// </summary>
/// <param name="expr">MemRef expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(TIR.Buffer expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf buffer region expression.
/// </summary>
/// <param name="expr">buffer region expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult VisitLeaf(TIR.BufferRegion expr) => DefaultVisitLeaf(expr);
/// <summary>
/// Visit leaf ifunctable.
/// </summary>
public virtual object VisitLeaf(IVisitable visitable) => DefaultVisitLeaf(visitable);
/// <summary>
/// Default leaf visit routine.
/// </summary>
public virtual object DefaultVisitLeaf(IVisitable visitable)
{
throw new NotImplementedException($"Unhandled visit leaf routine for {visitable.GetType()}.");
}
/// <summary>
/// Default leaf visit routine.
/// </summary>
/// <param name="expr">Expression.</param>
/// <returns>Result.</returns>
public virtual TExprResult DefaultVisitLeaf(Expr expr)
{
throw new NotImplementedException($"Unhandled visit leaf routine for {expr.GetType()}.");
}
/// <inheritdoc/>
public sealed override TTypeResult VisitType(AnyType type)
{
if (!_typeMemo.TryGetValue(type, out var result))
{
result = VisitTypeLeaf(type);
_typeMemo.Add(type, result);
}
return result;
}
/// <inheritdoc/>
public sealed override TTypeResult VisitType(CallableType type)
{
if (!_typeMemo.TryGetValue(type, out var result))
{
foreach (var param in type.Parameters)
{
VisitType(param);
}
VisitType(type.ReturnType);
result = VisitTypeLeaf(type);
_typeMemo.Add(type, result);
}
return result;
}
/// <inheritdoc/>
public sealed override TTypeResult VisitType(InvalidType type)
{
if (!_typeMemo.TryGetValue(type, out var result))
{
result = VisitTypeLeaf(type);
_typeMemo.Add(type, result);
}
return result;
}
/// <inheritdoc/>
public sealed override TTypeResult VisitType(TensorType type)
{
if (!_typeMemo.TryGetValue(type, out var result))
{
result = VisitTypeLeaf(type);
_typeMemo.Add(type, result);
}
return result;
}
/// <inheritdoc/>
public sealed override TTypeResult VisitType(TupleType type)
{
if (!_typeMemo.TryGetValue(type, out var result))
{
foreach (var field in type.Fields)
{
VisitType(field);
}
result = VisitTypeLeaf(type);
_typeMemo.Add(type, result);
}
return result;
}
/// <summary>
/// Visit any type leaf.
/// </summary>
/// <param name="type">Any type.</param>
/// <returns>Result.</returns>
public virtual TTypeResult VisitTypeLeaf(AnyType type) => DefaultVisitTypeLeaf(type);
/// <summary>
/// Visit invalid type leaf.
/// </summary>
/// <param name="type">Invalid type.</param>
/// <returns>Result.</returns>
public virtual TTypeResult VisitTypeLeaf(InvalidType type) => DefaultVisitTypeLeaf(type);
/// <summary>
/// Visit tensor type leaf.
/// </summary>
/// <param name="type">Tensor type.</param>
/// <returns>Result.</returns>
public virtual TTypeResult VisitTypeLeaf(TensorType type) => DefaultVisitTypeLeaf(type);
/// <summary>
/// Visit tuple type leaf.
/// </summary>
/// <param name="type">Tuple type.</param>
/// <returns>Result.</returns>
public virtual TTypeResult VisitTypeLeaf(TupleType type) => DefaultVisitTypeLeaf(type);
/// <summary>
/// Visit tuple type leaf.
/// </summary>
/// <param name="type">Callable type.</param>
/// <returns>Result.</returns>
public virtual TTypeResult VisitTypeLeaf(CallableType type) => DefaultVisitTypeLeaf(type);
/// <summary>
/// Default visit leaf routine.
/// </summary>
/// <param name="type">Type.</param>
/// <returns>Result.</returns>
public virtual TTypeResult DefaultVisitTypeLeaf(IRType type)
{
throw new NotImplementedException($"Unhandled visit leaf routine for {type.GetType()}.");
}
/// <summary>
/// clear the Memo!.
/// </summary>
public virtual void Clear()
{
_exprMemo.Clear();
_typeMemo.Clear();
}
protected void RegisterAfterCallback(string name, Action<Expr> callback)
{
_callbacksAfterCall[name] = callback;
}
protected void RegisterBeforeCallback(string name, Action<Expr> callback)
{
_callbacksBeforeCall[name] = callback;
}
private void CallbacksBeforeCall(Expr expr)
{
foreach (var (name, callback) in _callbacksBeforeCall)
{
callback(expr);
}
}
private void CallbacksAfterCall(Expr expr)
{
foreach (var (name, callback) in _callbacksAfterCall)
{
callback(expr);
}
Visit(operand, context);
}
}
}
/// <summary>
/// Expression visitor.
/// </summary>
/// <typeparam name="TExprResult">Expression visit result type.</typeparam>
/// <typeparam name="TTypeResult">Type visit result type.</typeparam>
public abstract partial class ExprVisitor<TExprResult, TTypeResult> : ExprVisitor<TExprResult, TTypeResult, Unit>
{
/// <summary>
/// Initializes a new instance of the <see cref="ExprVisitor{TExprResult, TTypeResult}"/> class.
/// </summary>
/// <param name="visitOtherFunctions">Vist other functions.</param>
public ExprVisitor(bool visitOtherFunctions = false)
: base(visitOtherFunctions)
{
}
/// <summary>
/// Visit <see cref="Expr"/>.
/// </summary>
public TExprResult Visit(Expr expr) => Visit(expr, default);
/// <summary>
/// Default visit routine.
/// </summary>
/// <param name="expr">Expression.</param>
/// <returns>Result.</returns>
protected internal virtual TExprResult DefaultVisit(Expr expr) => base.DefaultVisit(expr, default);
/// <inheritdoc/>
protected internal sealed override TExprResult DefaultVisit(Expr expr, Unit context) => DefaultVisit(expr);
/// <summary>
/// Default leaf visit routine.
/// </summary>
protected virtual TExprResult DefaultVisitLeaf(Expr expr) => base.DefaultVisitLeaf(expr, default);
protected sealed override TExprResult DefaultVisitLeaf(Expr expr, Unit context) => DefaultVisitLeaf(expr);
protected virtual TExprResult DispatchVisit(Expr expr) => base.DispatchVisit(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult DispatchVisit(Expr expr, Unit context) => DispatchVisit(expr);
}

View File

@ -0,0 +1,751 @@

//---------------------------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by T4 template.
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
// </auto-generated>
//---------------------------------------------------------------------------------------------------
using System;
using System.Collections.Generic;
using System.Reactive;
namespace Nncase.IR;
public partial class ExprVisitor<TExprResult, TTypeResult, TContext>
{
/// <inheritdoc />
protected internal override TExprResult VisitCall(Call expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafCall(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitFunction(Function expr, TContext context)
{
if (CanVisitFunctionBody(expr))
{
VisitOperands(expr, context);
}
return VisitLeafFunction(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitFusion(Fusion expr, TContext context)
{
if (CanVisitFunctionBody(expr))
{
VisitOperands(expr, context);
}
return VisitLeafFusion(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitIf(If expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafIf(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitMarker(Marker expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafMarker(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitNone(None expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafNone(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitOp(Op expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafOp(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, TContext context)
{
if (CanVisitFunctionBody(expr))
{
VisitOperands(expr, context);
}
return VisitLeafPrimFunctionWrapper(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitTensorConst(TensorConst expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafTensorConst(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitTuple(IR.Tuple expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafTuple(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitTupleConst(TupleConst expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafTupleConst(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitVar(Var expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafVar(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitBlock(TIR.Block expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafBlock(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafLogicalBuffer(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafPhysicalBuffer(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitBufferLoad(TIR.BufferLoad expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafBufferLoad(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitBufferRegion(TIR.BufferRegion expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafBufferRegion(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitBufferStore(TIR.BufferStore expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafBufferStore(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitFor(TIR.For expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafFor(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitIfThenElse(TIR.IfThenElse expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafIfThenElse(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitLet(TIR.Let expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafLet(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitPrimFunction(TIR.PrimFunction expr, TContext context)
{
if (CanVisitFunctionBody(expr))
{
VisitOperands(expr, context);
}
return VisitLeafPrimFunction(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitSequential(TIR.Sequential expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafSequential(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitRange(TIR.Range expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafRange(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitIterVar(TIR.IterVar expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafIterVar(expr, context);
}
/// <summary>
/// Visit leaf <see cref="BaseFunction"/>.
/// </summary>
protected virtual TExprResult VisitLeafBaseFunction(BaseFunction expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="Call"/>.
/// </summary>
protected virtual TExprResult VisitLeafCall(Call expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="Const"/>.
/// </summary>
protected virtual TExprResult VisitLeafConst(Const expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="Function"/>.
/// </summary>
protected virtual TExprResult VisitLeafFunction(Function expr, TContext context) => VisitLeafBaseFunction(expr, context);
/// <summary>
/// Visit leaf <see cref="Fusion"/>.
/// </summary>
protected virtual TExprResult VisitLeafFusion(Fusion expr, TContext context) => VisitLeafBaseFunction(expr, context);
/// <summary>
/// Visit leaf <see cref="If"/>.
/// </summary>
protected virtual TExprResult VisitLeafIf(If expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="Marker"/>.
/// </summary>
protected virtual TExprResult VisitLeafMarker(Marker expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="None"/>.
/// </summary>
protected virtual TExprResult VisitLeafNone(None expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="Op"/>.
/// </summary>
protected virtual TExprResult VisitLeafOp(Op expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="PrimFunctionWrapper"/>.
/// </summary>
protected virtual TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, TContext context) => VisitLeafBaseFunction(expr, context);
/// <summary>
/// Visit leaf <see cref="TensorConst"/>.
/// </summary>
protected virtual TExprResult VisitLeafTensorConst(TensorConst expr, TContext context) => VisitLeafConst(expr, context);
/// <summary>
/// Visit leaf <see cref="IR.Tuple"/>.
/// </summary>
protected virtual TExprResult VisitLeafTuple(IR.Tuple expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TupleConst"/>.
/// </summary>
protected virtual TExprResult VisitLeafTupleConst(TupleConst expr, TContext context) => VisitLeafConst(expr, context);
/// <summary>
/// Visit leaf <see cref="Var"/>.
/// </summary>
protected virtual TExprResult VisitLeafVar(Var expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.Block"/>.
/// </summary>
protected virtual TExprResult VisitLeafBlock(TIR.Block expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.Buffer"/>.
/// </summary>
protected virtual TExprResult VisitLeafBuffer(TIR.Buffer expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.LogicalBuffer"/>.
/// </summary>
protected virtual TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => VisitLeafBuffer(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.PhysicalBuffer"/>.
/// </summary>
protected virtual TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitLeafBuffer(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.BufferLoad"/>.
/// </summary>
protected virtual TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.BufferRegion"/>.
/// </summary>
protected virtual TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.BufferStore"/>.
/// </summary>
protected virtual TExprResult VisitLeafBufferStore(TIR.BufferStore expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.For"/>.
/// </summary>
protected virtual TExprResult VisitLeafFor(TIR.For expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.IfThenElse"/>.
/// </summary>
protected virtual TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.Let"/>.
/// </summary>
protected virtual TExprResult VisitLeafLet(TIR.Let expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.PrimFunction"/>.
/// </summary>
protected virtual TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.Sequential"/>.
/// </summary>
protected virtual TExprResult VisitLeafSequential(TIR.Sequential expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.Range"/>.
/// </summary>
protected virtual TExprResult VisitLeafRange(TIR.Range expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.IterVar"/>.
/// </summary>
protected virtual TExprResult VisitLeafIterVar(TIR.IterVar expr, TContext context) => DefaultVisitLeaf(expr, context);
}
public partial class ExprVisitor<TExprResult, TTypeResult>
{
/// <summary>
/// Visit <see cref="Call"/>.
/// </summary>
internal protected virtual TExprResult VisitCall(Call expr) => base.VisitCall(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitCall(Call expr, Unit context) => VisitCall(expr);
/// <summary>
/// Visit <see cref="Function"/>.
/// </summary>
internal protected virtual TExprResult VisitFunction(Function expr) => base.VisitFunction(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitFunction(Function expr, Unit context) => VisitFunction(expr);
/// <summary>
/// Visit <see cref="Fusion"/>.
/// </summary>
internal protected virtual TExprResult VisitFusion(Fusion expr) => base.VisitFusion(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitFusion(Fusion expr, Unit context) => VisitFusion(expr);
/// <summary>
/// Visit <see cref="If"/>.
/// </summary>
internal protected virtual TExprResult VisitIf(If expr) => base.VisitIf(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitIf(If expr, Unit context) => VisitIf(expr);
/// <summary>
/// Visit <see cref="Marker"/>.
/// </summary>
internal protected virtual TExprResult VisitMarker(Marker expr) => base.VisitMarker(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitMarker(Marker expr, Unit context) => VisitMarker(expr);
/// <summary>
/// Visit <see cref="None"/>.
/// </summary>
internal protected virtual TExprResult VisitNone(None expr) => base.VisitNone(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitNone(None expr, Unit context) => VisitNone(expr);
/// <summary>
/// Visit <see cref="Op"/>.
/// </summary>
internal protected virtual TExprResult VisitOp(Op expr) => base.VisitOp(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitOp(Op expr, Unit context) => VisitOp(expr);
/// <summary>
/// Visit <see cref="PrimFunctionWrapper"/>.
/// </summary>
internal protected virtual TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitPrimFunctionWrapper(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitPrimFunctionWrapper(expr);
/// <summary>
/// Visit <see cref="TensorConst"/>.
/// </summary>
internal protected virtual TExprResult VisitTensorConst(TensorConst expr) => base.VisitTensorConst(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr);
/// <summary>
/// Visit <see cref="IR.Tuple"/>.
/// </summary>
internal protected virtual TExprResult VisitTuple(IR.Tuple expr) => base.VisitTuple(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitTuple(IR.Tuple expr, Unit context) => VisitTuple(expr);
/// <summary>
/// Visit <see cref="TupleConst"/>.
/// </summary>
internal protected virtual TExprResult VisitTupleConst(TupleConst expr) => base.VisitTupleConst(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitTupleConst(TupleConst expr, Unit context) => VisitTupleConst(expr);
/// <summary>
/// Visit <see cref="Var"/>.
/// </summary>
internal protected virtual TExprResult VisitVar(Var expr) => base.VisitVar(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitVar(Var expr, Unit context) => VisitVar(expr);
/// <summary>
/// Visit <see cref="TIR.Block"/>.
/// </summary>
internal protected virtual TExprResult VisitBlock(TIR.Block expr) => base.VisitBlock(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBlock(TIR.Block expr, Unit context) => VisitBlock(expr);
/// <summary>
/// Visit <see cref="TIR.LogicalBuffer"/>.
/// </summary>
internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLogicalBuffer(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLogicalBuffer(expr);
/// <summary>
/// Visit <see cref="TIR.PhysicalBuffer"/>.
/// </summary>
internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitPhysicalBuffer(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr);
/// <summary>
/// Visit <see cref="TIR.BufferLoad"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr) => base.VisitBufferLoad(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBufferLoad(TIR.BufferLoad expr, Unit context) => VisitBufferLoad(expr);
/// <summary>
/// Visit <see cref="TIR.BufferRegion"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBufferRegion(TIR.BufferRegion expr, Unit context) => VisitBufferRegion(expr);
/// <summary>
/// Visit <see cref="TIR.BufferStore"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr) => base.VisitBufferStore(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBufferStore(TIR.BufferStore expr, Unit context) => VisitBufferStore(expr);
/// <summary>
/// Visit <see cref="TIR.For"/>.
/// </summary>
internal protected virtual TExprResult VisitFor(TIR.For expr) => base.VisitFor(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitFor(TIR.For expr, Unit context) => VisitFor(expr);
/// <summary>
/// Visit <see cref="TIR.IfThenElse"/>.
/// </summary>
internal protected virtual TExprResult VisitIfThenElse(TIR.IfThenElse expr) => base.VisitIfThenElse(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitIfThenElse(TIR.IfThenElse expr, Unit context) => VisitIfThenElse(expr);
/// <summary>
/// Visit <see cref="TIR.Let"/>.
/// </summary>
internal protected virtual TExprResult VisitLet(TIR.Let expr) => base.VisitLet(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitLet(TIR.Let expr, Unit context) => VisitLet(expr);
/// <summary>
/// Visit <see cref="TIR.PrimFunction"/>.
/// </summary>
internal protected virtual TExprResult VisitPrimFunction(TIR.PrimFunction expr) => base.VisitPrimFunction(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitPrimFunction(TIR.PrimFunction expr, Unit context) => VisitPrimFunction(expr);
/// <summary>
/// Visit <see cref="TIR.Sequential"/>.
/// </summary>
internal protected virtual TExprResult VisitSequential(TIR.Sequential expr) => base.VisitSequential(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitSequential(TIR.Sequential expr, Unit context) => VisitSequential(expr);
/// <summary>
/// Visit <see cref="TIR.Range"/>.
/// </summary>
internal protected virtual TExprResult VisitRange(TIR.Range expr) => base.VisitRange(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitRange(TIR.Range expr, Unit context) => VisitRange(expr);
/// <summary>
/// Visit <see cref="TIR.IterVar"/>.
/// </summary>
internal protected virtual TExprResult VisitIterVar(TIR.IterVar expr) => base.VisitIterVar(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitIterVar(TIR.IterVar expr, Unit context) => VisitIterVar(expr);
/// <summary>
/// Visit leaf <see cref="BaseFunction"/>.
/// </summary>
protected virtual TExprResult VisitLeafBaseFunction(BaseFunction expr) => base.VisitLeafBaseFunction(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafBaseFunction(BaseFunction expr, Unit context) => VisitLeafBaseFunction(expr);
/// <summary>
/// Visit leaf <see cref="Call"/>.
/// </summary>
protected virtual TExprResult VisitLeafCall(Call expr) => base.VisitLeafCall(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafCall(Call expr, Unit context) => VisitLeafCall(expr);
/// <summary>
/// Visit leaf <see cref="Const"/>.
/// </summary>
protected virtual TExprResult VisitLeafConst(Const expr) => base.VisitLeafConst(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafConst(Const expr, Unit context) => VisitLeafConst(expr);
/// <summary>
/// Visit leaf <see cref="Function"/>.
/// </summary>
protected virtual TExprResult VisitLeafFunction(Function expr) => base.VisitLeafFunction(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafFunction(Function expr, Unit context) => VisitLeafFunction(expr);
/// <summary>
/// Visit leaf <see cref="Fusion"/>.
/// </summary>
protected virtual TExprResult VisitLeafFusion(Fusion expr) => base.VisitLeafFusion(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafFusion(Fusion expr, Unit context) => VisitLeafFusion(expr);
/// <summary>
/// Visit leaf <see cref="If"/>.
/// </summary>
protected virtual TExprResult VisitLeafIf(If expr) => base.VisitLeafIf(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafIf(If expr, Unit context) => VisitLeafIf(expr);
/// <summary>
/// Visit leaf <see cref="Marker"/>.
/// </summary>
protected virtual TExprResult VisitLeafMarker(Marker expr) => base.VisitLeafMarker(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafMarker(Marker expr, Unit context) => VisitLeafMarker(expr);
/// <summary>
/// Visit leaf <see cref="None"/>.
/// </summary>
protected virtual TExprResult VisitLeafNone(None expr) => base.VisitLeafNone(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafNone(None expr, Unit context) => VisitLeafNone(expr);
/// <summary>
/// Visit leaf <see cref="Op"/>.
/// </summary>
protected virtual TExprResult VisitLeafOp(Op expr) => base.VisitLeafOp(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafOp(Op expr, Unit context) => VisitLeafOp(expr);
/// <summary>
/// Visit leaf <see cref="PrimFunctionWrapper"/>.
/// </summary>
protected virtual TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitLeafPrimFunctionWrapper(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitLeafPrimFunctionWrapper(expr);
/// <summary>
/// Visit leaf <see cref="TensorConst"/>.
/// </summary>
protected virtual TExprResult VisitLeafTensorConst(TensorConst expr) => base.VisitLeafTensorConst(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafTensorConst(TensorConst expr, Unit context) => VisitLeafTensorConst(expr);
/// <summary>
/// Visit leaf <see cref="IR.Tuple"/>.
/// </summary>
protected virtual TExprResult VisitLeafTuple(IR.Tuple expr) => base.VisitLeafTuple(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafTuple(IR.Tuple expr, Unit context) => VisitLeafTuple(expr);
/// <summary>
/// Visit leaf <see cref="TupleConst"/>.
/// </summary>
protected virtual TExprResult VisitLeafTupleConst(TupleConst expr) => base.VisitLeafTupleConst(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafTupleConst(TupleConst expr, Unit context) => VisitLeafTupleConst(expr);
/// <summary>
/// Visit leaf <see cref="Var"/>.
/// </summary>
protected virtual TExprResult VisitLeafVar(Var expr) => base.VisitLeafVar(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafVar(Var expr, Unit context) => VisitLeafVar(expr);
/// <summary>
/// Visit leaf <see cref="TIR.Block"/>.
/// </summary>
protected virtual TExprResult VisitLeafBlock(TIR.Block expr) => base.VisitLeafBlock(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafBlock(TIR.Block expr, Unit context) => VisitLeafBlock(expr);
/// <summary>
/// Visit leaf <see cref="TIR.Buffer"/>.
/// </summary>
protected virtual TExprResult VisitLeafBuffer(TIR.Buffer expr) => base.VisitLeafBuffer(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafBuffer(TIR.Buffer expr, Unit context) => VisitLeafBuffer(expr);
/// <summary>
/// Visit leaf <see cref="TIR.LogicalBuffer"/>.
/// </summary>
protected virtual TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLeafLogicalBuffer(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLeafLogicalBuffer(expr);
/// <summary>
/// Visit leaf <see cref="TIR.PhysicalBuffer"/>.
/// </summary>
protected virtual TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitLeafPhysicalBuffer(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitLeafPhysicalBuffer(expr);
/// <summary>
/// Visit leaf <see cref="TIR.BufferLoad"/>.
/// </summary>
protected virtual TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr) => base.VisitLeafBufferLoad(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr, Unit context) => VisitLeafBufferLoad(expr);
/// <summary>
/// Visit leaf <see cref="TIR.BufferRegion"/>.
/// </summary>
protected virtual TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr) => base.VisitLeafBufferRegion(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr, Unit context) => VisitLeafBufferRegion(expr);
/// <summary>
/// Visit leaf <see cref="TIR.BufferStore"/>.
/// </summary>
protected virtual TExprResult VisitLeafBufferStore(TIR.BufferStore expr) => base.VisitLeafBufferStore(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafBufferStore(TIR.BufferStore expr, Unit context) => VisitLeafBufferStore(expr);
/// <summary>
/// Visit leaf <see cref="TIR.For"/>.
/// </summary>
protected virtual TExprResult VisitLeafFor(TIR.For expr) => base.VisitLeafFor(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafFor(TIR.For expr, Unit context) => VisitLeafFor(expr);
/// <summary>
/// Visit leaf <see cref="TIR.IfThenElse"/>.
/// </summary>
protected virtual TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr) => base.VisitLeafIfThenElse(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr, Unit context) => VisitLeafIfThenElse(expr);
/// <summary>
/// Visit leaf <see cref="TIR.Let"/>.
/// </summary>
protected virtual TExprResult VisitLeafLet(TIR.Let expr) => base.VisitLeafLet(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafLet(TIR.Let expr, Unit context) => VisitLeafLet(expr);
/// <summary>
/// Visit leaf <see cref="TIR.PrimFunction"/>.
/// </summary>
protected virtual TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr) => base.VisitLeafPrimFunction(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr, Unit context) => VisitLeafPrimFunction(expr);
/// <summary>
/// Visit leaf <see cref="TIR.Sequential"/>.
/// </summary>
protected virtual TExprResult VisitLeafSequential(TIR.Sequential expr) => base.VisitLeafSequential(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafSequential(TIR.Sequential expr, Unit context) => VisitLeafSequential(expr);
/// <summary>
/// Visit leaf <see cref="TIR.Range"/>.
/// </summary>
protected virtual TExprResult VisitLeafRange(TIR.Range expr) => base.VisitLeafRange(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafRange(TIR.Range expr, Unit context) => VisitLeafRange(expr);
/// <summary>
/// Visit leaf <see cref="TIR.IterVar"/>.
/// </summary>
protected virtual TExprResult VisitLeafIterVar(TIR.IterVar expr) => base.VisitLeafIterVar(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafIterVar(TIR.IterVar expr, Unit context) => VisitLeafIterVar(expr);
}

View File

@ -0,0 +1,105 @@
<#@ template debug="false" hostspecific="false" language="C#" #>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.IO" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ output extension=".cs" #>
<#@ include file="IRListParser.tt"#>
//---------------------------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by T4 template.
// Changes to this file may cause incorrect behavior and will be lost if the code is regenerated.
// </auto-generated>
//---------------------------------------------------------------------------------------------------
using System;
using System.Collections.Generic;
using System.Reactive;
namespace Nncase.IR;
public partial class ExprVisitor<TExprResult, TTypeResult, TContext>
{
<#
foreach (var ir in irs.Where(x => x.IsDerived))
{
#>
/// <inheritdoc />
protected internal override TExprResult Visit<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, TContext context)
{
<#
var fieldIdent = ir.IsFunction ? " " : string.Empty;
if (ir.IsFunction)
{
#>
if (CanVisitFunctionBody(expr))
{
<#
}
#>
<#=fieldIdent#>VisitOperands(expr, context);
<#
if (ir.IsFunction)
{
#>
}
<#
}
#>
return VisitLeaf<#=ir.Name#>(expr, context);
}
<#
}
#>
<#
foreach (var ir in irs)
{
var func = ir.VisitBase == "Default" ? "DefaultVisitLeaf" : $"VisitLeaf{ir.VisitBase}";
#>
/// <summary>
/// Visit leaf <see cref="<#=ir.Namespace#><#=ir.Name#>"/>.
/// </summary>
protected virtual TExprResult VisitLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, TContext context) => <#=func#>(expr, context);
<#
}
#>
}
public partial class ExprVisitor<TExprResult, TTypeResult>
{
<#
foreach (var ir in irs.Where(x => x.IsDerived))
{
#>
/// <summary>
/// Visit <see cref="<#=ir.Namespace#><#=ir.Name#>"/>.
/// </summary>
internal protected virtual TExprResult Visit<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr) => base.Visit<#=ir.Name#>(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult Visit<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, Unit context) => Visit<#=ir.Name#>(expr);
<#
}
#>
<#
foreach (var ir in irs)
{
var func = ir.VisitBase == "Default" ? "DefaultVisitLeaf" : $"VisitLeaf{ir.VisitBase}";
#>
/// <summary>
/// Visit leaf <see cref="<#=ir.Namespace#><#=ir.Name#>"/>.
/// </summary>
protected virtual TExprResult VisitLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr) => base.VisitLeaf<#=ir.Name#>(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeaf<#=ir.Name#>(<#=ir.Namespace#><#=ir.Name#> expr, Unit context) => VisitLeaf<#=ir.Name#>(expr);
<#
}
#>
}

View File

@ -0,0 +1,39 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive;
using System.Text;
using System.Threading.Tasks;
namespace Nncase.IR;
public abstract class ExprWalker<TContext> : ExprVisitor<Unit, Unit, TContext>
{
/// <summary>
/// Initializes a new instance of the <see cref="ExprWalker{TContext}"/> class.
/// </summary>
/// <param name="visitOtherFunctions">Vist other functions.</param>
public ExprWalker(bool visitOtherFunctions = false)
: base(visitOtherFunctions)
{
}
protected override Unit DefaultVisitLeaf(Expr expr, TContext context) => default;
}
public abstract class ExprWalker : ExprVisitor<Unit, Unit>
{
/// <summary>
/// Initializes a new instance of the <see cref="ExprWalker"/> class.
/// </summary>
/// <param name="visitOtherFunctions">Vist other functions.</param>
public ExprWalker(bool visitOtherFunctions = false)
: base(visitOtherFunctions)
{
}
protected override Unit DefaultVisitLeaf(Expr expr) => default;
}

View File

@ -3,56 +3,35 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using NetFabric.Hyperlinq;
using Nncase.Utilities;
namespace Nncase.IR;
/// <summary>
/// the Callable Expr.
/// </summary>
public abstract record Callable(string Name, string ModuleKind) : Expr
{
/// <summary>
/// StackVM module kind.
/// </summary>
public static readonly string StackVMModuleKind = "stackvm";
}
/// <summary>
/// Base function.
/// </summary>
public abstract record BaseFunction(string Name, string ModuleKind) : Callable(Name, ModuleKind)
{
/// <summary>
/// Gets sched result.
/// </summary>
public Schedule.SchedFunctionResult SchedResult { get; } = new();
/// <summary>
/// Gets parameter types.
/// </summary>
public abstract IEnumerable<IRType?> ParameterTypes { get; }
/// <inheritdoc/>
public override int GetHashCode() => base.GetHashCode();
}
/// <summary>
/// Function expression.
/// </summary>
public sealed record Function(string Name, Expr Body, IRArray<Var> Parameters) : BaseFunction(Name, StackVMModuleKind)
public sealed class Function : BaseFunction
{
private static int _globalFuncIndex;
/// <summary>
/// Initializes a new instance of the <see cref="Function"/> class.
/// build function.
/// </summary>
/// <param name="parameters">Parameters.</param>
/// <param name="body">Body.</param>
public Function(Expr body, IRArray<Var> parameters)
public Function(string name, Expr body, ReadOnlySpan<Var> parameters)
: base(name, StackVMModuleKind, ArrayUtility.Concat(body, SpanUtility.UnsafeCast<Var, Expr>(parameters)))
{
}
/// <summary>
/// Initializes a new instance of the <see cref="Function"/> class.
/// build function.
/// </summary>
public Function(Expr body, ReadOnlySpan<Var> parameters)
: this($"func_{_globalFuncIndex++}", body, parameters)
{
}
@ -61,13 +40,33 @@ public sealed record Function(string Name, Expr Body, IRArray<Var> Parameters) :
/// Initializes a new instance of the <see cref="Function"/> class.
/// build function.
/// </summary>
public Function(Expr body, params Var[] parameters)
: this($"func_{_globalFuncIndex++}", body, new(parameters))
public Function(string name, Expr body, params Var[] parameters)
: this(name, body, parameters.AsSpan())
{
}
/// <summary>
/// Initializes a new instance of the <see cref="Function"/> class.
/// build function.
/// </summary>
public Function(Expr body, params Var[] parameters)
: this(body, parameters.AsSpan())
{
}
public Expr Body => Operands[0];
public ReadOnlySpan<Var> Parameters => SpanUtility.UnsafeCast<Expr, Var>(Operands[1..]);
/// <summary>
/// Gets get all parameter checked types.
/// </summary>
public override IEnumerable<IRType?> ParameterTypes => Parameters.Select(x => x.CheckedType);
public override IEnumerable<IRType?> ParameterTypes => Parameters.AsValueEnumerable().Select(x => x.CheckedType).ToArray();
/// <inheritdoc/>
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
=> functor.VisitFunction(this, context);
public Function With(string? name = null, Expr? body = null, Var[]? parameters = null)
=> new Function(name ?? Name, body ?? Body, parameters ?? Parameters);
}

View File

@ -1,25 +1,25 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using NetFabric.Hyperlinq;
using Nncase.CodeGen;
using Nncase.Utilities;
namespace Nncase.IR;
/// <summary>
/// Fusion expression.
/// </summary>
public record Fusion(string Name, string ModuleKind, Expr Body, IRArray<Var> Parameters) : BaseFunction(Name, ModuleKind)
public sealed class Fusion : BaseFunction
{
private static int _globalFusionIndex;
/// <summary>
/// Initializes a new instance of the <see cref="Fusion"/> class.
/// build function.
/// </summary>
/// <param name="module_kind">Module kind.</param>
/// <param name="parameters">Parameters.</param>
/// <param name="body">Body.</param>
public Fusion(string module_kind, Expr body, IRArray<Var> parameters)
: this($"fusion_{_globalFusionIndex++}", module_kind, body, parameters)
public Fusion(string name, string moduleKind, Expr body, ReadOnlySpan<Var> parameters)
: base(name, moduleKind, ArrayUtility.Concat(body, SpanUtility.UnsafeCast<Var, Expr>(parameters)))
{
}
@ -27,13 +27,42 @@ public record Fusion(string Name, string ModuleKind, Expr Body, IRArray<Var> Par
/// Initializes a new instance of the <see cref="Fusion"/> class.
/// build function.
/// </summary>
public Fusion(string module_kind, Expr body, params Var[] parameters)
: this($"fusion_{_globalFusionIndex++}", module_kind, body, new(parameters))
public Fusion(string moduleKind, Expr body, ReadOnlySpan<Var> parameters)
: this($"func_{_globalFusionIndex++}", moduleKind, body, parameters)
{
}
/// <summary>
/// Initializes a new instance of the <see cref="Fusion"/> class.
/// build function.
/// </summary>
public Fusion(string name, string moduleKind, Expr body, params Var[] parameters)
: base(name, moduleKind, ArrayUtility.Concat(body, SpanUtility.UnsafeCast<Var, Expr>(parameters)))
{
}
/// <summary>
/// Initializes a new instance of the <see cref="Fusion"/> class.
/// build function.
/// </summary>
public Fusion(string moduleKind, Expr body, params Var[] parameters)
: this($"func_{_globalFusionIndex++}", moduleKind, body, parameters)
{
}
public Expr Body => Operands[0];
public ReadOnlySpan<Var> Parameters => SpanUtility.UnsafeCast<Expr, Var>(Operands[1..]);
/// <summary>
/// Gets get all parameter checked types.
/// </summary>
public override IEnumerable<IRType?> ParameterTypes => Parameters.Select(x => x.CheckedType);
public override IEnumerable<IRType?> ParameterTypes => Parameters.AsValueEnumerable().Select(x => x.CheckedType).ToArray();
/// <inheritdoc/>
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
=> functor.VisitFusion(this, context);
public Fusion With(string? name = null, string? moduleKind = null, Expr? body = null, Var[]? parameters = null)
=> new Fusion(name ?? Name, moduleKind ?? ModuleKind, body ?? Body, parameters ?? Parameters);
}

View File

@ -1,371 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using GiGraph.Dot.Entities.Clusters;
using GiGraph.Dot.Entities.Graphs;
using GiGraph.Dot.Entities.Html.Table;
using GiGraph.Dot.Entities.Nodes;
using GiGraph.Dot.Extensions;
using GiGraph.Dot.Types.Colors;
using GiGraph.Dot.Types.Edges;
using GiGraph.Dot.Types.Fonts;
using GiGraph.Dot.Types.Graphs;
using GiGraph.Dot.Types.Nodes;
using GiGraph.Dot.Types.Records;
using GiGraph.Dot.Types.Styling;
using Nncase.IR;
namespace Nncase.IR;
internal sealed class ILDotOption
{
private readonly DotNode? _dotNode;
private readonly string? _str;
public ILDotOption(DotNode dotNode)
{
_dotNode = dotNode;
_str = null;
}
public ILDotOption(string str)
{
_dotNode = null;
_str = str;
}
public DotNode DotNode => _dotNode!;
public string Str => _str!;
public bool IsDotNode => _dotNode is not null;
}
internal sealed class ILDotPrintVisitor : ExprVisitor<ILDotOption, string>
{
private readonly bool _display_callable;
private readonly DotGraph _dotGraph;
private readonly List<(string, DotGraph)> _subdotGraphs;
private int _idCounter;
private BaseFunction? _entryBaseFunc;
public ILDotPrintVisitor(bool display_callable)
{
_display_callable = display_callable;
_dotGraph = new(directed: true);
_subdotGraphs = new();
}
/// <inheritdoc/>
public override ILDotOption Visit(BaseFunction baseFunction)
{
_entryBaseFunc ??= baseFunction;
return base.Visit(baseFunction);
}
/// <inheritdoc/>
public override ILDotOption Visit(PrimFunctionWrapper expr)
{
if (!ExpressionMemo.TryGetValue(expr, out var result))
{
var id = _idCounter++;
_ = "\"" + id.ToString() + "\"";
result = new(expr.Name);
ExpressionMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public override ILDotOption Visit(TIR.PrimFunction expr)
{
if (!ExpressionMemo.TryGetValue(expr, out var result))
{
var id = _idCounter++;
_ = "\"" + id.ToString() + "\"";
result = new(expr.Name);
ExpressionMemo.Add(expr, result);
}
return result;
}
/// <inheritdoc/>
public override ILDotOption Visit(Fusion expr)
{
_entryBaseFunc ??= expr;
if (!object.ReferenceEquals(_entryBaseFunc, expr))
{
if (_display_callable)
{
var visitor = new ILDotPrintVisitor(_display_callable);
visitor.Visit(expr);
_subdotGraphs.Add((expr.Name, visitor._dotGraph));
_subdotGraphs.AddRange(visitor._subdotGraphs);
}
return new(expr.Name);
}
return base.Visit(expr);
}
/// <inheritdoc/>
public override ILDotOption Visit(Function expr)
{
_entryBaseFunc ??= expr;
if (!object.ReferenceEquals(_entryBaseFunc, expr))
{
if (_display_callable)
{
var visitor = new ILDotPrintVisitor(_display_callable);
visitor.Visit(expr);
_subdotGraphs.Add((expr.Name, visitor._dotGraph));
_subdotGraphs.AddRange(visitor._subdotGraphs);
}
return new(expr.Name);
}
return base.Visit(expr);
}
/// <inheritdoc/>
public override ILDotOption VisitLeaf(Fusion expr)
{
return new(expr.Name);
}
/// <inheritdoc/>
public override ILDotOption VisitLeaf(Function expr)
{
return new(expr.Name);
}
public override ILDotOption VisitLeaf(Tuple expr)
{
var id = _idCounter++;
string exprId = "\"" + id.ToString() + "\"";
var table = new DotHtmlTable
{
BorderWidth = 0,
CellBorderWidth = 1,
CellSpacing = 0,
};
var connect_list = new List<(Expr, string)>();
// 1. the connect type.
table.AddRow(row =>
{
row.AddCell("Tuple"); // key wrods type.
int count = 0;
foreach (var child in expr.Fields)
{
var childnode = Visit(child);
var portName = $"P{count++}";
row.AddCell(childnode.IsDotNode ? string.Empty : childnode.Str, cell => cell.PortName = portName);
if (childnode.IsDotNode)
{
connect_list.Add((child, portName));
}
}
});
// 3. make crrent node.
var dotNode = _dotGraph.Nodes.Add(exprId);
dotNode.ToPlainHtmlNode(table);
// 4. connect edge.
foreach (var (child, port_name) in connect_list)
{
_dotGraph.Edges.Add(Visit(child).DotNode, dotNode, edge =>
{
edge.Head.Endpoint.Port = new DotEndpointPort(port_name);
});
}
return new(dotNode);
}
public override ILDotOption VisitLeaf(Op expr)
{
return new(expr.GetType().Name + $"({expr.DisplayProperty()})");
}
public override ILDotOption VisitLeaf(Const expr)
{
return new(CompilerServices.Print(expr));
}
public override ILDotOption VisitLeaf(None expr)
{
return new("None");
}
public override ILDotOption VisitLeaf(Marker expr)
{
var id = _idCounter++;
string exprId = "\"" + id.ToString() + "\"";
var table = new DotHtmlTable
{
BorderWidth = 0,
CellBorderWidth = 1,
CellSpacing = 0,
};
var target = Visit(expr.Target);
var attr = Visit(expr.Attribute);
// 1. the connect type.
table.AddRow(row =>
{
row.AddCell("Marker"); // key wrods type.
if (target.IsDotNode)
{
row.AddCell("Target", cell => cell.PortName = "P0"); // target.
}
else
{
row.AddCell(target.Str, cell => cell.PortName = "P0");
}
if (attr.IsDotNode)
{
row.AddCell("Attr", cell => cell.PortName = "P1"); // attr
}
else
{
row.AddCell(attr.Str, cell => cell.PortName = "P1");
}
});
table.AddRow(row =>
{
row.AddCell(expr.CheckedType is null ? "Null" : CompilerServices.Print(expr.CheckedType), cell => cell.ColumnSpan = 3);
});
// 3. make crrent node.
var dotNode = _dotGraph.Nodes.Add(exprId);
dotNode.ToPlainHtmlNode(table);
// 4. connect edge.
if (target.IsDotNode)
{
_dotGraph.Edges.Add(target.DotNode, dotNode, edge =>
{
edge.Head.Endpoint.Port = new DotEndpointPort("P0");
});
}
if (attr.IsDotNode)
{
_dotGraph.Edges.Add(attr.DotNode, dotNode, edge =>
{
edge.Head.Endpoint.Port = new DotEndpointPort("P1");
});
}
return new(dotNode);
}
public override ILDotOption VisitLeaf(Var expr)
{
var id = _idCounter++;
string exprId = "\"" + id.ToString() + "\"";
var dotNode = new DotNode(exprId) { Label = expr.Name, Shape = DotNodeShape.Rectangle };
_dotGraph.Nodes.Add(dotNode);
return new(dotNode);
}
public override ILDotOption VisitLeaf(Call expr)
{
var id = _idCounter++;
string exprId = "\"" + id.ToString() + "\"";
var table = new DotHtmlTable
{
BorderWidth = 0,
CellBorderWidth = 1,
CellSpacing = 0,
};
var connect_list = new List<(Expr, string)>();
// 1. the connect type.
table.AddRow(row =>
{
row.AddCell("Call"); // key wrods type.
row.AddCell(Visit(expr.Target).Str); // target.
int count = 0;
foreach (var (child, arg_name) in expr.Parameters.Zip(expr.Target switch
{
Op op => op.Parameters.Select(info => info.Name),
Fusion fusion => fusion.Parameters.Select(v => v.Name),
Function func => func.Parameters.Select(v => v.Name),
PrimFunctionWrapper wrapper => wrapper.Target.Parameters.Select(b => b.Name),
_ => throw new NotSupportedException($"Target type {expr.Target.GetType()} is not supported."),
}))
{
if (child is Const or None)
{
continue;
}
var portName = $"P{count++}";
row.AddCell(arg_name, cell => cell.PortName = portName);
connect_list.Add((child, portName));
}
});
table.AddRow(row =>
{
row.AddCell(expr.CheckedType is null ? "Null" : CompilerServices.Print(expr.CheckedType), cell => cell.ColumnSpan = connect_list.Count + 2);
});
// 3. make crrent node.
var dotNode = _dotGraph.Nodes.Add(exprId);
dotNode.ToPlainHtmlNode(table);
// 4. connect edge.
foreach (var (child, port_name) in connect_list)
{
_dotGraph.Edges.Add(Visit(child).DotNode, dotNode, edge =>
{
edge.Head.Endpoint.Port = new DotEndpointPort(port_name);
});
}
return new(dotNode);
}
/// <summary>
/// Save the dot to File.
/// </summary>
/// <param name="name">name.</param>
/// <param name="prefix">prefix.</param>
/// <param name="dumpDir">dump dir.</param>
public void SaveToFile(string name, string prefix, string dumpDir)
{
SaveToFileCore(_dotGraph, name, prefix, dumpDir);
foreach (var (sub_name, subGraph) in _subdotGraphs)
{
SaveToFileCore(subGraph, name + "_" + sub_name, prefix, dumpDir);
}
}
private static void SaveToFileCore(DotGraph dotGraph, string name, string prefix, string dumpDir)
{
var nprefix = prefix.Any() ? prefix + "_" : prefix;
string dump_path = Path.Combine(dumpDir, $"{nprefix}{name}.dot");
dotGraph.Build();
dotGraph.SaveToFile(dump_path);
}
}

View File

@ -0,0 +1,17 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Nncase.IR;
public static class IRHelpers
{
public static IRType? GetRawCheckedType(Expr expr) => expr.RawCheckedType;
public static void SetRawCheckedType(Expr expr, IRType? value) => expr.RawCheckedType = value;
}

View File

@ -0,0 +1,28 @@
BaseFunction,false,true,Default,,
Call,true,false,Default,,Target;@Arguments
Const,false,false,Default,,
Function,true,true,BaseFunction,,@Parameters;Body
Fusion,true,true,BaseFunction,,@Parameters;Body
If,true,false,Default,,Condition;Then;Else
Marker,true,false,Default,,Target;Attribute
None,true,false,Default,,
Op,true,false,Default,,
PrimFunctionWrapper,true,true,BaseFunction,,Target
TensorConst,true,false,Const,,
Tuple,true,false,Default,IR.,@Fields
TupleConst,true,false,Const,,
Var,true,false,Default,,
Block,true,false,Default,TIR.,Body;InitBody;@IterVars;@Reads;@Writes;@AllocBuffers;Predicate
Buffer,false,false,Default,TIR.,
LogicalBuffer,true,false,Buffer,TIR.,@Dimensions;@Strides
PhysicalBuffer,true,false,Buffer,TIR.,
BufferLoad,true,false,Default,TIR.,Buffer;@Indices
BufferRegion,true,false,Default,TIR.,Buffer;@Region
BufferStore,true,false,Default,TIR.,Buffer;@Indices;Value
For,true,false,Default,TIR.,LoopVar;Domain;Body
IfThenElse,true,false,Default,TIR.,Condition;Then;Else
Let,true,false,Default,TIR.,Var;Expression;Body
PrimFunction,true,true,Default,TIR.,@Parameters;Body
Sequential,true,false,Default,TIR.,@Fields
Range,true,false,Default,TIR.,Start;Stop;Step
IterVar,true,false,Default,TIR.,Value;Dom
1 BaseFunction false true Default
2 Call true false Default Target;@Arguments
3 Const false false Default
4 Function true true BaseFunction @Parameters;Body
5 Fusion true true BaseFunction @Parameters;Body
6 If true false Default Condition;Then;Else
7 Marker true false Default Target;Attribute
8 None true false Default
9 Op true false Default
10 PrimFunctionWrapper true true BaseFunction Target
11 TensorConst true false Const
12 Tuple true false Default IR. @Fields
13 TupleConst true false Const
14 Var true false Default
15 Block true false Default TIR. Body;InitBody;@IterVars;@Reads;@Writes;@AllocBuffers;Predicate
16 Buffer false false Default TIR.
17 LogicalBuffer true false Buffer TIR. @Dimensions;@Strides
18 PhysicalBuffer true false Buffer TIR.
19 BufferLoad true false Default TIR. Buffer;@Indices
20 BufferRegion true false Default TIR. Buffer;@Region
21 BufferStore true false Default TIR. Buffer;@Indices;Value
22 For true false Default TIR. LoopVar;Domain;Body
23 IfThenElse true false Default TIR. Condition;Then;Else
24 Let true false Default TIR. Var;Expression;Body
25 PrimFunction true true Default TIR. @Parameters;Body
26 Sequential true false Default TIR. @Fields
27 Range true false Default TIR. Start;Stop;Step
28 IterVar true false Default TIR. Value;Dom

View File

@ -0,0 +1,32 @@
<#@ assembly name="System.Core" #>
<#@ import namespace="System.IO" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ import namespace="System.Diagnostics" #>
<#
var irs = (from l in File.ReadAllLines("src/Nncase.Core/IR/IRList.csv")
where !string.IsNullOrWhiteSpace(l)
let columns = l.Split(',')
let isDerived = bool.Parse(columns[1])
select new IRDef
{
Name = columns[0],
IsDerived = isDerived,
IsFunction = bool.Parse(columns[2]),
VisitBase = columns[3],
Namespace = columns[4],
Fields = isDerived ? columns[5].Split(new[]{';'}, StringSplitOptions.RemoveEmptyEntries) : Array.Empty<string>()
}).ToArray();
#>
<#+
struct IRDef
{
public string Name;
public bool IsDerived;
public bool IsFunction;
public string VisitBase;
public string Namespace;
public string[] Fields;
}
#>

View File

@ -4,9 +4,12 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR;
using TorchSharp.Modules;
namespace Nncase.IR;
@ -57,6 +60,7 @@ public sealed class IRModule
/// <param name="function">Callable to add.</param>
public void Add(BaseFunction function)
{
CompilerServices.InferenceType(function);
_functions.Add(function);
}
@ -67,47 +71,9 @@ public sealed class IRModule
/// <param name="function">the entry function defination.</param>
public void Replace(int index, BaseFunction function)
{
var old = _functions[index];
var replacer = new FunctionReplacer(old, function);
for (int i = 0; i < _functions.Count; i++)
{
replacer.Visit(_functions[i]);
}
for (int i = 0; i < _functions.Count; i++)
{
var originFunc = _functions[i];
if (replacer.ExpressionMemo.TryGetValue(originFunc, out var replace))
{
_functions[i] = (BaseFunction)replace;
}
}
}
/// <summary>
/// Replace the function call dependencer.
/// </summary>
private sealed class FunctionReplacer : DeepExprMutator
{
private readonly BaseFunction _original;
private readonly BaseFunction _replace;
public FunctionReplacer(BaseFunction original, BaseFunction replace)
{
_original = original;
_replace = replace;
}
public override Expr DefaultMutateLeaf(Expr expr)
{
if (expr is BaseFunction baseFunction
&& object.ReferenceEquals(baseFunction, _original))
{
return _replace;
}
return expr;
}
CompilerServices.InferenceType(function);
ref var old = ref CollectionsMarshal.AsSpan(_functions)[index];
old.ReplaceAllUsesWith(function);
old = function;
}
}

View File

@ -12,16 +12,10 @@ namespace Nncase.IR;
/// <summary>
/// Tuple interface.
/// </summary>
public interface ITuple : IReadOnlyList<Expr>
public interface ITuple
{
/// <summary>
/// Gets fields.
/// Gets fields count.
/// </summary>
IReadOnlyList<Expr> Fields { get; }
/// <summary>
/// Cast to expression.
/// </summary>
/// <returns>The expression.</returns>
public Expr AsExpr() => (Expr)this;
int Count { get; }
}

View File

@ -12,6 +12,23 @@ namespace Nncase.IR;
/// <summary>
/// if(Condition) then { Then } else { Else }.
/// </summary>
public sealed record If(Expr Condition, Expr Then, Expr Else) : Expr
public sealed class If : Expr
{
public If(Expr condition, Expr then, Expr @else)
: base(new[] { condition, then, @else })
{
}
public Expr Condition => Operands[0];
public Expr Then => Operands[1];
public Expr Else => Operands[2];
/// <inheritdoc/>
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
=> functor.VisitIf(this, context);
public If With(Expr? condition = null, Expr? then = null, Expr? @else = null)
=> new If(condition ?? Condition, then ?? Then, @else ?? Else);
}

View File

@ -10,49 +10,53 @@ using System.Threading.Tasks;
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
namespace Nncase.IR.Imaging
namespace Nncase.IR.Imaging;
/// <summary>
/// ResizeImage expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed partial class ResizeImage : Op
{
/// <summary>
/// ResizeImage expression.
/// Gets input.
/// </summary>
[PatternFunctionalGenerator]
public sealed record ResizeImage(
ImageResizeMode ResizeMode,
ImageResizeTransformationMode TransformationMode,
ImageResizeNearestMode NearestMode, bool IsTFResize = false) : Op
{
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(ResizeImage), 0, "input", HasRank(r => r >= 2, "RanK >= 2"));
public static readonly ParameterInfo Input = new(typeof(ResizeImage), 0, "input", HasRank(r => r >= 2, "RanK >= 2"));
/// <summary>
/// Gets roi.
/// [axis 0 start,axis 1 end, ... ,axis n start, axis n end].
/// </summary>
public static readonly ParameterInfo Roi = new(typeof(ResizeImage), 1, "roi", IsNoneType() | (IsFloat() & HasRank(1)));
/// <summary>
/// Gets roi.
/// [axis 0 start,axis 1 end, ... ,axis n start, axis n end].
/// </summary>
public static readonly ParameterInfo Roi = new(typeof(ResizeImage), 1, "roi", IsNoneType() | (IsFloat() & HasRank(1)));
/// <summary>
/// Gets new_size.
/// </summary>
public static readonly ParameterInfo NewSize = new(typeof(ResizeImage), 2, "new_size", HasShape(new[] { 4 }));
/// <summary>
/// Gets new_size.
/// </summary>
public static readonly ParameterInfo NewSize = new(typeof(ResizeImage), 2, "new_size", HasShape(new[] { 4 }));
/// <summary>
/// Gets CubicCoeffA.
/// </summary>
public static readonly ParameterInfo CubicCoeffA = new(typeof(ResizeImage), 3, "cubic_coeff_a", IsNoneType() | IsFloatScalar());
/// <summary>
/// Gets CubicCoeffA.
/// </summary>
public static readonly ParameterInfo CubicCoeffA = new(typeof(ResizeImage), 3, "cubic_coeff_a", IsNoneType() | IsFloatScalar());
/// <summary>
/// Gets ExcludeOutside.
/// </summary>
public static readonly ParameterInfo ExcludeOutside = new(typeof(ResizeImage), 4, "exclude_outside", IsNoneType() | IsIntegralScalar());
/// <summary>
/// Gets ExcludeOutside.
/// </summary>
public static readonly ParameterInfo ExcludeOutside = new(typeof(ResizeImage), 4, "exclude_outside", IsNoneType() | IsIntegralScalar());
/// <summary>
/// Gets ExtrapolationValue.
/// </summary>
public static readonly ParameterInfo ExtrapolationValue = new(typeof(ResizeImage), 5, "extrapolation_value", IsNoneType() | IsFloatScalar());
/// <summary>
/// Gets ExtrapolationValue.
/// </summary>
public static readonly ParameterInfo ExtrapolationValue = new(typeof(ResizeImage), 5, "extrapolation_value", IsNoneType() | IsFloatScalar());
/// <inheritdoc/>
public override string DisplayProperty() => $"ImageResizeMode.{ResizeMode}, ImageResizeTransformationMode.{TransformationMode}, ImageResizeNearestMode.{NearestMode}, {IsTFResize}";
}
public ImageResizeMode ResizeMode { get; }
public ImageResizeTransformationMode TransformationMode { get; }
public ImageResizeNearestMode NearestMode { get; }
public bool IsTFResize { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"ImageResizeMode.{ResizeMode}, ImageResizeTransformationMode.{TransformationMode}, ImageResizeNearestMode.{NearestMode}, {IsTFResize}";
}

View File

@ -6,6 +6,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Toolkit.HighPerformance.Helpers;
namespace Nncase.IR;
@ -43,12 +44,12 @@ public class LeafExprEqualityComparer : IEqualityComparer<Expr>
(Const tx, Const ty) => tx.Equals(ty),
// note think of fusion/primfunc/primfunc wrapper as a black box.
(Fusion tx, Fusion ty) => ReferenceEqualityComparer.Instance.Equals(tx, ty),
(TIR.PrimFunction tx, TIR.PrimFunction ty) => ReferenceEqualityComparer.Instance.Equals(tx, ty),
(PrimFunctionWrapper tx, PrimFunctionWrapper ty) => ReferenceEqualityComparer.Instance.Equals(tx, ty),
(Function tx, Function ty) => tx.Parameters.Equals(ty.Parameters),
(Fusion tx, Fusion ty) => ReferenceEquals(tx, ty),
(TIR.PrimFunction tx, TIR.PrimFunction ty) => ReferenceEquals(tx, ty),
(PrimFunctionWrapper tx, PrimFunctionWrapper ty) => ReferenceEquals(tx, ty),
(Function tx, Function ty) => tx.Parameters.Length == ty.Parameters.Length,
(Tuple tx, Tuple ty) => tx.Count == ty.Count,
(Call tx, Call ty) => tx.Parameters.Count == ty.Parameters.Count,
(Call tx, Call ty) => tx.Arguments.Length == ty.Arguments.Length,
(Op tx, Op ty) => tx.Equals(ty),
(IR.If, IR.If) => true,
(Marker tx, Marker ty) => tx.Name == ty.Name,
@ -64,14 +65,14 @@ public class LeafExprEqualityComparer : IEqualityComparer<Expr>
{
Var x => x.GetHashCode(),
Const x => x.GetHashCode(),
Function x => x.Parameters.GetHashCode(),
Function x => ReferenceEqualityComparer.Instance.GetHashCode(x),
Fusion x => ReferenceEqualityComparer.Instance.GetHashCode(x),
TIR.PrimFunction x => ReferenceEqualityComparer.Instance.GetHashCode(x),
PrimFunctionWrapper x => ReferenceEqualityComparer.Instance.GetHashCode(x),
Tuple x => x.Count.GetHashCode(),
Call x => x.Parameters.Count.GetHashCode(),
Call x => x.Arguments.Length.GetHashCode(),
Op x => x.GetHashCode(),
Marker x => x.Name.GetHashCode(StringComparison.InvariantCulture),
Marker x => x.Name.GetHashCode(StringComparison.Ordinal),
None x => x.GetHashCode(),
IR.If x => x.GetType().GetHashCode(),
_ => throw new InvalidOperationException("Invalid expression type."),

View File

@ -9,6 +9,17 @@ using System.Threading.Tasks;
namespace Nncase.IR;
/// <summary>
/// staic marker name collection.
/// </summary>
public static class WellknownMarkerNames
{
/// <summary>
/// attribute. <seealso cref="IR.Math.RangeOf"/>
/// </summary>
public static readonly string RangeOf = "RangeOf";
}
public class MixQuantInfo
{
public bool HasBindedMixQuantInfo { get; set; }
@ -38,11 +49,28 @@ public class AdaQuantInfo
/// <summary>
/// The marker expression, it's can attach the attribute on the target.
/// </summary>
/// <param name="Name"> Name will belong to <see cref="WellknownMarkerNames"/>. </param>
/// <param name="Target"> expr target. </param>
/// <param name="Attribute"> expr attribute. </param>
public sealed record Marker(string Name, Expr Target, Expr Attribute) : Expr
public sealed class Marker : Expr, IEquatable<Marker?>
{
private readonly string _name;
/// <summary>
/// Initializes a new instance of the <see cref="Marker"/> class.
/// </summary>
/// <param name="name">Name will belong to <see cref="WellknownMarkerNames"/>.</param>
/// <param name="target">expr target.</param>
/// <param name="attribute">expr attribute.</param>
public Marker(string name, Expr target, Expr attribute)
: base(new[] { target, attribute })
{
_name = name;
}
public string Name => _name;
public Expr Target => Operands[0];
public Expr Attribute => Operands[1];
/// <summary>
/// Gets or sets the mix quant info.
/// </summary>
@ -52,15 +80,36 @@ public sealed record Marker(string Name, Expr Target, Expr Attribute) : Expr
/// Gets or sets the ada quant info.
/// </summary>
public AdaQuantInfo? AdaQuantInfo { get; set; }
}
/// <summary>
/// staic marker name collection.
/// </summary>
public static class WellknownMarkerNames
{
/// <summary>
/// attribute. <seealso cref="IR.Math.RangeOf"/>
/// </summary>
public static readonly string RangeOf = "RangeOf";
public static bool operator ==(Marker? left, Marker? right) => EqualityComparer<Marker>.Default.Equals(left, right);
public static bool operator !=(Marker? left, Marker? right) => !(left == right);
/// <inheritdoc/>
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
=> functor.VisitMarker(this, context);
/// <inheritdoc/>
public override bool Equals(object? obj) => Equals(obj as Marker);
/// <inheritdoc/>
public bool Equals(Marker? other)
{
if (ReferenceEquals(this, other))
{
return true;
}
return other is not null && base.Equals(other) && Name == other.Name;
}
public Marker With(string? name = null, Expr? target = null, Expr? attribute = null, MixQuantInfo? mixQuantInfo = null, AdaQuantInfo? adaQuantInfo = null)
=> new Marker(name ?? Name, target ?? Target, attribute ?? Attribute)
{
MixQuantInfo = mixQuantInfo ?? MixQuantInfo,
AdaQuantInfo = adaQuantInfo ?? AdaQuantInfo,
};
/// <inheritdoc/>
protected override int GetHashCodeCore() => HashCode.Combine(base.GetHashCodeCore(), Name);
}

View File

@ -15,7 +15,7 @@ namespace Nncase.IR.Math;
/// Binary expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Binary(BinaryOp BinaryOp) : Op
public sealed partial class Binary : Op
{
/// <summary>
/// Gets lhs.
@ -27,6 +27,8 @@ public sealed record Binary(BinaryOp BinaryOp) : Op
/// </summary>
public static readonly ParameterInfo Rhs = new(typeof(Binary), 1, "rhs");
public BinaryOp BinaryOp { get; }
/// <inheritdoc/>
public override string DisplayProperty()
{

View File

@ -16,7 +16,7 @@ namespace Nncase.IR.Math;
/// Clamp expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Clamp() : Op
public sealed partial class Clamp : Op
{
/// <summary>
/// Gets input.

View File

@ -15,7 +15,7 @@ namespace Nncase.IR.Math;
/// Binary expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Compare(CompareOp CompareOp) : Op
public sealed partial class Compare : Op
{
/// <summary>
/// Gets lhs.
@ -27,6 +27,8 @@ public sealed record Compare(CompareOp CompareOp) : Op
/// </summary>
public static readonly ParameterInfo Rhs = new(typeof(Compare), 1, "rhs");
public CompareOp CompareOp { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"CompareOp.{CompareOp}";
}

View File

@ -16,7 +16,7 @@ namespace Nncase.IR.Math;
/// Condition operation.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Condition() : Op
public sealed partial class Condition : Op
{
/// <summary>
/// Gets Condition.

View File

@ -16,7 +16,7 @@ namespace Nncase.IR.Math;
/// CumSum expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record CumSum() : Op
public sealed partial class CumSum : Op
{
/// <summary>
/// Gets input.

View File

@ -10,7 +10,7 @@ namespace Nncase.IR.Math;
/// Dequantize expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Dequantize(DataType TargetType) : Op
public sealed partial class Dequantize : Op
{
/// <summary>
/// Gets input.
@ -22,6 +22,8 @@ public sealed record Dequantize(DataType TargetType) : Op
/// </summary>
public static readonly ParameterInfo DequantParam = new(typeof(Dequantize), 1, "dequantParam", IsQuantParamType());
public DataType TargetType { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"{TargetType.GetCSharpName()}";
}

View File

@ -9,7 +9,7 @@ namespace Nncase.IR.Math;
/// Fake dequantize expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record FakeDequantize(DataType TargetType) : Op
public sealed partial class FakeDequantize : Op
{
/// <summary>
/// Gets input.
@ -21,6 +21,8 @@ public sealed record FakeDequantize(DataType TargetType) : Op
/// </summary>
public static readonly ParameterInfo DequantParam = new(typeof(FakeDequantize), 1, "dequantParam");
public DataType TargetType { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"{TargetType.GetCSharpName()}";
}

View File

@ -9,7 +9,7 @@ namespace Nncase.IR.Math;
/// Fake quantize expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record FakeQuantize(DataType TargetType) : Op
public sealed partial class FakeQuantize : Op
{
/// <summary>
/// Gets input.
@ -21,6 +21,8 @@ public sealed record FakeQuantize(DataType TargetType) : Op
/// </summary>
public static readonly ParameterInfo QuantParam = new(typeof(FakeQuantize), 1, "quantParam");
public DataType TargetType { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"{TargetType.GetCSharpName()}";
}

View File

@ -15,7 +15,7 @@ namespace Nncase.IR.Math;
/// MatMul expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record MatMul() : Op
public sealed partial class MatMul : Op
{
/// <summary>
/// Gets input.

View File

@ -10,7 +10,7 @@ namespace Nncase.IR.Math;
/// QuantParamOf expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record QuantParamOf(QuantMode QuantMode) : Op
public sealed partial class QuantParamOf : Op
{
/// <summary>
/// Gets range.
@ -22,6 +22,8 @@ public sealed record QuantParamOf(QuantMode QuantMode) : Op
/// </summary>
public static readonly ParameterInfo Bits = new(typeof(QuantParamOf), 1, "bits", IsIntegralScalar());
public QuantMode QuantMode { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"QuantMode.{QuantMode}";
}

View File

@ -10,7 +10,7 @@ namespace Nncase.IR.Math;
/// Quantize expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Quantize(DataType TargetType) : Op
public sealed partial class Quantize : Op
{
/// <summary>
/// Gets input.
@ -22,6 +22,8 @@ public sealed record Quantize(DataType TargetType) : Op
/// </summary>
public static readonly ParameterInfo QuantParam = new(typeof(Quantize), 1, "quantParam", IsQuantParamType());
public DataType TargetType { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"{TargetType.GetCSharpName()}";
}

View File

@ -9,7 +9,7 @@ namespace Nncase.IR.Math;
/// RangeOf expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record RangeOf() : Op
public sealed partial class RangeOf : Op
{
/// <summary>
/// Gets input.

View File

@ -10,7 +10,7 @@ namespace Nncase.IR.Math;
/// Reduce expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Reduce(ReduceOp ReduceOp) : Op
public sealed partial class Reduce : Op
{
/// <summary>
/// Gets input.
@ -32,6 +32,8 @@ public sealed record Reduce(ReduceOp ReduceOp) : Op
/// </summary>
public static readonly ParameterInfo KeepDims = new(typeof(Reduce), 3, "keepDims", IsScalar() & IsIntegral());
public ReduceOp ReduceOp { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"ReduceOp.{ReduceOp}";
}

View File

@ -16,7 +16,7 @@ namespace Nncase.IR.Math;
/// ReduceArg expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record ReduceArg(ReduceArgOp ReduceArgOp, DataType DestType) : Op
public sealed partial class ReduceArg : Op
{
/// <summary>
/// Gets input.
@ -40,6 +40,10 @@ public sealed record ReduceArg(ReduceArgOp ReduceArgOp, DataType DestType) : Op
/// <remarks>Only used in onnx.</remarks>
public static readonly ParameterInfo SelectLastIndex = new(typeof(ReduceArg), 3, "selectLastIndex", IsBoolScalar());
public ReduceArgOp ReduceArgOp { get; }
public DataType DestType { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"ReduceArgOp.{ReduceArgOp}";
}

View File

@ -16,7 +16,7 @@ namespace Nncase.IR.Math;
/// Require expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Require(string Message) : Op
public sealed partial class Require : Op
{
/// <summary>
/// Gets Condition.
@ -28,6 +28,8 @@ public sealed record Require(string Message) : Op
/// </summary>
public static readonly ParameterInfo Value = new(typeof(Require), 1, "value");
public string Message { get; }
/// <inheritdoc/>
public override string DisplayProperty() => "\"\"";
}

View File

@ -16,7 +16,7 @@ namespace Nncase.IR.Math;
/// Unary expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Select() : Op
public sealed partial class Select : Op
{
/// <summary>
/// Gets Condition.

View File

@ -15,13 +15,15 @@ namespace Nncase.IR.Math;
/// Unary expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Unary(UnaryOp UnaryOp) : Op
public sealed partial class Unary : Op
{
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Unary), 0, "input");
public UnaryOp UnaryOp { get; }
/// <inheritdoc/>
public override string DisplayProperty()
{

View File

@ -15,7 +15,7 @@ namespace Nncase.IR.NN;
/// <summary>
/// The base class.
/// </summary>
public abstract record ActivationOp : Op
public abstract class ActivationOp : Op
{
}
@ -23,7 +23,7 @@ public abstract record ActivationOp : Op
/// Sigmoid expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Sigmoid() : ActivationOp
public sealed partial class Sigmoid : ActivationOp
{
/// <summary>
/// Gets input.
@ -35,7 +35,7 @@ public sealed record Sigmoid() : ActivationOp
/// Relu expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Relu() : ActivationOp
public sealed partial class Relu : ActivationOp
{
/// <summary>
/// Gets input.
@ -47,7 +47,7 @@ public sealed record Relu() : ActivationOp
/// Relu6 expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Relu6() : ActivationOp
public sealed partial class Relu6 : ActivationOp
{
/// <summary>
/// Gets input.
@ -59,7 +59,7 @@ public sealed record Relu6() : ActivationOp
/// PRelu expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record PRelu() : ActivationOp
public sealed partial class PRelu : ActivationOp
{
/// <summary>
/// Gets input.
@ -76,7 +76,7 @@ public sealed record PRelu() : ActivationOp
/// LeakyRelu expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record LeakyRelu() : ActivationOp
public sealed partial class LeakyRelu : ActivationOp
{
/// <summary>
/// Gets input.
@ -93,7 +93,7 @@ public sealed record LeakyRelu() : ActivationOp
/// Celu expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Celu() : ActivationOp
public sealed partial class Celu : ActivationOp
{
/// <summary>
/// Gets input.
@ -110,7 +110,7 @@ public sealed record Celu() : ActivationOp
/// Selu expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Selu() : ActivationOp
public sealed partial class Selu : ActivationOp
{
/// <summary>
/// Gets input.
@ -132,7 +132,7 @@ public sealed record Selu() : ActivationOp
/// Elu expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Elu() : ActivationOp
public sealed partial class Elu : ActivationOp
{
/// <summary>
/// Gets input.
@ -149,7 +149,7 @@ public sealed record Elu() : ActivationOp
/// HardSwish expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record HardSwish() : ActivationOp
public sealed partial class HardSwish : ActivationOp
{
/// <summary>
/// Gets input.
@ -161,7 +161,7 @@ public sealed record HardSwish() : ActivationOp
/// HardSigmoid expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record HardSigmoid() : ActivationOp
public sealed partial class HardSigmoid : ActivationOp
{
/// <summary>
/// Gets input.
@ -183,7 +183,7 @@ public sealed record HardSigmoid() : ActivationOp
/// Erf expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Erf() : ActivationOp
public sealed partial class Erf : ActivationOp
{
/// <summary>
/// Gets input.

View File

@ -15,7 +15,7 @@ namespace Nncase.IR.NN;
/// BatchToSpace expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record BatchToSpace() : Op
public sealed partial class BatchToSpace : Op
{
/// <summary>
/// Gets input.

View File

@ -10,55 +10,56 @@ using System.Threading.Tasks;
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
namespace Nncase.IR.NN
namespace Nncase.IR.NN;
/// <summary>
/// Conv2D.
/// </summary>
[PatternFunctionalGenerator]
public sealed partial class Conv2D : Op
{
/// <summary>
/// Conv2D.
/// Gets input.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Conv2D(PadMode PadMode) : Op
{
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Conv2D), 0, "input");
public static readonly ParameterInfo Input = new(typeof(Conv2D), 0, "input");
/// <summary>
/// Gets Weights.
/// </summary>
public static readonly ParameterInfo Weights = new(typeof(Conv2D), 1, "weights", HasRank(4));
/// <summary>
/// Gets Weights.
/// </summary>
public static readonly ParameterInfo Weights = new(typeof(Conv2D), 1, "weights", HasRank(4));
/// <summary>
/// Gets Bias.
/// </summary>
public static readonly ParameterInfo Bias = new(typeof(Conv2D), 2, "bias", HasRank(1));
/// <summary>
/// Gets Bias.
/// </summary>
public static readonly ParameterInfo Bias = new(typeof(Conv2D), 2, "bias", HasRank(1));
/// <summary>
/// Gets Stride.
/// </summary>
public static readonly ParameterInfo Stride = new(typeof(Conv2D), 3, "stride", HasRank(1) & IsIntegral());
/// <summary>
/// Gets Stride.
/// </summary>
public static readonly ParameterInfo Stride = new(typeof(Conv2D), 3, "stride", HasRank(1) & IsIntegral());
/// <summary>
/// Gets Padding.
/// </summary>
public static readonly ParameterInfo Padding = new(typeof(Conv2D), 4, "padding", HasRank(2) & IsIntegral());
/// <summary>
/// Gets Padding.
/// </summary>
public static readonly ParameterInfo Padding = new(typeof(Conv2D), 4, "padding", HasRank(2) & IsIntegral());
/// <summary>
/// Gets Dilation.
/// </summary>
public static readonly ParameterInfo Dilation = new(typeof(Conv2D), 5, "dilation", HasRank(1) & IsIntegral());
/// <summary>
/// Gets Dilation.
/// </summary>
public static readonly ParameterInfo Dilation = new(typeof(Conv2D), 5, "dilation", HasRank(1) & IsIntegral());
/// <summary>
/// Gets Groups.
/// </summary>
public static readonly ParameterInfo Groups = new(typeof(Conv2D), 6, "groups", IsScalar() & IsIntegral());
/// <summary>
/// Gets Groups.
/// </summary>
public static readonly ParameterInfo Groups = new(typeof(Conv2D), 6, "groups", IsScalar() & IsIntegral());
/// <summary>
/// Gets FusedClamp.
/// </summary>
public static readonly ParameterInfo FusedClamp = new(typeof(Conv2D), 7, "fused_clamp", HasShape(new Shape(2)) & HasDataType(DataTypes.Float32));
/// <summary>
/// Gets FusedClamp.
/// </summary>
public static readonly ParameterInfo FusedClamp = new(typeof(Conv2D), 7, "fused_clamp", HasShape(new Shape(2)) & HasDataType(DataTypes.Float32));
/// <inheritdoc/>
public override string DisplayProperty() => $"PadMode.{PadMode}";
}
public PadMode PadMode { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"PadMode.{PadMode}";
}

View File

@ -11,65 +11,66 @@ using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
namespace Nncase.IR.NN
namespace Nncase.IR.NN;
/// <summary>
/// Conv2DTranspose.
/// </summary>
[PatternFunctionalGenerator]
public sealed partial class Conv2DTranspose : Op
{
/// <summary>
/// Conv2DTranspose.
/// Gets input.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Conv2DTranspose(PadMode PadMode) : Op
{
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Conv2DTranspose), 0, "input");
public static readonly ParameterInfo Input = new(typeof(Conv2DTranspose), 0, "input");
/// <summary>
/// Gets Weights.
/// </summary>
public static readonly ParameterInfo Weights = new(typeof(Conv2DTranspose), 1, "weights");
/// <summary>
/// Gets Weights.
/// </summary>
public static readonly ParameterInfo Weights = new(typeof(Conv2DTranspose), 1, "weights");
/// <summary>
/// Gets Bias.
/// </summary>
public static readonly ParameterInfo Bias = new(typeof(Conv2DTranspose), 2, "bias");
/// <summary>
/// Gets Bias.
/// </summary>
public static readonly ParameterInfo Bias = new(typeof(Conv2DTranspose), 2, "bias");
/// <summary>
/// Gets OutputShape.
/// </summary>
public static readonly ParameterInfo OutputShape = new(typeof(Conv2DTranspose), 3, "outputShape");
/// <summary>
/// Gets OutputShape.
/// </summary>
public static readonly ParameterInfo OutputShape = new(typeof(Conv2DTranspose), 3, "outputShape");
/// <summary>
/// Gets Stride.
/// </summary>
public static readonly ParameterInfo Stride = new(typeof(Conv2DTranspose), 4, "stride");
/// <summary>
/// Gets Stride.
/// </summary>
public static readonly ParameterInfo Stride = new(typeof(Conv2DTranspose), 4, "stride");
/// <summary>
/// Gets Padding.
/// </summary>
public static readonly ParameterInfo Padding = new(typeof(Conv2DTranspose), 5, "padding");
/// <summary>
/// Gets Padding.
/// </summary>
public static readonly ParameterInfo Padding = new(typeof(Conv2DTranspose), 5, "padding");
/// <summary>
/// Gets Output Padding.
/// </summary>
public static readonly ParameterInfo OutputPadding = new(typeof(Conv2DTranspose), 6, "output_padding");
/// <summary>
/// Gets Output Padding.
/// </summary>
public static readonly ParameterInfo OutputPadding = new(typeof(Conv2DTranspose), 6, "output_padding");
/// <summary>
/// Gets Dilation.
/// </summary>
public static readonly ParameterInfo Dilation = new(typeof(Conv2DTranspose), 7, "dilation");
/// <summary>
/// Gets Dilation.
/// </summary>
public static readonly ParameterInfo Dilation = new(typeof(Conv2DTranspose), 7, "dilation");
/// <summary>
/// Gets Groups.
/// </summary>
public static readonly ParameterInfo Groups = new(typeof(Conv2DTranspose), 8, "groups");
/// <summary>
/// Gets Groups.
/// </summary>
public static readonly ParameterInfo Groups = new(typeof(Conv2DTranspose), 8, "groups");
/// <summary>
/// Gets FusedClamp.
/// </summary>
public static readonly ParameterInfo FusedClamp = new(typeof(Conv2D), 9, "fused_clamp", HasShape(new Shape(2)) & HasDataType(DataTypes.Float32));
/// <summary>
/// Gets FusedClamp.
/// </summary>
public static readonly ParameterInfo FusedClamp = new(typeof(Conv2D), 9, "fused_clamp", HasShape(new Shape(2)) & HasDataType(DataTypes.Float32));
/// <inheritdoc/>
public override string DisplayProperty() => $"PadMode.{PadMode}";
}
public PadMode PadMode { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"PadMode.{PadMode}";
}

View File

@ -16,7 +16,7 @@ namespace Nncase.IR.NN;
/// Hardmax expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record Hardmax() : Op
public sealed partial class Hardmax : Op
{
/// <summary>
/// Gets input.

View File

@ -1,64 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
#if false
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
namespace Nncase.IR.NN
{
public enum lstm_direction
{
kForward,
kReverse,
kBidirectional
}
[PatternFunctionalGenerator]
public sealed record LSTM(lstm_direction direction, String framework) : Op
{
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(LSTM), 0, "input", HasDataType(DataTypes.Float32));
/// <summary>
/// Gets w.
/// </summary>
public static readonly ParameterInfo W = new(typeof(LSTM), 1, "w");
/// <summary>
/// Gets r.
/// </summary>
public static readonly ParameterInfo R = new(typeof(LSTM), 2, "r");
/// <summary>
/// Gets b.
/// </summary>
public static readonly ParameterInfo B = new(typeof(LSTM), 3, "b");
/// <summary>
/// Gets initial_h.
/// </summary>
public static readonly ParameterInfo initial_h = new(typeof(LSTM), 4, "initial_h",
HasDataType(DataTypes.Float32));
/// <summary>
/// Gets initial_c.
/// </summary>
public static readonly ParameterInfo initial_c = new(typeof(LSTM), 5, "initial_c",
HasDataType(DataTypes.Float32));
/// <summary>
/// Gets has_static.
/// </summary>
public static readonly ParameterInfo has_static = new(typeof(LSTM), 6, "has_static", IsBool());
}
}
#endif

View File

@ -16,7 +16,7 @@ namespace Nncase.IR.NN;
/// Hardmax expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record LayerNorm(int Axis, float Epsilon) : Op
public sealed partial class LayerNorm : Op
{
/// <summary>
/// Gets input.
@ -32,4 +32,8 @@ public sealed record LayerNorm(int Axis, float Epsilon) : Op
/// Gets bias.
/// </summary>
public static readonly ParameterInfo Bias = new(typeof(LayerNorm), 2, "bias");
public int Axis { get; }
public float Epsilon { get; }
}

View File

@ -8,7 +8,7 @@ using static Nncase.IR.TypePatternUtility;
namespace Nncase.IR.NN;
[PatternFunctionalGenerator]
public sealed record L2Normalization() : Op
public sealed partial class L2Normalization : Op
{
/// <summary>
/// Gets input.
@ -17,7 +17,7 @@ public sealed record L2Normalization() : Op
}
[PatternFunctionalGenerator]
public sealed record BatchNormalization() : Op
public sealed partial class BatchNormalization : Op
{
/// <summary>
/// Gets input.
@ -56,7 +56,7 @@ public sealed record BatchNormalization() : Op
}
[PatternFunctionalGenerator]
public sealed record InstanceNormalization() : Op
public sealed partial class InstanceNormalization : Op
{
/// <summary>
/// Gets input.
@ -80,7 +80,7 @@ public sealed record InstanceNormalization() : Op
}
[PatternFunctionalGenerator]
public sealed record LpNormalization() : Op
public sealed partial class LpNormalization : Op
{
/// <summary>
/// Gets input.
@ -99,7 +99,7 @@ public sealed record LpNormalization() : Op
}
[PatternFunctionalGenerator]
public sealed record LRN() : Op
public sealed partial class LRN : Op
{
/// <summary>
/// Gets input.

View File

@ -15,7 +15,7 @@ namespace Nncase.IR.NN;
/// OneHot expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed record OneHot(OneHotMode OneHotMode) : Op
public sealed partial class OneHot : Op
{
/// <summary>
/// Gets input.
@ -37,6 +37,8 @@ public sealed record OneHot(OneHotMode OneHotMode) : Op
/// </summary>
public static readonly ParameterInfo Axis = new(typeof(OneHot), 3, "axis");
public OneHotMode OneHotMode { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"OneHotMode.{OneHotMode}";
}

8
src/Nncase.Core/IR/NN/Pad.cs Executable file → Normal file
View File

@ -9,9 +9,8 @@ namespace Nncase.IR.NN;
/// <summary>
/// Pad tensor, a little difference with pytorch pad.
/// </summary>
/// <param name="PadMode">Pad mode.</param>
[PatternFunctionalGenerator]
public sealed record Pad(PadMode PadMode) : Op
public sealed partial class Pad : Op
{
/// <summary>
/// input.
@ -28,6 +27,11 @@ public sealed record Pad(PadMode PadMode) : Op
/// </summary>
public static readonly ParameterInfo Value = new(typeof(Pad), 2, "value", IsScalar());
/// <summary>
/// Gets pad mode.
/// </summary>
public PadMode PadMode { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"PadMode.{PadMode}";
}

Some files were not shown because too many files have changed in this diff Show More