mirror of https://github.com/kendryte/nncase.git
Refactor usage of IoC (#767)
* Use DryIoC * Fix all usage of ioc * Fix merged tests * Use MSBuildThisFileDirectory instead of SolutionDir * Fix pytest * Fix build * Fix build * Fix os.makedirs with exist_ok=True * Fix merge * Apply code-format changes * Add dump test * Fix dump test * Resolve review & Add more dump tests Co-authored-by: sunnycase <sunnycase@users.noreply.github.com>pull/774/head
parent
8f5fe4d477
commit
4b93ce843a
|
@ -533,7 +533,7 @@ dotnet_diagnostic.MA0097.severity = warning # MA0097: A class that implements
|
|||
dotnet_diagnostic.MA0098.severity = suggestion # MA0098: Use indexer instead of extension methods
|
||||
dotnet_diagnostic.MA0099.severity = warning # MA0099: Use Explicit enum value instead of 0
|
||||
MA0053.public_class_should_be_sealed = false
|
||||
end_of_line = lf
|
||||
end_of_line = crlf
|
||||
tab_width = 4
|
||||
dotnet_style_allow_multiple_blank_lines_experimental = true:silent
|
||||
dotnet_style_allow_statement_immediately_after_block_experimental = true:silent
|
||||
|
|
|
@ -1,63 +1,62 @@
|
|||
###############################################################################
|
||||
# Set default behavior to automatically normalize line endings.
|
||||
###############################################################################
|
||||
* text=auto
|
||||
|
||||
###############################################################################
|
||||
# Set default behavior for command prompt diff.
|
||||
#
|
||||
# This is need for earlier builds of msysgit that does not have it on by
|
||||
# default for csharp files.
|
||||
# Note: This is only used by command line
|
||||
###############################################################################
|
||||
#*.cs diff=csharp
|
||||
*.doc diff=astextplain
|
||||
*.DOC diff=astextplain
|
||||
*.docx diff=astextplain
|
||||
*.DOCX diff=astextplain
|
||||
*.dot diff=astextplain
|
||||
*.DOT diff=astextplain
|
||||
*.pdf diff=astextplain
|
||||
*.PDF diff=astextplain
|
||||
*.rtf diff=astextplain
|
||||
*.RTF diff=astextplain
|
||||
|
||||
###############################################################################
|
||||
# Set the merge driver for project and solution files
|
||||
#
|
||||
# Merging from the command prompt will add diff markers to the files if there
|
||||
# are conflicts (Merging from VS is not affected by the settings below, in VS
|
||||
# the diff markers are never inserted). Diff markers may cause the following
|
||||
# file extensions to fail to load in VS. An alternative would be to treat
|
||||
# these files as binary and thus will always conflict and require user
|
||||
# intervention with every merge. To do so, just uncomment the entries below
|
||||
###############################################################################
|
||||
#*.sln merge=binary
|
||||
#*.csproj merge=binary
|
||||
#*.vbproj merge=binary
|
||||
#*.vcxproj merge=binary
|
||||
#*.vcproj merge=binary
|
||||
#*.dbproj merge=binary
|
||||
#*.fsproj merge=binary
|
||||
#*.lsproj merge=binary
|
||||
#*.wixproj merge=binary
|
||||
#*.modelproj merge=binary
|
||||
#*.sqlproj merge=binary
|
||||
#*.wwaproj merge=binary
|
||||
*.jpg binary
|
||||
*.png binary
|
||||
*.gif binary
|
||||
|
||||
###############################################################################
|
||||
# behavior for image files
|
||||
#
|
||||
# image files are treated as binary by default.
|
||||
###############################################################################
|
||||
#*.jpg binary
|
||||
#*.png binary
|
||||
#*.gif binary
|
||||
# Force bash scripts to always use lf line endings so that if a repo is accessed
|
||||
# in Unix via a file share from Windows, the scripts will work.
|
||||
*.in text eol=lf
|
||||
*.sh text eol=lf
|
||||
|
||||
###############################################################################
|
||||
# diff behavior for common document formats
|
||||
#
|
||||
# Convert binary document formats to text before diffing them. This feature
|
||||
# is only available from the command line. Turn it on by uncommenting the
|
||||
# entries below.
|
||||
###############################################################################
|
||||
#*.doc diff=astextplain
|
||||
#*.DOC diff=astextplain
|
||||
#*.docx diff=astextplain
|
||||
#*.DOCX diff=astextplain
|
||||
#*.dot diff=astextplain
|
||||
#*.DOT diff=astextplain
|
||||
#*.pdf diff=astextplain
|
||||
#*.PDF diff=astextplain
|
||||
#*.rtf diff=astextplain
|
||||
#*.RTF diff=astextplain
|
||||
# Likewise, force cmd and batch scripts to always use crlf
|
||||
*.cmd text eol=crlf
|
||||
*.bat text eol=crlf
|
||||
|
||||
*.cs text=auto diff=csharp
|
||||
*.vb text=auto
|
||||
*.resx text=auto
|
||||
*.c text=auto
|
||||
*.cpp text=auto
|
||||
*.cxx text=auto
|
||||
*.h text=auto
|
||||
*.hxx text=auto
|
||||
*.py text=auto
|
||||
*.rb text=auto
|
||||
*.java text=auto
|
||||
*.html text=auto
|
||||
*.htm text=auto
|
||||
*.css text=auto
|
||||
*.scss text=auto
|
||||
*.sass text=auto
|
||||
*.less text=auto
|
||||
*.js text=auto
|
||||
*.lisp text=auto
|
||||
*.clj text=auto
|
||||
*.sql text=auto
|
||||
*.php text=auto
|
||||
*.lua text=auto
|
||||
*.m text=auto
|
||||
*.asm text=auto
|
||||
*.erl text=auto
|
||||
*.fs text=auto
|
||||
*.fsx text=auto
|
||||
*.hs text=auto
|
||||
|
||||
*.csproj text=auto
|
||||
*.vbproj text=auto
|
||||
*.fsproj text=auto
|
||||
*.dbproj text=auto
|
||||
*.sln text=auto eol=crlf
|
||||
|
|
|
@ -1,16 +1,19 @@
|
|||
<Project>
|
||||
<PropertyGroup>
|
||||
<ManagePackageVersionsCentrally>true</ManagePackageVersionsCentrally>
|
||||
<CodeAnalysisRuleSet>$(SolutionDir)/tools/StyleCopAnalyzers.ruleset</CodeAnalysisRuleSet>
|
||||
<CodeAnalysisRuleSet>$(MSBuildThisFileDirectory)/tools/StyleCopAnalyzers.ruleset</CodeAnalysisRuleSet>
|
||||
<Nullable>enable</Nullable>
|
||||
</PropertyGroup>
|
||||
<ItemGroup>
|
||||
<PackageVersion Include="AnyTensorFlow.NET" Version="0.70.1" />
|
||||
<PackageVersion Include="Autofac" Version="6.3.0" />
|
||||
<PackageVersion Include="Autofac.Extensions.DependencyInjection" Version="7.2.0" />
|
||||
<PackageVersion Include="BitFields" Version="0.1.0" />
|
||||
<PackageVersion Include="coverlet.collector" Version="3.0.2" />
|
||||
<PackageVersion Include="DryIoc.dll" Version="5.3.1" />
|
||||
<PackageVersion Include="DryIoc.Microsoft.DependencyInjection" Version="6.1.0" />
|
||||
<PackageVersion Include="Extension.Mathematics" Version="1.2.12" />
|
||||
<PackageVersion Include="Fody">
|
||||
<Version>6.6.4</Version>
|
||||
</PackageVersion>
|
||||
<PackageVersion Include="GiGraph.Dot" Version="2.0.0" />
|
||||
<PackageVersion Include="Google.Protobuf" Version="3.19.1" />
|
||||
<PackageVersion Include="Grpc.Tools" Version="2.42.0" />
|
||||
|
@ -25,10 +28,17 @@
|
|||
<!-- <PackageVersion Include="libtorch-cuda-11.3-win-x64" Version="1.11.0.1" /> -->
|
||||
<PackageVersion Include="MagicalTensorflowLib" Version="0.0.2" />
|
||||
<PackageVersion Include="MagicalTensorflowLibOSX-ARM64" Version="0.0.4" />
|
||||
<PackageVersion Include="MethodBoundaryAspect.Fody">
|
||||
<Version>2.0.148</Version>
|
||||
</PackageVersion>
|
||||
<PackageVersion Include="MethodDecorator.Fody">
|
||||
<Version>1.1.1</Version>
|
||||
</PackageVersion>
|
||||
<PackageVersion Include="Microsoft.CodeAnalysis.CSharp" Version="4.0.1" />
|
||||
<PackageVersion Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.3" />
|
||||
<PackageVersion Include="Microsoft.Extensions.Hosting" Version="6.0.0" />
|
||||
<PackageVersion Include="Microsoft.Extensions.Hosting.Abstractions" Version="6.0.0" />
|
||||
<PackageVersion Include="Microsoft.Extensions.Logging.Abstractions" Version="6.0.0" />
|
||||
<PackageVersion Include="Microsoft.Extensions.Options" Version="6.0.0" />
|
||||
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="16.9.4" />
|
||||
<PackageVersion Include="Microsoft.Toolkit.HighPerformance" Version="7.1.1" />
|
||||
|
@ -44,6 +54,9 @@
|
|||
<PackageVersion Include="TorchSharp" Version="0.99.0" />
|
||||
<PackageVersion Include="xunit" Version="2.4.1" />
|
||||
<PackageVersion Include="xunit.assert" Version="2.4.1" />
|
||||
<PackageVersion Include="Xunit.Combinatorial">
|
||||
<Version>1.5.25</Version>
|
||||
</PackageVersion>
|
||||
<PackageVersion Include="Xunit.DependencyInjection" Version="8.3.0" />
|
||||
<PackageVersion Include="xunit.runner.visualstudio" Version="2.4.3" />
|
||||
</ItemGroup>
|
||||
|
@ -54,6 +67,6 @@
|
|||
</PackageReference>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<AdditionalFiles Include="$(SolutionDir)/tools/stylecop.json" />
|
||||
<AdditionalFiles Include="$(MSBuildThisFileDirectory)/tools/stylecop.json" />
|
||||
</ItemGroup>
|
||||
</Project>
|
|
@ -7,6 +7,8 @@ using System.Linq;
|
|||
using System.Reflection;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using DryIoc;
|
||||
using Nncase.Hosting;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
|
@ -18,11 +20,10 @@ public static class K210ApplicationPart
|
|||
/// <summary>
|
||||
/// Add k210 assembly.
|
||||
/// </summary>
|
||||
/// <param name="assemblies">Assembly collection.</param>
|
||||
/// <returns>Updated assembly collection.</returns>
|
||||
public static IList<Assembly> AddK210(this IList<Assembly> assemblies)
|
||||
/// <param name="registrator">Service registrator.</param>
|
||||
/// <returns>Configured service registrator.</returns>
|
||||
public static IRegistrator AddK210(this IRegistrator registrator)
|
||||
{
|
||||
assemblies.Add(typeof(K210ApplicationPart).Assembly);
|
||||
return assemblies;
|
||||
return registrator.RegisterModule<K210Module>();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Autofac;
|
||||
using DryIoc;
|
||||
using Nncase.Evaluator.K210;
|
||||
using Nncase.Hosting;
|
||||
using Nncase.Targets;
|
||||
|
||||
namespace Nncase;
|
||||
|
@ -10,14 +11,13 @@ namespace Nncase;
|
|||
/// <summary>
|
||||
/// K210 module.
|
||||
/// </summary>
|
||||
public class K210Module : Module
|
||||
internal sealed class K210Module : IApplicationPart
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
protected override void Load(ContainerBuilder builder)
|
||||
public void ConfigureServices(IRegistrator registrator)
|
||||
{
|
||||
builder.RegisterType<K210Target>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<FakeKPUConv2DEvaluator>().AsImplementedInterfaces();
|
||||
builder.RegisterType<FakeKPUDownloadEvaluator>().AsImplementedInterfaces();
|
||||
builder.RegisterType<FakeKPUUploadEvaluator>().AsImplementedInterfaces();
|
||||
registrator.Register<ITarget, K210Target>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<FakeKPUConv2DEvaluator>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<FakeKPUDownloadEvaluator>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<FakeKPUUploadEvaluator>(reuse: Reuse.Singleton);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="$(SolutionDir)/tools/Nncase.SourceGenerator/Nncase.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
|
||||
<ProjectReference Include="../../tools/Nncase.SourceGenerator/Nncase.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
|
|
@ -32,36 +32,37 @@ public class K210Target : ITarget
|
|||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public void RegisterTargetDependentPass(PassManager passManager, CompileOptions options)
|
||||
public void RegisterTargetDependentPass(IPassManager passManager, CompileOptions options)
|
||||
{
|
||||
if (options.ModelQuantMode == ModelQuantMode.UsePTQ)
|
||||
if (options.QuantizeOptions.ModelQuantMode == ModelQuantMode.UsePTQ)
|
||||
{
|
||||
passManager.Add(new EGraphPassWithQuantize("lowering_kpu", options.QuantizeOptions!)
|
||||
passManager.Add<EGraphPassWithQuantize>().Configure(p =>
|
||||
{
|
||||
new LowerConv2D(),
|
||||
p.Name = "lowering_kpu";
|
||||
p.Add<LowerConv2D>();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public Task<Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>> BindQuantMethodCosine(ICalibrationDatasetProvider calibrationDataset, ITarget target, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, RunPassOptions runPassOptions)
|
||||
public Task<Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>> BindQuantMethodCosine(ICalibrationDatasetProvider calibrationDataset, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, QuantizeOptions quantizeOptions)
|
||||
{
|
||||
return null;
|
||||
return Task.FromResult(new Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>());
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public Task AdaRoundWeights(ICalibrationDatasetProvider calibrationDataset, ITarget target, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, RunPassOptions runPassOptions)
|
||||
public Task AdaRoundWeights(ICalibrationDatasetProvider calibrationDataset, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, QuantizeOptions quantizeOptions)
|
||||
{
|
||||
return null;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public void RegisterQuantizePass(PassManager passManager, CompileOptions options)
|
||||
public void RegisterQuantizePass(IPassManager passManager, CompileOptions options)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public void RegisterTargetDependentAfterQuantPass(PassManager passManager, CompileOptions options)
|
||||
public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, CompileOptions options)
|
||||
{
|
||||
}
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="$(SolutionDir)/tools/Nncase.SourceGenerator/Nncase.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
|
||||
<ProjectReference Include="../../tools/Nncase.SourceGenerator/Nncase.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
|
|
@ -7,6 +7,8 @@ using System.Linq;
|
|||
using System.Reflection;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using DryIoc;
|
||||
using Nncase.Hosting;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
|
@ -18,11 +20,10 @@ public static class StackVMApplicationPart
|
|||
/// <summary>
|
||||
/// Add stackVM assembly.
|
||||
/// </summary>
|
||||
/// <param name="assemblies">Assembly collection.</param>
|
||||
/// <returns>Updated assembly collection.</returns>
|
||||
public static IList<Assembly> AddStackVM(this IList<Assembly> assemblies)
|
||||
/// <param name="registrator">Service registrator.</param>
|
||||
/// <returns>Configured service registrator.</returns>
|
||||
public static IRegistrator AddStackVM(this IRegistrator registrator)
|
||||
{
|
||||
assemblies.Add(typeof(StackVMApplicationPart).Assembly);
|
||||
return assemblies;
|
||||
return registrator.RegisterModule<StackVMModule>();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Autofac;
|
||||
using DryIoc;
|
||||
using Nncase.Hosting;
|
||||
using Nncase.Targets;
|
||||
|
||||
namespace Nncase;
|
||||
|
@ -9,11 +10,10 @@ namespace Nncase;
|
|||
/// <summary>
|
||||
/// StackVM module.
|
||||
/// </summary>
|
||||
public class StackVMModule : Module
|
||||
internal class StackVMModule : IApplicationPart
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
protected override void Load(ContainerBuilder builder)
|
||||
public void ConfigureServices(IRegistrator registrator)
|
||||
{
|
||||
builder.RegisterType<CPUTarget>().AsImplementedInterfaces().SingleInstance();
|
||||
registrator.Register<ITarget, CPUTarget>(reuse: Reuse.Singleton);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,8 +21,12 @@ namespace Nncase.Targets;
|
|||
/// </summary>
|
||||
public class CPUTarget : ITarget
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
public string Kind => "cpu";
|
||||
/// <summary>
|
||||
/// Gets kind.
|
||||
/// </summary>
|
||||
public static readonly string Kind = "cpu";
|
||||
|
||||
string ITarget.Kind => Kind;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public void ParseTargetDependentOptions(IConfigurationSection configure)
|
||||
|
@ -30,30 +34,30 @@ public class CPUTarget : ITarget
|
|||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public void RegisterTargetDependentPass(PassManager passManager, CompileOptions options)
|
||||
public void RegisterTargetDependentPass(IPassManager passManager, CompileOptions options)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public Task<Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>> BindQuantMethodCosine(ICalibrationDatasetProvider calibrationDataset, ITarget target, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, RunPassOptions runPassOptions)
|
||||
public Task<Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>> BindQuantMethodCosine(ICalibrationDatasetProvider calibrationDataset, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, QuantizeOptions quantizeOptions)
|
||||
{
|
||||
var enodeQuantCosineDict = new Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>();
|
||||
return Task.FromResult(enodeQuantCosineDict);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public Task AdaRoundWeights(ICalibrationDatasetProvider calibrationDataset, ITarget target, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, RunPassOptions runPassOptions)
|
||||
public Task AdaRoundWeights(ICalibrationDatasetProvider calibrationDataset, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, QuantizeOptions quantizeOptions)
|
||||
{
|
||||
return null;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public void RegisterQuantizePass(PassManager passManager, CompileOptions options)
|
||||
public void RegisterQuantizePass(IPassManager passManager, CompileOptions options)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public void RegisterTargetDependentAfterQuantPass(PassManager passManager, CompileOptions options)
|
||||
public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, CompileOptions options)
|
||||
{
|
||||
}
|
||||
|
||||
|
|
24
nncase.sln
24
nncase.sln
|
@ -67,14 +67,17 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Modules.K210", "modu
|
|||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Quantization", "src\Nncase.Quantization\Nncase.Quantization.csproj", "{317C0D8F-75B3-4248-83E8-17AADDCF247A}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.TestFixture", "src\Nncase.TestFixture\Nncase.TestFixture.csproj", "{FF67E6C1-205B-49D6-BA83-2DC536F0D414}"
|
||||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "root", "root", "{EF1F8779-2B98-4E6F-A3DC-CA3FD2CADAD8}"
|
||||
ProjectSection(SolutionItems) = preProject
|
||||
.editorconfig = .editorconfig
|
||||
Directory.Packages.props = Directory.Packages.props
|
||||
NuGet.Config = NuGet.Config
|
||||
EndProjectSection
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Diagnostics", "src\Nncase.Diagnostics\Nncase.Diagnostics.csproj", "{E365B1B1-4D13-4839-9763-A7A7C5F32FD4}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Tests.TestFixture", "src\Nncase.Tests.TestFixture\Nncase.Tests.TestFixture.csproj", "{98A03405-CA53-4EC4-9B18-94D1C8DF9453}"
|
||||
EndProject
|
||||
Global
|
||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||
Debug|Any CPU = Debug|Any CPU
|
||||
|
@ -165,10 +168,14 @@ Global
|
|||
{317C0D8F-75B3-4248-83E8-17AADDCF247A}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{317C0D8F-75B3-4248-83E8-17AADDCF247A}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{317C0D8F-75B3-4248-83E8-17AADDCF247A}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{FF67E6C1-205B-49D6-BA83-2DC536F0D414}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{FF67E6C1-205B-49D6-BA83-2DC536F0D414}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{FF67E6C1-205B-49D6-BA83-2DC536F0D414}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{FF67E6C1-205B-49D6-BA83-2DC536F0D414}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{E365B1B1-4D13-4839-9763-A7A7C5F32FD4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{E365B1B1-4D13-4839-9763-A7A7C5F32FD4}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{E365B1B1-4D13-4839-9763-A7A7C5F32FD4}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{E365B1B1-4D13-4839-9763-A7A7C5F32FD4}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{98A03405-CA53-4EC4-9B18-94D1C8DF9453}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{98A03405-CA53-4EC4-9B18-94D1C8DF9453}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{98A03405-CA53-4EC4-9B18-94D1C8DF9453}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{98A03405-CA53-4EC4-9B18-94D1C8DF9453}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
EndGlobalSection
|
||||
GlobalSection(SolutionProperties) = preSolution
|
||||
HideSolutionNode = FALSE
|
||||
|
@ -198,9 +205,10 @@ Global
|
|||
{70D3FA34-B0B6-488F-812D-7E076CC0DFB3} = {9859F5E8-5504-4AFE-B955-9497A0A0CD66}
|
||||
{7390DF41-E804-4F12-B441-6EB5119C81BF} = {9859F5E8-5504-4AFE-B955-9497A0A0CD66}
|
||||
{317C0D8F-75B3-4248-83E8-17AADDCF247A} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
|
||||
{FF67E6C1-205B-49D6-BA83-2DC536F0D414} = {E5A4516C-4080-4346-991D-57A7AA76ADA6}
|
||||
{E365B1B1-4D13-4839-9763-A7A7C5F32FD4} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
|
||||
{98A03405-CA53-4EC4-9B18-94D1C8DF9453} = {E5A4516C-4080-4346-991D-57A7AA76ADA6}
|
||||
EndGlobalSection
|
||||
GlobalSection(ExtensibilityGlobals) = postSolution
|
||||
SolutionGuid = {B251C48E-7D05-4662-8DFF-20843CB2DB99}
|
||||
SolutionGuid = {9492E141-292E-4D60-9C6E-3738AB234DB2}
|
||||
EndGlobalSection
|
||||
EndGlobal
|
||||
|
|
|
@ -97,7 +97,7 @@ class GraphEvaluator:
|
|||
return self._outputs[index]
|
||||
|
||||
def run(self):
|
||||
self._outputs = self._func.body.evaluate(self._inputs, self._params).to_runtime_tensors()
|
||||
self._outputs = self._func.body.evaluate(self._params, self._inputs).to_runtime_tensors()
|
||||
|
||||
@ property
|
||||
def outputs_size(self) -> int:
|
||||
|
@ -117,18 +117,19 @@ class IRModule():
|
|||
|
||||
|
||||
class Compiler:
|
||||
_target: _nncase.Target
|
||||
_session: _nncase.CompileSession
|
||||
_compiler: _nncase.Compiler
|
||||
_compile_options: _nncase.CompileOptions
|
||||
_quantize_options: _nncase.QuantizeOptions
|
||||
_module: IRModule
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, compile_options: CompileOptions) -> None:
|
||||
self._compile_options = _nncase.CompileOptions()
|
||||
self._compiler = _nncase.Compiler(self._compile_options)
|
||||
self._quantize_options = None
|
||||
|
||||
def set_compile_options(self, compile_options: CompileOptions):
|
||||
self.__process_compile_options(compile_options)
|
||||
self._session = _nncase.CompileSession(self._target, self._compile_options)
|
||||
self._compiler = self._session.compiler
|
||||
self._quantize_options = None
|
||||
|
||||
def compile(self) -> None:
|
||||
self._compiler.compile()
|
||||
|
@ -168,14 +169,14 @@ class Compiler:
|
|||
self._quantize_options = _nncase.QuantizeOptions()
|
||||
self._compile_options.quantize_options = self._quantize_options
|
||||
self._quantize_options.calibration_dataset = provider
|
||||
self._compile_options.model_quant_mode = _nncase.ModelQuantMode.UsePTQ
|
||||
self._quantize_options.model_quant_mode = _nncase.ModelQuantMode.UsePTQ
|
||||
|
||||
def dump_range_options(self) -> DumpRangeTensorOptions:
|
||||
raise NotImplementedError("dump_range_options")
|
||||
|
||||
def __process_compile_options(self, compile_options: CompileOptions) -> ClCompileOptions:
|
||||
self._compile_options.target = compile_options.target
|
||||
self._compile_options.dump_level = 3 if compile_options.dump_ir == True else 0
|
||||
self._target = _nncase.Target(compile_options.target)
|
||||
self._compile_options.dump_flags = _nncase.DumpFlags.Nothing
|
||||
self._compile_options.dump_dir = compile_options.dump_dir
|
||||
|
||||
def _import_module(self, model_content: bytes | io.RawIOBase) -> None:
|
||||
|
@ -188,7 +189,7 @@ def check_target(target: str):
|
|||
return target in ["cpu", "k510", "k230"]
|
||||
|
||||
def target_exists(target: str):
|
||||
return _nncase.target_exists(target)
|
||||
return _nncase.Target.exists(target)
|
||||
|
||||
return test_target(target) and target_exists(target)
|
||||
|
||||
|
|
|
@ -48,11 +48,7 @@ PYBIND11_MODULE(_nncase, m) {
|
|||
true, std::memory_order_release);
|
||||
}));
|
||||
m.def("initialize", nncase_clr_initialize);
|
||||
m.def("launch_debugger", nncase_clr_launch_debugger);
|
||||
m.def("target_exists", [](std::string_view target_name) {
|
||||
return nncase_clr_target_exists(target_name.data(),
|
||||
target_name.length());
|
||||
});
|
||||
m.def("launch_debugger", []() { nncase_clr_api()->luanch_debugger(); });
|
||||
|
||||
#include "runtime_tensor.inl"
|
||||
|
||||
|
@ -61,28 +57,29 @@ PYBIND11_MODULE(_nncase, m) {
|
|||
.value("UsePTQ", nncase_mqm_use_ptq)
|
||||
.value("UseQAT", nncase_mqm_use_qat);
|
||||
|
||||
py::enum_<nncase_dump_flags_t>(m, "DumpFlags")
|
||||
.value("Nothing", nncase_dump_flags_none);
|
||||
|
||||
py::class_<compile_options>(m, "CompileOptions")
|
||||
.def(py::init())
|
||||
.def_property(
|
||||
"input_format", py::overload_cast<>(&compile_options::input_format),
|
||||
py::overload_cast<std::string_view>(&compile_options::input_format))
|
||||
.def_property(
|
||||
"target", py::overload_cast<>(&compile_options::target),
|
||||
py::overload_cast<std::string_view>(&compile_options::target))
|
||||
.def_property("dump_level",
|
||||
py::overload_cast<>(&compile_options::dump_level),
|
||||
py::overload_cast<int32_t>(&compile_options::dump_level))
|
||||
.def_property(
|
||||
"dump_dir", py::overload_cast<>(&compile_options::dump_dir),
|
||||
py::overload_cast<std::string_view>(&compile_options::dump_dir))
|
||||
.def_property("dump_flags",
|
||||
py::overload_cast<>(&compile_options::dump_flags),
|
||||
py::overload_cast<nncase_dump_flags_t>(
|
||||
&compile_options::dump_flags))
|
||||
.def_property("quantize_options",
|
||||
py::overload_cast<>(&compile_options::quantize_options),
|
||||
py::overload_cast<const quantize_options &>(
|
||||
&compile_options::quantize_options))
|
||||
.def_property("model_quant_mode",
|
||||
py::overload_cast<>(&compile_options::model_quant_mode),
|
||||
py::overload_cast<nncase_model_quant_mode_t>(
|
||||
&compile_options::model_quant_mode));
|
||||
&compile_options::quantize_options));
|
||||
|
||||
py::class_<target>(m, "Target")
|
||||
.def(py::init<std::string_view>())
|
||||
.def_static("exists", &target::exists);
|
||||
|
||||
py::class_<quantize_options>(m, "QuantizeOptions")
|
||||
.def(py::init())
|
||||
|
@ -90,7 +87,11 @@ PYBIND11_MODULE(_nncase, m) {
|
|||
"calibration_dataset",
|
||||
py::overload_cast<>(&quantize_options::calibration_dataset),
|
||||
py::overload_cast<const calibration_dataset_provider &>(
|
||||
&quantize_options::calibration_dataset));
|
||||
&quantize_options::calibration_dataset))
|
||||
.def_property("model_quant_mode",
|
||||
py::overload_cast<>(&quantize_options::model_quant_mode),
|
||||
py::overload_cast<nncase_model_quant_mode_t>(
|
||||
&quantize_options::model_quant_mode));
|
||||
|
||||
py::class_<calibration_dataset_provider>(m, "CalibrationDatasetProvider")
|
||||
.def(py::init([](py::list dataset, size_t samples_count,
|
||||
|
@ -138,21 +139,21 @@ PYBIND11_MODULE(_nncase, m) {
|
|||
}
|
||||
});
|
||||
|
||||
py::class_<expr>(m, "Expr").def("evaluate", [](expr &expr, py::list inputs,
|
||||
py::list params) {
|
||||
std::vector<clr_object_handle_t> input_handles(inputs.size());
|
||||
py::class_<expr>(m, "Expr").def("evaluate", [](expr &expr, py::list params,
|
||||
py::list inputs) {
|
||||
std::vector<clr_object_handle_t> param_handles(params.size());
|
||||
for (size_t i = 0; i < input_handles.size(); i++) {
|
||||
input_handles[i] = inputs[i].cast<rtvalue &>().get();
|
||||
}
|
||||
std::vector<clr_object_handle_t> input_handles(inputs.size());
|
||||
for (size_t i = 0; i < param_handles.size(); i++) {
|
||||
param_handles[i] = params[i].cast<var &>().get();
|
||||
}
|
||||
for (size_t i = 0; i < input_handles.size(); i++) {
|
||||
input_handles[i] = inputs[i].cast<rtvalue &>().get();
|
||||
}
|
||||
|
||||
array params_arr(nncase_array_var, param_handles.data(), inputs.size());
|
||||
array inputs_arr(nncase_array_rtvalue, input_handles.data(),
|
||||
inputs.size());
|
||||
array params_arr(nncase_array_var, param_handles.data(), inputs.size());
|
||||
return expr.evaluate(inputs_arr, params_arr);
|
||||
return expr.evaluate(params_arr, inputs_arr);
|
||||
});
|
||||
|
||||
py::class_<var, expr>(m, "Var");
|
||||
|
@ -167,11 +168,14 @@ PYBIND11_MODULE(_nncase, m) {
|
|||
.def_property_readonly("entry", &ir_module::entry);
|
||||
|
||||
py::class_<compiler>(m, "Compiler")
|
||||
.def(py::init<const compile_options &>())
|
||||
.def("import_module", &compiler::import_module)
|
||||
.def("compile", &compiler::compile)
|
||||
.def("gencode", &compiler::gencode);
|
||||
|
||||
py::class_<compile_session>(m, "CompileSession")
|
||||
.def(py::init<const target &, const compile_options &>())
|
||||
.def_property_readonly("compiler", &compile_session::compiler);
|
||||
|
||||
py::class_<interpreter>(m, "Simulator")
|
||||
.def(py::init())
|
||||
.def("load_model",
|
||||
|
|
|
@ -40,7 +40,8 @@ struct nncase_buffer_slice {
|
|||
uint32_t size_bytes;
|
||||
};
|
||||
|
||||
NNCASE_API int nncase_object_free(nncase::object_node *node);
|
||||
NNCASE_API int nncase_object_add_ref(nncase::object_node *node);
|
||||
NNCASE_API int nncase_object_release(nncase::object_node *node);
|
||||
|
||||
NNCASE_API int nncase_interp_create(nncase::runtime::interpreter **interp);
|
||||
NNCASE_API int nncase_interp_free(nncase::runtime::interpreter *interp);
|
||||
|
|
|
@ -43,6 +43,8 @@ typedef enum {
|
|||
nncase_calib_kld = 1
|
||||
} nncase_calib_method_t;
|
||||
|
||||
typedef enum { nncase_dump_flags_none = 0 } nncase_dump_flags_t;
|
||||
|
||||
typedef struct {
|
||||
void (*add_ref)(nncase_stream_handle_t handle);
|
||||
void (*release)(nncase_stream_handle_t handle);
|
||||
|
@ -60,101 +62,77 @@ typedef struct {
|
|||
size_t length);
|
||||
} nncase_stream_mt_t;
|
||||
|
||||
typedef struct {
|
||||
clr_object_handle_t (*array_create)(nncase_array_element_kind_t kind,
|
||||
const clr_object_handle_t *elements,
|
||||
size_t count);
|
||||
clr_object_handle_t (*array_get_item)(clr_object_handle_t array,
|
||||
size_t index);
|
||||
size_t (*array_get_length)(clr_object_handle_t array);
|
||||
clr_object_handle_t (*calibration_dataset_provider_create)(
|
||||
clr_object_handle_t dataset, size_t samplesCount,
|
||||
clr_object_handle_t fn_params);
|
||||
void (*handle_free)(clr_object_handle_t handle);
|
||||
clr_object_handle_t (*compile_options_create)();
|
||||
void (*compile_options_set_input_file)(clr_object_handle_t compile_options,
|
||||
const char *input_file,
|
||||
size_t input_file_length);
|
||||
void (*compile_options_set_input_format)(
|
||||
clr_object_handle_t compile_options, const char *input_format,
|
||||
size_t input_format_length);
|
||||
void (*compile_options_set_dump_dir)(clr_object_handle_t compile_options,
|
||||
const char *dump_dir,
|
||||
size_t dump_dir_length);
|
||||
void (*compile_options_set_dump_flags)(clr_object_handle_t compile_options,
|
||||
nncase_dump_flags_t dump_flags);
|
||||
void (*compile_options_set_quantize_options)(
|
||||
clr_object_handle_t compile_options,
|
||||
clr_object_handle_t quantize_options);
|
||||
clr_object_handle_t (*compile_session_create)(
|
||||
clr_object_handle_t target, clr_object_handle_t compile_options);
|
||||
clr_object_handle_t (*compile_session_get_compiler)(
|
||||
clr_object_handle_t compile_session);
|
||||
void (*compiler_initialize)();
|
||||
clr_object_handle_t (*compiler_import_module)(clr_object_handle_t compiler,
|
||||
clr_object_handle_t stream);
|
||||
void (*compiler_compile)(clr_object_handle_t compiler);
|
||||
void (*compiler_gencode)(clr_object_handle_t compiler,
|
||||
clr_object_handle_t stream);
|
||||
clr_object_handle_t (*datatype_from_typecode)(nncase::typecode_t typecode);
|
||||
clr_object_handle_t (*expr_evaluate)(clr_object_handle_t expr,
|
||||
clr_object_handle_t parameters,
|
||||
clr_object_handle_t inputs);
|
||||
clr_object_handle_t (*function_get_body)(clr_object_handle_t function);
|
||||
clr_object_handle_t (*function_get_parameters)(
|
||||
clr_object_handle_t function);
|
||||
clr_object_handle_t (*ir_module_get_entry)(clr_object_handle_t module);
|
||||
void (*luanch_debugger)();
|
||||
clr_object_handle_t (*quantize_options_create)();
|
||||
void (*quantize_options_set_calibration_dataset)(
|
||||
clr_object_handle_t quantize_options, clr_object_handle_t dataset);
|
||||
void (*quantize_options_set_calibration_method)(
|
||||
clr_object_handle_t quantize_options, nncase_calib_method_t method);
|
||||
void (*quantize_options_set_model_quant_mode)(
|
||||
clr_object_handle_t quantize_options,
|
||||
nncase_model_quant_mode_t model_quant_mode);
|
||||
void (*quantize_options_set_quant_type)(
|
||||
clr_object_handle_t quantize_options, clr_object_handle_t quant_type);
|
||||
clr_object_handle_t (*rtvalue_from_handle)(nncase::value_node *value);
|
||||
nncase::value_node *(*rtvalue_get_handle)(clr_object_handle_t rtvalue);
|
||||
clr_object_handle_t (*stream_create)(const nncase_stream_mt_t *mt,
|
||||
void *handle);
|
||||
clr_object_handle_t (*target_create)(const char *target_name,
|
||||
size_t target_name_length);
|
||||
bool (*target_exists)(const char *target_name, size_t target_name_length);
|
||||
} nncase_api_mt_t;
|
||||
|
||||
NNCASE_API nncase_api_mt_t *nncase_clr_api();
|
||||
NNCASE_API int nncase_clr_initialize(const char *root_assembly_path);
|
||||
NNCASE_API int nncase_clr_uninitialize();
|
||||
|
||||
NNCASE_API int nncase_clr_array_create(nncase_array_element_kind_t kind,
|
||||
const clr_object_handle_t *elements,
|
||||
size_t count,
|
||||
clr_object_handle_t *array);
|
||||
NNCASE_API int nncase_clr_array_get_item(clr_object_handle_t array,
|
||||
size_t index,
|
||||
clr_object_handle_t *item);
|
||||
NNCASE_API int nncase_clr_array_get_length(clr_object_handle_t array,
|
||||
size_t *length);
|
||||
|
||||
NNCASE_API int nncase_clr_calibration_dataset_provider_create(
|
||||
clr_object_handle_t dataset, size_t samplesCount,
|
||||
clr_object_handle_t fn_params, clr_object_handle_t *provider);
|
||||
|
||||
NNCASE_API int nncase_clr_handle_free(clr_object_handle_t handle);
|
||||
|
||||
NNCASE_API int
|
||||
nncase_clr_compile_options_create(clr_object_handle_t *compile_options);
|
||||
NNCASE_API int
|
||||
nncase_clr_compile_options_set_input_file(clr_object_handle_t compile_options,
|
||||
const char *input_file,
|
||||
size_t input_file_length);
|
||||
NNCASE_API int
|
||||
nncase_clr_compile_options_set_input_format(clr_object_handle_t compile_options,
|
||||
const char *input_format,
|
||||
size_t input_format_length);
|
||||
NNCASE_API int
|
||||
nncase_clr_compile_options_set_target(clr_object_handle_t compile_options,
|
||||
const char *target, size_t target_length);
|
||||
NNCASE_API int
|
||||
nncase_clr_compile_options_set_dump_level(clr_object_handle_t compile_options,
|
||||
int32_t dump_level);
|
||||
NNCASE_API int
|
||||
nncase_clr_compile_options_set_dump_dir(clr_object_handle_t compile_options,
|
||||
const char *dump_dir,
|
||||
size_t dump_dir_length);
|
||||
NNCASE_API int nncase_clr_compile_options_set_quantize_options(
|
||||
clr_object_handle_t compile_options, clr_object_handle_t quantize_options);
|
||||
NNCASE_API int
|
||||
nncase_clr_compile_options_set_quant_type(clr_object_handle_t compile_options,
|
||||
clr_object_handle_t quant_type);
|
||||
NNCASE_API int nncase_clr_compile_options_set_model_quant_mode(
|
||||
clr_object_handle_t compile_options,
|
||||
nncase_model_quant_mode_t model_quant_mode);
|
||||
|
||||
NNCASE_API int nncase_clr_compiler_create(clr_object_handle_t compile_options,
|
||||
clr_object_handle_t *compiler);
|
||||
NNCASE_API int nncase_clr_compiler_import_module(clr_object_handle_t compiler,
|
||||
clr_object_handle_t stream,
|
||||
clr_object_handle_t *module);
|
||||
NNCASE_API int nncase_clr_compiler_compile(clr_object_handle_t compiler);
|
||||
NNCASE_API int nncase_clr_compiler_gencode(clr_object_handle_t compiler,
|
||||
clr_object_handle_t stream);
|
||||
|
||||
NNCASE_API int nncase_clr_datatype_from_typecode(nncase::typecode_t typecode,
|
||||
clr_object_handle_t *datatype);
|
||||
|
||||
NNCASE_API int nncase_clr_expr_evaluate(clr_object_handle_t expr,
|
||||
clr_object_handle_t parameters,
|
||||
clr_object_handle_t inputs,
|
||||
clr_object_handle_t *result);
|
||||
|
||||
NNCASE_API int nncase_clr_function_get_body(clr_object_handle_t function,
|
||||
clr_object_handle_t *body);
|
||||
NNCASE_API int
|
||||
nncase_clr_function_get_parameters(clr_object_handle_t function,
|
||||
clr_object_handle_t *parameters);
|
||||
NNCASE_API int nncase_clr_ir_module_get_entry(clr_object_handle_t module,
|
||||
clr_object_handle_t *entry);
|
||||
|
||||
NNCASE_API int nncase_clr_launch_debugger();
|
||||
|
||||
NNCASE_API int
|
||||
nncase_clr_quantize_options_create(clr_object_handle_t *quantize_options);
|
||||
NNCASE_API int nncase_clr_quantize_options_set_calibration_dataset(
|
||||
clr_object_handle_t quantize_options, clr_object_handle_t dataset);
|
||||
NNCASE_API int nncase_clr_quantize_options_set_calibration_method(
|
||||
clr_object_handle_t quantize_options, nncase_calib_method_t method);
|
||||
|
||||
NNCASE_API int nncase_clr_rtvalue_from_handle(nncase::value_node *value,
|
||||
clr_object_handle_t *rtvalue);
|
||||
NNCASE_API int nncase_clr_rtvalue_get_handle(clr_object_handle_t rtvalue,
|
||||
nncase::value_node **value);
|
||||
|
||||
NNCASE_API int nncase_clr_stream_create(const nncase_stream_mt_t *mt,
|
||||
void *handle,
|
||||
clr_object_handle_t *stream);
|
||||
|
||||
NNCASE_API bool nncase_clr_target_exists(const char *target_name,
|
||||
size_t target_name_length);
|
||||
}
|
||||
|
||||
DEFINE_ENUM_BITMASK_OPERATORS(nncase_dump_flags_t)
|
||||
|
||||
namespace nncase::clr {
|
||||
class clr_object_ptr {
|
||||
public:
|
||||
|
@ -196,7 +174,7 @@ class clr_object_ptr {
|
|||
void release() {
|
||||
if (auto handle = handle_) {
|
||||
handle_ = nullptr;
|
||||
nncase_clr_handle_free(handle);
|
||||
nncase_clr_api()->handle_free(handle);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -204,16 +182,16 @@ class clr_object_ptr {
|
|||
clr_object_handle_t handle_;
|
||||
};
|
||||
|
||||
#define CHECK_CLR(x) \
|
||||
if (x) { \
|
||||
throw std::runtime_error(#x); \
|
||||
}
|
||||
#define CHECK_CLR(x) x
|
||||
|
||||
class clr_object_base {
|
||||
public:
|
||||
constexpr clr_object_base(std::nullptr_t = nullptr) noexcept
|
||||
: obj_(nullptr) {}
|
||||
|
||||
clr_object_base(std::in_place_t, clr_object_ptr ptr) noexcept
|
||||
: obj_(std::move(ptr)) {}
|
||||
|
||||
clr_object_base(clr_object_base &&) = default;
|
||||
clr_object_base &operator=(clr_object_base &&) = default;
|
||||
|
||||
|
@ -237,22 +215,15 @@ class array : public clr_object_base {
|
|||
|
||||
array(nncase_array_element_kind_t kind, const clr_object_handle_t *elements,
|
||||
size_t length) {
|
||||
CHECK_CLR(nncase_clr_array_create(kind, elements, length,
|
||||
obj_.release_and_addressof()));
|
||||
obj_ = nncase_clr_api()->array_create(kind, elements, length);
|
||||
}
|
||||
|
||||
template <class T = clr_object_base> T at(size_t index) {
|
||||
T value(nullptr);
|
||||
CHECK_CLR(nncase_clr_array_get_item(obj_.get(), index,
|
||||
value.release_and_addressof()));
|
||||
return value;
|
||||
return {std::in_place,
|
||||
nncase_clr_api()->array_get_item(obj_.get(), index)};
|
||||
}
|
||||
|
||||
size_t length() {
|
||||
size_t length;
|
||||
CHECK_CLR(nncase_clr_array_get_length(obj_.get(), &length));
|
||||
return length;
|
||||
}
|
||||
size_t length() { return nncase_clr_api()->array_get_length(obj_.get()); }
|
||||
|
||||
template <class T = clr_object_base> std::vector<T> to_vector() {
|
||||
std::vector<T> vector(length());
|
||||
|
@ -269,9 +240,8 @@ class calibration_dataset_provider : public clr_object_base {
|
|||
|
||||
calibration_dataset_provider(array dataset, size_t samples_count,
|
||||
array fn_params) {
|
||||
CHECK_CLR(nncase_clr_calibration_dataset_provider_create(
|
||||
dataset.get(), samples_count, fn_params.get(),
|
||||
obj_.release_and_addressof()));
|
||||
obj_ = nncase_clr_api()->calibration_dataset_provider_create(
|
||||
dataset.get(), samples_count, fn_params.get());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -279,15 +249,18 @@ class quantize_options : public clr_object_base {
|
|||
public:
|
||||
using clr_object_base::clr_object_base;
|
||||
|
||||
quantize_options() {
|
||||
CHECK_CLR(
|
||||
nncase_clr_quantize_options_create(obj_.release_and_addressof()));
|
||||
}
|
||||
quantize_options() { obj_ = nncase_clr_api()->quantize_options_create(); }
|
||||
|
||||
calibration_dataset_provider calibration_dataset() { return nullptr; }
|
||||
void calibration_dataset(const calibration_dataset_provider &value) {
|
||||
CHECK_CLR(nncase_clr_quantize_options_set_calibration_dataset(
|
||||
obj_.get(), value.get()));
|
||||
nncase_clr_api()->quantize_options_set_calibration_dataset(obj_.get(),
|
||||
value.get());
|
||||
}
|
||||
|
||||
nncase_model_quant_mode_t model_quant_mode() { return nncase_mqm_no_quant; }
|
||||
void model_quant_mode(nncase_model_quant_mode_t value) {
|
||||
nncase_clr_api()->quantize_options_set_model_quant_mode(obj_.get(),
|
||||
value);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -296,8 +269,7 @@ class cstream : public clr_object_base {
|
|||
using clr_object_base::clr_object_base;
|
||||
|
||||
cstream(const nncase_stream_mt_t *mt, void *handle) {
|
||||
CHECK_CLR(
|
||||
nncase_clr_stream_create(mt, handle, obj_.release_and_addressof()));
|
||||
obj_ = nncase_clr_api()->stream_create(mt, handle);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -305,44 +277,42 @@ class compile_options : public clr_object_base {
|
|||
public:
|
||||
using clr_object_base::clr_object_base;
|
||||
|
||||
compile_options() {
|
||||
CHECK_CLR(
|
||||
nncase_clr_compile_options_create(obj_.release_and_addressof()));
|
||||
}
|
||||
compile_options() { obj_ = nncase_clr_api()->compile_options_create(); }
|
||||
|
||||
std::string input_format() { return "cpu"; }
|
||||
void input_format(std::string_view value) {
|
||||
CHECK_CLR(nncase_clr_compile_options_set_input_format(
|
||||
obj_.get(), value.data(), value.length()));
|
||||
}
|
||||
|
||||
std::string target() { return "cpu"; }
|
||||
void target(std::string_view value) {
|
||||
CHECK_CLR(nncase_clr_compile_options_set_target(
|
||||
obj_.get(), value.data(), value.length()));
|
||||
}
|
||||
|
||||
int32_t dump_level() { return 0; }
|
||||
void dump_level(int32_t value) {
|
||||
CHECK_CLR(nncase_clr_compile_options_set_dump_level(obj_.get(), value));
|
||||
nncase_clr_api()->compile_options_set_input_format(
|
||||
obj_.get(), value.data(), value.length());
|
||||
}
|
||||
|
||||
std::string dump_dir() { return "cpu"; }
|
||||
void dump_dir(std::string_view value) {
|
||||
CHECK_CLR(nncase_clr_compile_options_set_dump_dir(
|
||||
obj_.get(), value.data(), value.length()));
|
||||
nncase_clr_api()->compile_options_set_dump_dir(obj_.get(), value.data(),
|
||||
value.length());
|
||||
}
|
||||
|
||||
nncase_dump_flags_t dump_flags() { return nncase_dump_flags_none; }
|
||||
void dump_flags(nncase_dump_flags_t value) {
|
||||
nncase_clr_api()->compile_options_set_dump_flags(obj_.get(), value);
|
||||
}
|
||||
|
||||
clr::quantize_options quantize_options() { return nullptr; }
|
||||
void quantize_options(const clr::quantize_options &value) {
|
||||
CHECK_CLR(nncase_clr_compile_options_set_quantize_options(obj_.get(),
|
||||
value.get()));
|
||||
nncase_clr_api()->compile_options_set_quantize_options(obj_.get(),
|
||||
value.get());
|
||||
}
|
||||
};
|
||||
|
||||
class target : public clr_object_base {
|
||||
public:
|
||||
using clr_object_base::clr_object_base;
|
||||
|
||||
static bool exists(std::string_view name) {
|
||||
return nncase_clr_api()->target_exists(name.data(), name.length());
|
||||
}
|
||||
|
||||
nncase_model_quant_mode_t model_quant_mode() { return nncase_mqm_no_quant; }
|
||||
void model_quant_mode(nncase_model_quant_mode_t value) {
|
||||
CHECK_CLR(
|
||||
nncase_clr_compile_options_set_model_quant_mode(obj_.get(), value));
|
||||
target(std::string_view name) {
|
||||
obj_ = nncase_clr_api()->target_create(name.data(), name.length());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -351,15 +321,12 @@ class rtvalue : public clr_object_base {
|
|||
using clr_object_base::clr_object_base;
|
||||
|
||||
rtvalue(nncase::value_t value) {
|
||||
CHECK_CLR(nncase_clr_rtvalue_from_handle(value.detach(),
|
||||
obj_.release_and_addressof()));
|
||||
obj_ = nncase_clr_api()->rtvalue_from_handle(value.get());
|
||||
}
|
||||
|
||||
value_t to_value() const {
|
||||
value_t value(nullptr);
|
||||
CHECK_CLR(nncase_clr_rtvalue_get_handle(obj_.get(),
|
||||
value.release_and_addressof()));
|
||||
return value;
|
||||
auto ptr = nncase_clr_api()->rtvalue_get_handle(obj_.get());
|
||||
return ptr;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -368,10 +335,8 @@ class expr : public clr_object_base {
|
|||
using clr_object_base::clr_object_base;
|
||||
|
||||
rtvalue evaluate(const array ¶ms, const array &inputs) {
|
||||
rtvalue value(nullptr);
|
||||
CHECK_CLR(nncase_clr_expr_evaluate(get(), params.get(), inputs.get(),
|
||||
value.release_and_addressof()));
|
||||
return value;
|
||||
return {std::in_place, nncase_clr_api()->expr_evaluate(
|
||||
get(), params.get(), inputs.get())};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -390,17 +355,12 @@ class function : public base_function {
|
|||
using base_function::base_function;
|
||||
|
||||
expr body() {
|
||||
expr body(nullptr);
|
||||
CHECK_CLR(
|
||||
nncase_clr_function_get_body(get(), body.release_and_addressof()));
|
||||
return body;
|
||||
return {std::in_place, nncase_clr_api()->function_get_body(get())};
|
||||
}
|
||||
|
||||
array parameters() {
|
||||
array params(nullptr);
|
||||
CHECK_CLR(nncase_clr_function_get_parameters(
|
||||
get(), params.release_and_addressof()));
|
||||
return params;
|
||||
return {std::in_place,
|
||||
nncase_clr_api()->function_get_parameters(get())};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -409,10 +369,7 @@ class ir_module : public clr_object_base {
|
|||
using clr_object_base::clr_object_base;
|
||||
|
||||
function entry() {
|
||||
function entry(nullptr);
|
||||
CHECK_CLR(nncase_clr_ir_module_get_entry(
|
||||
get(), entry.release_and_addressof()));
|
||||
return entry;
|
||||
return {std::in_place, nncase_clr_api()->ir_module_get_entry(get())};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -420,21 +377,29 @@ class compiler : public clr_object_base {
|
|||
public:
|
||||
using clr_object_base::clr_object_base;
|
||||
|
||||
compiler(const compile_options &options) {
|
||||
CHECK_CLR(nncase_clr_compiler_create(options.get(),
|
||||
obj_.release_and_addressof()));
|
||||
}
|
||||
|
||||
ir_module import_module(cstream &stream) {
|
||||
ir_module module(nullptr);
|
||||
CHECK_CLR(nncase_clr_compiler_import_module(
|
||||
get(), stream.get(), module.release_and_addressof()));
|
||||
return module;
|
||||
return {std::in_place,
|
||||
nncase_clr_api()->compiler_import_module(get(), stream.get())};
|
||||
}
|
||||
|
||||
void compile() { CHECK_CLR(nncase_clr_compiler_compile(obj_.get())); }
|
||||
void compile() { nncase_clr_api()->compiler_compile(obj_.get()); }
|
||||
void gencode(cstream &stream) {
|
||||
CHECK_CLR(nncase_clr_compiler_gencode(obj_.get(), stream.get()));
|
||||
nncase_clr_api()->compiler_gencode(obj_.get(), stream.get());
|
||||
}
|
||||
};
|
||||
|
||||
class compile_session : public clr_object_base {
|
||||
public:
|
||||
using clr_object_base::clr_object_base;
|
||||
|
||||
compile_session(const target &target, const compile_options &options) {
|
||||
obj_ = nncase_clr_api()->compile_session_create(target.get(),
|
||||
options.get());
|
||||
}
|
||||
|
||||
clr::compiler compiler() {
|
||||
return {std::in_place,
|
||||
nncase_clr_api()->compile_session_get_compiler(obj_.get())};
|
||||
}
|
||||
};
|
||||
} // namespace nncase::clr
|
||||
|
|
|
@ -59,7 +59,8 @@ class NNCASE_API object_node {
|
|||
uint32_t release() const noexcept;
|
||||
|
||||
template <class T> friend class object_t;
|
||||
friend int ::nncase_object_free(nncase::object_node *node);
|
||||
friend int ::nncase_object_add_ref(nncase::object_node *node);
|
||||
friend int ::nncase_object_release(nncase::object_node *node);
|
||||
|
||||
private:
|
||||
mutable std::atomic<uint32_t> ref_count_;
|
||||
|
|
|
@ -68,7 +68,13 @@ result<strides_t> to_strides(const uint32_t *strides, uint32_t length) {
|
|||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
int nncase_object_free(nncase::object_node *node) {
|
||||
int nncase_object_add_ref(nncase::object_node *node) {
|
||||
if (node)
|
||||
node->add_ref();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_object_release(nncase::object_node *node) {
|
||||
if (node)
|
||||
node->release();
|
||||
return 0;
|
||||
|
@ -178,10 +184,6 @@ int nncase_buffer_as_host(nncase::runtime::buffer_node *buffer,
|
|||
return -EINVAL;
|
||||
}
|
||||
|
||||
int nncase_buffer_free(nncase::runtime::buffer_node *buffer) {
|
||||
return nncase_object_free(buffer);
|
||||
}
|
||||
|
||||
int nncase_host_buffer_map(nncase::runtime::host_buffer_node *host_buffer,
|
||||
nncase::runtime::map_access_t access, void **data,
|
||||
uint32_t *bytes) {
|
||||
|
@ -205,10 +207,6 @@ int nncase_host_buffer_unmap(nncase::runtime::host_buffer_node *host_buffer) {
|
|||
return -EINVAL;
|
||||
}
|
||||
|
||||
int nncase_host_buffer_free(nncase::runtime::host_buffer_node *host_buffer) {
|
||||
return nncase_object_free(host_buffer);
|
||||
}
|
||||
|
||||
int nncase_dtype_create_prime(nncase::typecode_t typecode,
|
||||
nncase::datatype_node **dtype) {
|
||||
if (dtype) {
|
||||
|
@ -223,10 +221,6 @@ int nncase_dtype_get_typecode(nncase::datatype_node *dtype) {
|
|||
return dtype->typecode();
|
||||
}
|
||||
|
||||
int nncase_dtype_free(nncase::datatype_node *dtype) {
|
||||
return nncase_object_free(dtype);
|
||||
}
|
||||
|
||||
int nncase_value_is_tensor(nncase::value_node *value, bool *is_tensor) {
|
||||
if (value && is_tensor) {
|
||||
*is_tensor = value_t(value).is_a<tensor>();
|
||||
|
|
|
@ -38,67 +38,6 @@
|
|||
#define UNMANAGEDCALLERSONLY_METHOD ((const char_t *)-1)
|
||||
|
||||
namespace {
|
||||
struct c_api_mt {
|
||||
clr_object_handle_t (*array_create)(nncase_array_element_kind_t kind,
|
||||
const clr_object_handle_t *elements,
|
||||
size_t count);
|
||||
clr_object_handle_t (*array_get_item)(clr_object_handle_t array,
|
||||
size_t index);
|
||||
size_t (*array_get_length)(clr_object_handle_t array);
|
||||
clr_object_handle_t (*calibration_dataset_provider_create)(
|
||||
clr_object_handle_t dataset, size_t samplesCount,
|
||||
clr_object_handle_t fn_params);
|
||||
void (*handle_free)(clr_object_handle_t handle);
|
||||
clr_object_handle_t (*compile_options_create)();
|
||||
void (*compile_options_set_input_file)(clr_object_handle_t compile_options,
|
||||
const char *input_file,
|
||||
size_t input_file_length);
|
||||
void (*compile_options_set_input_format)(
|
||||
clr_object_handle_t compile_options, const char *input_format,
|
||||
size_t input_format_length);
|
||||
void (*compile_options_set_target)(clr_object_handle_t compile_options,
|
||||
const char *target,
|
||||
size_t target_length);
|
||||
void (*compile_options_set_dump_level)(clr_object_handle_t compile_options,
|
||||
int32_t dump_level);
|
||||
void (*compile_options_set_dump_dir)(clr_object_handle_t compile_options,
|
||||
const char *dump_dir,
|
||||
size_t dump_dir_length);
|
||||
void (*compile_options_set_quantize_options)(
|
||||
clr_object_handle_t compile_options,
|
||||
clr_object_handle_t quantize_options);
|
||||
void (*compile_options_set_quant_type)(clr_object_handle_t compile_options,
|
||||
clr_object_handle_t quant_type);
|
||||
void (*compile_options_set_model_quant_mode)(
|
||||
clr_object_handle_t compile_options,
|
||||
nncase_model_quant_mode_t model_quant_mode);
|
||||
void (*compiler_initialize)();
|
||||
clr_object_handle_t (*compiler_create)(clr_object_handle_t compile_options);
|
||||
clr_object_handle_t (*compiler_import_module)(clr_object_handle_t compiler,
|
||||
clr_object_handle_t stream);
|
||||
void (*compiler_compile)(clr_object_handle_t compiler);
|
||||
void (*compiler_gencode)(clr_object_handle_t compiler,
|
||||
clr_object_handle_t stream);
|
||||
clr_object_handle_t (*datatype_from_typecode)(nncase::typecode_t typecode);
|
||||
clr_object_handle_t (*expr_evaluate)(clr_object_handle_t expr,
|
||||
clr_object_handle_t parameters,
|
||||
clr_object_handle_t inputs);
|
||||
clr_object_handle_t (*function_get_body)(clr_object_handle_t function);
|
||||
clr_object_handle_t (*function_get_parameters)(
|
||||
clr_object_handle_t function);
|
||||
clr_object_handle_t (*ir_module_get_entry)(clr_object_handle_t module);
|
||||
void (*luanch_debugger)();
|
||||
clr_object_handle_t (*quantize_options_create)();
|
||||
void (*quantize_options_set_calibration_dataset)(
|
||||
clr_object_handle_t quantize_options, clr_object_handle_t dataset);
|
||||
void (*quantize_options_set_calibration_method)(
|
||||
clr_object_handle_t quantize_options, nncase_calib_method_t method);
|
||||
clr_object_handle_t (*rtvalue_from_handle)(nncase::value_node *value);
|
||||
nncase::value_node *(*rtvalue_get_handle)(clr_object_handle_t rtvalue);
|
||||
clr_object_handle_t (*stream_create)(const nncase_stream_mt_t *mt,
|
||||
void *handle);
|
||||
bool (*target_exists)(const char *target_name, size_t target_name_length);
|
||||
};
|
||||
|
||||
typedef int (*get_function_pointer_fn)(const char_t *type_name,
|
||||
const char_t *method_name,
|
||||
|
@ -106,7 +45,7 @@ typedef int (*get_function_pointer_fn)(const char_t *type_name,
|
|||
void *load_context, void *reserved,
|
||||
/*out*/ void **delegate);
|
||||
|
||||
typedef void (*c_api_initialize_fn)(c_api_mt *mt);
|
||||
typedef void (*c_api_initialize_fn)(nncase_api_mt_t *mt);
|
||||
|
||||
#ifdef WIN32
|
||||
#define THROW_WIN32_IF_NOT(x) \
|
||||
|
@ -295,14 +234,16 @@ load_compiler_c_api_initializer(const char *root_assembly_path) {
|
|||
return c_api_initialize;
|
||||
}
|
||||
|
||||
c_api_mt g_c_api_mt;
|
||||
nncase_api_mt_t g_nncase_api_mt;
|
||||
} // namespace
|
||||
|
||||
nncase_api_mt_t *nncase_clr_api() { return &g_nncase_api_mt; }
|
||||
|
||||
int nncase_clr_initialize(const char *root_assembly_path) {
|
||||
if (!g_c_api_mt.handle_free) {
|
||||
if (!g_nncase_api_mt.handle_free) {
|
||||
auto init = load_compiler_c_api_initializer(root_assembly_path);
|
||||
init(&g_c_api_mt);
|
||||
g_c_api_mt.compiler_initialize();
|
||||
init(&g_nncase_api_mt);
|
||||
g_nncase_api_mt.compiler_initialize();
|
||||
// SetUnhandledExceptionFilter(MyUnhandledExceptionFilter);
|
||||
}
|
||||
|
||||
|
@ -310,200 +251,6 @@ int nncase_clr_initialize(const char *root_assembly_path) {
|
|||
}
|
||||
|
||||
int nncase_clr_uninitialize() {
|
||||
g_c_api_mt = {};
|
||||
g_nncase_api_mt = {};
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_array_create(nncase_array_element_kind_t kind,
|
||||
const clr_object_handle_t *elements, size_t count,
|
||||
clr_object_handle_t *array) {
|
||||
*array = g_c_api_mt.array_create(kind, elements, count);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_array_get_item(clr_object_handle_t array, size_t index,
|
||||
clr_object_handle_t *item) {
|
||||
*item = g_c_api_mt.array_get_item(array, index);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_array_get_length(clr_object_handle_t array, size_t *length) {
|
||||
*length = g_c_api_mt.array_get_length(array);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_calibration_dataset_provider_create(
|
||||
clr_object_handle_t dataset, size_t samplesCount,
|
||||
clr_object_handle_t fn_params, clr_object_handle_t *provider) {
|
||||
*provider = g_c_api_mt.calibration_dataset_provider_create(
|
||||
dataset, samplesCount, fn_params);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_handle_free([[maybe_unused]] clr_object_handle_t handle) {
|
||||
if (g_c_api_mt.handle_free)
|
||||
g_c_api_mt.handle_free(handle);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_compile_options_create(clr_object_handle_t *compile_options) {
|
||||
*compile_options = g_c_api_mt.compile_options_create();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_compile_options_set_inputfile(
|
||||
clr_object_handle_t compile_options, const char *input_file,
|
||||
size_t input_file_length) {
|
||||
g_c_api_mt.compile_options_set_input_file(compile_options, input_file,
|
||||
input_file_length);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_compile_options_set_input_format(
|
||||
clr_object_handle_t compile_options, const char *input_format,
|
||||
size_t input_format_length) {
|
||||
g_c_api_mt.compile_options_set_input_format(compile_options, input_format,
|
||||
input_format_length);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_compile_options_set_target(clr_object_handle_t compile_options,
|
||||
const char *target,
|
||||
size_t target_length) {
|
||||
g_c_api_mt.compile_options_set_target(compile_options, target,
|
||||
target_length);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_compile_options_set_dump_level(
|
||||
clr_object_handle_t compile_options, int32_t dump_level) {
|
||||
g_c_api_mt.compile_options_set_dump_level(compile_options, dump_level);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_compile_options_set_dump_dir(clr_object_handle_t compile_options,
|
||||
const char *dump_dir,
|
||||
size_t dump_dir_length) {
|
||||
g_c_api_mt.compile_options_set_dump_dir(compile_options, dump_dir,
|
||||
dump_dir_length);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_compile_options_set_quantize_options(
|
||||
clr_object_handle_t compile_options, clr_object_handle_t quantize_options) {
|
||||
g_c_api_mt.compile_options_set_quantize_options(compile_options,
|
||||
quantize_options);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_compile_options_set_model_quant_mode(
|
||||
clr_object_handle_t compile_options,
|
||||
nncase_model_quant_mode_t model_quant_mode) {
|
||||
g_c_api_mt.compile_options_set_model_quant_mode(compile_options,
|
||||
model_quant_mode);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_compiler_create(clr_object_handle_t compile_options,
|
||||
clr_object_handle_t *compiler) {
|
||||
*compiler = g_c_api_mt.compiler_create(compile_options);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_compiler_import_module(clr_object_handle_t compiler,
|
||||
clr_object_handle_t stream,
|
||||
clr_object_handle_t *module) {
|
||||
*module = g_c_api_mt.compiler_import_module(compiler, stream);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_compiler_compile(clr_object_handle_t compiler) {
|
||||
g_c_api_mt.compiler_compile(compiler);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_compiler_gencode(clr_object_handle_t compiler,
|
||||
clr_object_handle_t stream) {
|
||||
g_c_api_mt.compiler_gencode(compiler, stream);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_datatype_from_typecode(nncase::typecode_t typecode,
|
||||
clr_object_handle_t *datatype) {
|
||||
*datatype = g_c_api_mt.datatype_from_typecode(typecode);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_expr_evaluate(clr_object_handle_t expr,
|
||||
clr_object_handle_t inputs,
|
||||
clr_object_handle_t parameters,
|
||||
clr_object_handle_t *result) {
|
||||
*result = g_c_api_mt.expr_evaluate(expr, parameters, inputs);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_function_get_body(clr_object_handle_t function,
|
||||
clr_object_handle_t *body) {
|
||||
*body = g_c_api_mt.function_get_body(function);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_function_get_parameters(clr_object_handle_t function,
|
||||
clr_object_handle_t *parameters) {
|
||||
*parameters = g_c_api_mt.function_get_parameters(function);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_ir_module_get_entry(clr_object_handle_t module,
|
||||
clr_object_handle_t *entry) {
|
||||
*entry = g_c_api_mt.ir_module_get_entry(module);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_launch_debugger() {
|
||||
g_c_api_mt.luanch_debugger();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_quantize_options_create(clr_object_handle_t *quantize_options) {
|
||||
*quantize_options = g_c_api_mt.quantize_options_create();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_quantize_options_set_calibration_dataset(
|
||||
clr_object_handle_t quantize_options, clr_object_handle_t dataset) {
|
||||
g_c_api_mt.quantize_options_set_calibration_dataset(quantize_options,
|
||||
dataset);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_quantize_options_set_calibration_method(
|
||||
clr_object_handle_t quantize_options, nncase_calib_method_t method) {
|
||||
g_c_api_mt.quantize_options_set_calibration_method(quantize_options,
|
||||
method);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_rtvalue_from_handle(nncase::value_node *value,
|
||||
clr_object_handle_t *rtvalue) {
|
||||
*rtvalue = g_c_api_mt.rtvalue_from_handle(nncase::value_t(value).detach());
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_rtvalue_get_handle(clr_object_handle_t rtvalue,
|
||||
nncase::value_node **value) {
|
||||
nncase::value_t v(g_c_api_mt.rtvalue_get_handle(rtvalue));
|
||||
*value = v.detach();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int nncase_clr_stream_create(const nncase_stream_mt_t *mt, void *handle,
|
||||
clr_object_handle_t *stream) {
|
||||
*stream = g_c_api_mt.stream_create(mt, handle);
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool nncase_clr_target_exists(const char *target_name,
|
||||
size_t target_name_length) {
|
||||
return g_c_api_mt.target_exists(target_name, target_name_length);
|
||||
}
|
||||
|
|
|
@ -17,10 +17,12 @@
|
|||
#include <nncase/api.h>
|
||||
#include <nncase/compiler.h>
|
||||
#include <nncase/io_utils.h>
|
||||
#include <string_view>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::clr;
|
||||
using namespace nncase::runtime;
|
||||
using namespace std::string_view_literals;
|
||||
|
||||
#define TRY(x) \
|
||||
if (x) \
|
||||
|
@ -29,11 +31,13 @@ using namespace nncase::runtime;
|
|||
int main() {
|
||||
nncase_clr_initialize(
|
||||
R"(E:\Work\Repos\nncase\src\Nncase.Compiler\bin\Debug\net6.0\Nncase.Compiler.dll)");
|
||||
clr_object_ptr compiler, compile_options;
|
||||
TRY(nncase_clr_compile_options_create(
|
||||
compile_options.release_and_addressof()));
|
||||
TRY(nncase_clr_compiler_create(compile_options.get(),
|
||||
compiler.release_and_addressof()));
|
||||
auto target_name = "cpu"sv;
|
||||
auto nncapi = nncase_clr_api();
|
||||
clr_object_ptr target, compile_session, compiler, compile_options;
|
||||
compile_options = nncapi->compile_options_create();
|
||||
target = nncapi->target_create(target_name.data(), target_name.length());
|
||||
nncapi->compile_session_create(target.get(), compile_options.get());
|
||||
compiler = nncapi->compile_session_get_compiler(compile_session.get());
|
||||
|
||||
auto kmodel = read_file(
|
||||
R"(E:\Work\Repos\nncase\src\Nncase.Tests\bin\Debug\net6.0\TestCallFunction.kmodel)");
|
||||
|
@ -62,7 +66,7 @@ int main() {
|
|||
nullptr));
|
||||
memcpy(x_buf_data, x, sizeof(x));
|
||||
TRY(nncase_host_buffer_unmap(x_host_buf));
|
||||
TRY(nncase_object_free((object_node *)x_host_buf));
|
||||
TRY(nncase_object_release((object_node *)x_host_buf));
|
||||
}
|
||||
|
||||
tensor_node *x_tensor;
|
||||
|
@ -95,14 +99,14 @@ int main() {
|
|||
std::cout << *ret_float_data << std::endl;
|
||||
|
||||
TRY(nncase_host_buffer_unmap(ret_host_buf));
|
||||
TRY(nncase_object_free((object_node *)ret_host_buf));
|
||||
TRY(nncase_object_release((object_node *)ret_host_buf));
|
||||
}
|
||||
|
||||
TRY(nncase_object_free((object_node *)out_buffer_slice.buffer));
|
||||
TRY(nncase_object_free((object_node *)ret));
|
||||
TRY(nncase_object_free((object_node *)x_buf));
|
||||
TRY(nncase_object_free((object_node *)x_tensor));
|
||||
TRY(nncase_object_free((object_node *)dtype_float32));
|
||||
TRY(nncase_object_release((object_node *)out_buffer_slice.buffer));
|
||||
TRY(nncase_object_release((object_node *)ret));
|
||||
TRY(nncase_object_release((object_node *)x_buf));
|
||||
TRY(nncase_object_release((object_node *)x_tensor));
|
||||
TRY(nncase_object_release((object_node *)dtype_float32));
|
||||
TRY(nncase_interp_free(interp));
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -7,12 +7,14 @@ using System.CommandLine;
|
|||
using System.CommandLine.Invocation;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using Autofac;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
using Nncase.CodeGen;
|
||||
using Nncase.Compiler;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using Nncase.Quantization;
|
||||
using Nncase.Transform;
|
||||
|
||||
namespace Nncase.Cli.Commands;
|
||||
|
@ -69,66 +71,83 @@ public class Compile : Command
|
|||
description: "model quant options, default is Random",
|
||||
getDefaultValue: () => Quantization.CalibMethod.Random));
|
||||
|
||||
Handler = CommandHandler.Create<CliCompileOptions, IHost>(Run);
|
||||
Handler = CommandHandler.Create<CliCompileOptions, IHost>(RunAsync);
|
||||
}
|
||||
|
||||
private void Run(CliCompileOptions cliOptions, IHost host)
|
||||
private static DumpFlags DumpLevelToFlags(int dumpLevel)
|
||||
{
|
||||
var provider = host.Services.GetRequiredService<ICompilerServicesProvider>();
|
||||
CompilerServices.Configure(provider);
|
||||
return dumpLevel switch
|
||||
{
|
||||
0 => DumpFlags.None,
|
||||
1 => DumpLevelToFlags(0) | DumpFlags.Compile,
|
||||
2 => DumpLevelToFlags(1) | DumpFlags.PassIR,
|
||||
3 => DumpLevelToFlags(2) | DumpFlags.Rewrite,
|
||||
4 => DumpLevelToFlags(3) | DumpFlags.EGraphCost,
|
||||
5 => DumpLevelToFlags(4) | DumpFlags.Evaluator,
|
||||
6 => DumpLevelToFlags(5) | DumpFlags.Calibration,
|
||||
7 => DumpLevelToFlags(6) | DumpFlags.Tiling,
|
||||
8 => DumpLevelToFlags(7) | DumpFlags.Schedule,
|
||||
>= 9 => DumpLevelToFlags(8) | DumpFlags.CodeGen,
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(dumpLevel)),
|
||||
};
|
||||
}
|
||||
|
||||
private async Task RunAsync(CliCompileOptions cliOptions, IHost host)
|
||||
{
|
||||
CompilerServices.Configure(host.Services);
|
||||
|
||||
// 1. setup the options
|
||||
Quantization.QuantizeOptions quant_options = new() { CalibrationMethod = cliOptions.CalibMethod };
|
||||
var compileOptions = new CompileOptions()
|
||||
var compileOptions = new CompileOptions
|
||||
{
|
||||
InputFile = cliOptions.InputFile,
|
||||
OutputFile = cliOptions.OutputFile,
|
||||
Target = cliOptions.Target,
|
||||
InputFormat = cliOptions.InputFormat,
|
||||
DumpLevel = cliOptions.DumpLevel,
|
||||
DumpFlags = DumpLevelToFlags(cliOptions.DumpLevel),
|
||||
DumpDir = cliOptions.DumpDir,
|
||||
QuantType = cliOptions.QuantType switch
|
||||
QuantizeOptions = new()
|
||||
{
|
||||
QuantType.UInt8 => DataTypes.UInt8,
|
||||
QuantType.Int8 => DataTypes.Int8,
|
||||
QuantType.Int16 => DataTypes.Int16,
|
||||
_ => throw new ArgumentOutOfRangeException(),
|
||||
CalibrationMethod = cliOptions.CalibMethod,
|
||||
QuantType = cliOptions.QuantType switch
|
||||
{
|
||||
QuantType.UInt8 => DataTypes.UInt8,
|
||||
QuantType.Int8 => DataTypes.Int8,
|
||||
QuantType.Int16 => DataTypes.Int16,
|
||||
_ => throw new ArgumentException("Invalid quant type"),
|
||||
},
|
||||
WQuantType = cliOptions.WQuantType switch
|
||||
{
|
||||
QuantType.UInt8 => DataTypes.UInt8,
|
||||
QuantType.Int8 => DataTypes.Int8,
|
||||
QuantType.Int16 => DataTypes.Int16,
|
||||
_ => throw new ArgumentException("Invalid weights quant type"),
|
||||
},
|
||||
ModelQuantMode = cliOptions.ModelQuantMode,
|
||||
},
|
||||
WQuantType = cliOptions.WQuantType switch
|
||||
{
|
||||
QuantType.UInt8 => DataTypes.UInt8,
|
||||
QuantType.Int8 => DataTypes.Int8,
|
||||
QuantType.Int16 => DataTypes.Int16,
|
||||
_ => throw new ArgumentOutOfRangeException(),
|
||||
},
|
||||
ModelQuantMode = cliOptions.ModelQuantMode,
|
||||
|
||||
// todo add the quant options parser
|
||||
QuantizeOptions = quant_options,
|
||||
};
|
||||
|
||||
// 2. import the model
|
||||
var compiler = new Compiler.Compiler(compileOptions);
|
||||
var target = CompilerServices.GetTarget(cliOptions.Target);
|
||||
using var compileSession = CompileSession.Create(target, compileOptions);
|
||||
var compiler = compileSession.Compiler;
|
||||
IRModule module;
|
||||
using (var model_stream = File.OpenRead(compileOptions.InputFile))
|
||||
{
|
||||
module = compiler.ImportModule(model_stream);
|
||||
module = await compiler.ImportModuleAsync(model_stream);
|
||||
}
|
||||
|
||||
// 3. create the calib dataset
|
||||
if (compileOptions.ModelQuantMode == Quantization.ModelQuantMode.UsePTQ)
|
||||
if (compileOptions.QuantizeOptions.ModelQuantMode == Quantization.ModelQuantMode.UsePTQ)
|
||||
{
|
||||
if (quant_options.CalibrationMethod == Quantization.CalibMethod.Random)
|
||||
if (compileOptions.QuantizeOptions.CalibrationMethod == Quantization.CalibMethod.Random)
|
||||
{
|
||||
quant_options.CalibrationDataset = new RandCalibrationDatasetProvider(((Function)module.Entry!).Parameters.ToArray());
|
||||
compileOptions.QuantizeOptions.CalibrationDataset = new RandomCalibrationDatasetProvider(((Function)module.Entry!).Parameters.ToArray(), 5);
|
||||
}
|
||||
}
|
||||
|
||||
// 4. compile
|
||||
compiler.Compile();
|
||||
await compiler.CompileAsync();
|
||||
|
||||
// 5. code gen
|
||||
using (var os = File.OpenWrite(compileOptions.OutputFile))
|
||||
using (var os = File.OpenWrite(cliOptions.OutputFile))
|
||||
{
|
||||
compiler.Gencode(os);
|
||||
}
|
||||
|
@ -167,31 +186,3 @@ internal sealed class CliCompileOptions
|
|||
/// <inheritdoc/>
|
||||
public Quantization.CalibMethod CalibMethod { get; set; }
|
||||
}
|
||||
|
||||
internal sealed class RandCalibrationDatasetProvider : Quantization.ICalibrationDatasetProvider
|
||||
{
|
||||
private const int CountValue = 5;
|
||||
|
||||
private readonly IReadOnlyDictionary<Var, IValue>[] _samples;
|
||||
|
||||
public RandCalibrationDatasetProvider(IEnumerable<Var> vars)
|
||||
{
|
||||
_samples = Enumerable.Range(0, CountValue).Select(i =>
|
||||
{
|
||||
var values = new Dictionary<Var, IValue>();
|
||||
foreach (var var in vars)
|
||||
{
|
||||
CompilerServices.InferenceType(var);
|
||||
var shape = var.CheckedShape.Select(d => d.IsUnknown ? 1 : d.FixedValue).ToArray();
|
||||
var value = IR.F.Random.Normal(var.CheckedDataType, 0, 1, 0, shape).Evaluate();
|
||||
values.Add(var, value);
|
||||
}
|
||||
|
||||
return values;
|
||||
}).ToArray();
|
||||
}
|
||||
|
||||
public int? Count => CountValue;
|
||||
|
||||
public IAsyncEnumerable<IReadOnlyDictionary<Var, IValue>> Samples => _samples.ToAsyncEnumerable();
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Autofac.Extensions.DependencyInjection" />
|
||||
<PackageReference Include="Microsoft.Extensions.Hosting" />
|
||||
<PackageReference Include="System.CommandLine.Hosting" />
|
||||
</ItemGroup>
|
||||
|
|
|
@ -1,35 +1,38 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.CommandLine.Builder;
|
||||
using System.CommandLine.Hosting;
|
||||
using System.CommandLine.Parsing;
|
||||
using System.IO;
|
||||
using System.Threading.Tasks;
|
||||
using Autofac;
|
||||
using Autofac.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Nncase.Hosting;
|
||||
|
||||
namespace Nncase.Cli
|
||||
namespace Nncase.Cli;
|
||||
|
||||
internal partial class Program
|
||||
{
|
||||
internal partial class Program
|
||||
public static async Task<int> Main(string[] args)
|
||||
{
|
||||
public static async Task<int> Main(string[] args)
|
||||
{
|
||||
return await BuildCommandLine()
|
||||
.UseHost(
|
||||
_ => CompilerHost.CreateHostBuilder(args),
|
||||
host =>
|
||||
{
|
||||
host.UseConsoleLifetime();
|
||||
})
|
||||
.UseDefaults()
|
||||
.Build().InvokeAsync(args);
|
||||
}
|
||||
return await BuildCommandLine()
|
||||
.UseHost(ConfigureHost)
|
||||
.UseDefaults()
|
||||
.Build().InvokeAsync(args);
|
||||
}
|
||||
|
||||
private static void ConfigureHost(IHostBuilder hostBuilder)
|
||||
{
|
||||
hostBuilder.ConfigureAppConfiguration(ConfigureAppConfiguration)
|
||||
.UseConsoleLifetime()
|
||||
.ConfigureCompiler();
|
||||
}
|
||||
|
||||
private static void ConfigureAppConfiguration(HostBuilderContext context, IConfigurationBuilder builder)
|
||||
{
|
||||
var baseDirectory = Path.GetDirectoryName(typeof(Program).Assembly.Location);
|
||||
builder.SetBasePath(baseDirectory)
|
||||
.AddJsonFile("config.json", true, false);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,9 +4,5 @@
|
|||
"Default": "Information",
|
||||
"Microsoft.Hosting.Lifetime": "Warning"
|
||||
}
|
||||
},
|
||||
"Testing": {
|
||||
"LogDir": "tests_output",
|
||||
"LogLevel": 4
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using DryIoc;
|
||||
using Nncase.CodeGen;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.Hosting;
|
||||
|
||||
namespace Nncase.Diagnostics;
|
||||
|
||||
/// <summary>
|
||||
/// CodeGen module.
|
||||
/// </summary>
|
||||
internal class CodeGenModule : IApplicationPart
|
||||
{
|
||||
public void ConfigureServices(IRegistrator registrator)
|
||||
{
|
||||
registrator.Register<IModelBuilder, ModelBuilder>(reuse: Reuse.Scoped);
|
||||
}
|
||||
}
|
|
@ -11,7 +11,7 @@ using Extension.Mathematics;
|
|||
|
||||
namespace Nncase.CodeGen;
|
||||
|
||||
public sealed class LinkedModel
|
||||
internal sealed class LinkedModel : ILinkedModel
|
||||
{
|
||||
private const int _minAlignmnet = 8;
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ namespace Nncase.CodeGen;
|
|||
/// <summary>
|
||||
/// The Kmodel Builder.
|
||||
/// </summary>
|
||||
public sealed class ModelBuilder
|
||||
public sealed class ModelBuilder : IModelBuilder
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ModelBuilder"/> class.
|
||||
|
@ -27,16 +27,6 @@ public sealed class ModelBuilder
|
|||
CompileOptions = compileOptions;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ModelBuilder"/> class.
|
||||
/// ctor from the global compile options.
|
||||
/// </summary>
|
||||
/// <param name="target"></param>
|
||||
public ModelBuilder(ITarget target)
|
||||
: this(target, CompilerServices.CompileOptions)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets get the Target.
|
||||
/// </summary>
|
||||
|
@ -47,7 +37,7 @@ public sealed class ModelBuilder
|
|||
/// </summary>
|
||||
public CompileOptions CompileOptions { get; }
|
||||
|
||||
public LinkedModel Build(IRModule module)
|
||||
public ILinkedModel Build(IRModule module)
|
||||
{
|
||||
var functionsByKind = module.Functions.GroupBy(x => x.ModuleKind).ToList();
|
||||
var functionIds = MakeFunctionsIds(functionsByKind);
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using DryIoc;
|
||||
using Nncase.CodeGen;
|
||||
using Nncase.Converters;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.Hosting;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
/// <summary>
|
||||
/// CodeGen application part extensions.
|
||||
/// </summary>
|
||||
public static class CodeGenApplicationPart
|
||||
{
|
||||
/// <summary>
|
||||
/// Add diagnostics assembly.
|
||||
/// </summary>
|
||||
/// <param name="registrator">Service registrator.</param>
|
||||
/// <returns>Configured service registrator.</returns>
|
||||
public static IRegistrator AddCodeGen(this IRegistrator registrator)
|
||||
{
|
||||
return registrator.RegisterModule<CodeGenModule>();
|
||||
}
|
||||
}
|
|
@ -1,13 +1,12 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Autofac;
|
||||
using Autofac.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Nncase.CodeGen;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.Evaluator;
|
||||
using Nncase.Hosting;
|
||||
using Nncase.IR;
|
||||
|
@ -16,152 +15,131 @@ using Nncase.Transform;
|
|||
using Nncase.Transform.Passes;
|
||||
using Nncase.Transform.Rules.Lower;
|
||||
using Nncase.Utilities;
|
||||
using OrtKISharp;
|
||||
|
||||
namespace Nncase.Compiler;
|
||||
|
||||
public class Compiler
|
||||
internal class Compiler : ICompiler
|
||||
{
|
||||
private readonly CompileOptions _compileOptions;
|
||||
private readonly CompileSession _compileSession;
|
||||
private readonly IModelBuilder _modelBuilder;
|
||||
private readonly IDumpper _dumpper;
|
||||
private IRModule? _module;
|
||||
|
||||
public Compiler(CompileOptions compileOptions)
|
||||
public Compiler(CompileSession compileSession, IModelBuilder modelBuilder, IDumpperFactory dumpperFactory)
|
||||
{
|
||||
_compileOptions = compileOptions;
|
||||
_compileSession = compileSession;
|
||||
_modelBuilder = modelBuilder;
|
||||
_dumpper = dumpperFactory.Root;
|
||||
}
|
||||
|
||||
public static void Initialize()
|
||||
{
|
||||
var iHost = CompilerHost.CreateHostBuilder().Build();
|
||||
var provider = iHost.Services.GetRequiredService<ICompilerServicesProvider>();
|
||||
CompilerServices.Configure(provider);
|
||||
}
|
||||
public IRModule Module => _module ?? throw new InvalidOperationException("Module has not been imported");
|
||||
|
||||
public IRModule ImportModule(Stream content)
|
||||
public async Task<IRModule> ImportModuleAsync(Stream content)
|
||||
{
|
||||
CompilerServices.CompileOptions = _compileOptions;
|
||||
|
||||
// Console.WriteLine($"Target: {options.Target}");
|
||||
var module = ImportModel(content);
|
||||
DumpModule(module, "ir_import");
|
||||
if (_dumpper.IsEnabled(DumpFlags.Compile))
|
||||
{
|
||||
_dumpper.DumpModule(module, "IRImport");
|
||||
}
|
||||
|
||||
// Console.WriteLine("Infer Shape...");
|
||||
if (_compileOptions.DumpLevel > 4)
|
||||
{
|
||||
DumpManager.RunWithDump("EvaluatorInShapeInfer", () => RunPass(pmg => pmg.Add(new ShapeInferPass()), "ShapeInferAfterImport"));
|
||||
}
|
||||
else
|
||||
{
|
||||
RunPass(pmg => pmg.Add(new ShapeInferPass()), "ShapeInferAfterImport");
|
||||
}
|
||||
await RunPassAsync(pmg => pmg.Add<ShapeInferPass>(), "ShapeInferAfterImport");
|
||||
|
||||
var inferSucc = CompilerServices.InferenceType(module.Entry!);
|
||||
DumpModule(module, "ir_infertype");
|
||||
if (!inferSucc)
|
||||
{
|
||||
throw new InvalidOperationException("InferShape Failed For This Model!");
|
||||
}
|
||||
|
||||
// Console.WriteLine("ImportModule successful!");
|
||||
return module;
|
||||
}
|
||||
|
||||
public void TargetIndependentPass(PassManager passManager)
|
||||
public void TargetIndependentPass(IPassManager passManager)
|
||||
{
|
||||
if (_compileOptions.ModelQuantMode == ModelQuantMode.UsePTQ)
|
||||
var quantMode = _compileSession.CompileOptions.QuantizeOptions.ModelQuantMode;
|
||||
if (quantMode == ModelQuantMode.UsePTQ)
|
||||
{
|
||||
passManager.Add(new EGraphPass("1_NeutralOptimize")
|
||||
passManager.AddWithName<EGraphPass>("NeutralOptimize").Configure(p =>
|
||||
{
|
||||
new Transform.Rules.Neutral.FoldConstCall(),
|
||||
new Transform.Rules.Neutral.FoldNopTranspose(),
|
||||
new Transform.Rules.Neutral.FoldTwoTransposes(),
|
||||
new Transform.Rules.Neutral.CombineTransposeUnary(),
|
||||
new Transform.Rules.Neutral.CombineTransposePad(),
|
||||
new Transform.Rules.Neutral.CombinePadTranspose(),
|
||||
new Transform.Rules.Neutral.CombineTransposeBinary(),
|
||||
new Transform.Rules.Neutral.CombineTransposeConstBinary(),
|
||||
new Transform.Rules.Neutral.CombineTransposeReduce(),
|
||||
new Transform.Rules.Neutral.CombineTransposeActivations(),
|
||||
new Transform.Rules.Neutral.CombineActivationsTranspose(),
|
||||
new Transform.Rules.Neutral.FoldNopPad(),
|
||||
new Transform.Rules.Neutral.FoldConv2DPads(),
|
||||
new Transform.Rules.Neutral.FoldReduceWindow2DPads(),
|
||||
p.Add<Transform.Rules.Neutral.FoldConstCall>();
|
||||
p.Add<Transform.Rules.Neutral.FoldNopTranspose>();
|
||||
p.Add<Transform.Rules.Neutral.FoldTwoTransposes>();
|
||||
p.Add<Transform.Rules.Neutral.CombineTransposeUnary>();
|
||||
p.Add<Transform.Rules.Neutral.CombineTransposePad>();
|
||||
p.Add<Transform.Rules.Neutral.CombinePadTranspose>();
|
||||
p.Add<Transform.Rules.Neutral.CombineTransposeBinary>();
|
||||
p.Add<Transform.Rules.Neutral.CombineTransposeConstBinary>();
|
||||
p.Add<Transform.Rules.Neutral.CombineTransposeReduce>();
|
||||
p.Add<Transform.Rules.Neutral.CombineTransposeActivations>();
|
||||
p.Add<Transform.Rules.Neutral.CombineActivationsTranspose>();
|
||||
p.Add<Transform.Rules.Neutral.FoldNopPad>();
|
||||
p.Add<Transform.Rules.Neutral.FoldConv2DPads>();
|
||||
p.Add<Transform.Rules.Neutral.FoldReduceWindow2DPads>();
|
||||
});
|
||||
}
|
||||
|
||||
if (_compileOptions.ModelQuantMode == ModelQuantMode.UsePTQ)
|
||||
if (quantMode == ModelQuantMode.UsePTQ)
|
||||
{
|
||||
passManager.Add(new DataflowPass("2_AddRangeOfMarker")
|
||||
passManager.AddWithName<DataflowPass>("AddRangeOfMarker").Configure(p =>
|
||||
{
|
||||
new Transform.Rules.Neutral.AddRangeOfAndMarkerToConv2D(),
|
||||
new Transform.Rules.Neutral.AddRangeOfAndMarkerToMatMul(),
|
||||
new Transform.Rules.Neutral.AddRangeOfAndMarkerToReduceWindow2D(),
|
||||
new Transform.Rules.Neutral.AddRangeOfAndMarkerToConv2DTranspose(),
|
||||
new Transform.Rules.Neutral.AddRangeOfAndMarkerToBinary(),
|
||||
p.Add<Transform.Rules.Neutral.AddRangeOfAndMarkerToConv2D>();
|
||||
p.Add<Transform.Rules.Neutral.AddRangeOfAndMarkerToMatMul>();
|
||||
p.Add<Transform.Rules.Neutral.AddRangeOfAndMarkerToReduceWindow2D>();
|
||||
p.Add<Transform.Rules.Neutral.AddRangeOfAndMarkerToConv2DTranspose>();
|
||||
p.Add<Transform.Rules.Neutral.AddRangeOfAndMarkerToBinary>();
|
||||
});
|
||||
passManager.Add(new Quantization.EGraphPassWithQuantize("3_AssignRanges", _compileOptions.QuantizeOptions!));
|
||||
passManager.AddWithName<EGraphPassWithQuantize>("AssignRanges");
|
||||
}
|
||||
}
|
||||
|
||||
public void Compile()
|
||||
public async Task CompileAsync()
|
||||
{
|
||||
var t = CompilerServices.GetTarget(_compileOptions.Target);
|
||||
if (_compileOptions.DumpLevel > 4)
|
||||
{
|
||||
DumpManager.RunWithDump("TargetIndependentEval", () => RunPass(p => TargetIndependentPass(p), "TargetIndependentPass"));
|
||||
}
|
||||
else
|
||||
{
|
||||
RunPass(p => TargetIndependentPass(p), "TargetIndependentPass");
|
||||
}
|
||||
var target = _compileSession.Target;
|
||||
await RunPassAsync(p => TargetIndependentPass(p), "TargetIndependentPass");
|
||||
await RunPassAsync(p => target.RegisterTargetDependentPass(p, _compileSession.CompileOptions), "TargetIndependentPass");
|
||||
|
||||
RunPass(p => t.RegisterTargetDependentPass(p, _compileOptions), "TargetDependentPass");
|
||||
|
||||
// RunPass(p => p.Add(new Quantization.EGraphPassWithBindQuantizeConfig("2.5_BindQuantizeConfig", options.QuantizeOptions!)));
|
||||
if (_compileOptions.ModelQuantMode == ModelQuantMode.UsePTQ)
|
||||
if (_compileSession.CompileOptions.QuantizeOptions.ModelQuantMode == ModelQuantMode.UsePTQ)
|
||||
{
|
||||
RunPass(p => t.RegisterQuantizePass(p, _compileOptions), "QuantizePass");
|
||||
RunPass(p => t.RegisterTargetDependentAfterQuantPass(p, _compileOptions), "TargetDependentAfterQuantPass");
|
||||
RunPass(t => t.Add(new DataflowPass("ClearMarker") { new RemoveMarker() }), "RemoveMarker");
|
||||
await RunPassAsync(p => target.RegisterQuantizePass(p, _compileSession.CompileOptions), "QuantizePass");
|
||||
await RunPassAsync(p => target.RegisterTargetDependentAfterQuantPass(p, _compileSession.CompileOptions), "TargetDependentAfterQuantPass");
|
||||
await RunPassAsync(
|
||||
pmgr => pmgr.Add<DataflowPass>().Configure(p =>
|
||||
{
|
||||
p.Name = "ClearMarker";
|
||||
p.Add<RemoveMarker>();
|
||||
}),
|
||||
"RemoveMarker");
|
||||
}
|
||||
|
||||
// fold constant
|
||||
RunPass(p => p.Add(new Transform.Passes.ShapeInferPass()), "ShapeInferAfterCompile");
|
||||
|
||||
// Console.WriteLine("Compile successful");
|
||||
await RunPassAsync(p => p.Add<ShapeInferPass>(), "ShapeInferAfterCompile");
|
||||
}
|
||||
|
||||
public void Gencode(Stream output)
|
||||
{
|
||||
var target = CompilerServices.GetTarget(_compileOptions.Target);
|
||||
var moduleBuilder = new ModelBuilder(target, _compileOptions);
|
||||
var linkedModel = moduleBuilder.Build(_module);
|
||||
var linkedModel = _modelBuilder.Build(Module);
|
||||
linkedModel.Serialize(output);
|
||||
|
||||
// Console.WriteLine("Gencode successful");
|
||||
}
|
||||
|
||||
private IRModule ImportModel(Stream content)
|
||||
{
|
||||
_module = _compileOptions.InputFormat switch
|
||||
_module = _compileSession.CompileOptions.InputFormat switch
|
||||
{
|
||||
"tflite" => Importers.ImportTFLite(content, _compileOptions),
|
||||
"onnx" => Importers.ImportOnnx(content, _compileOptions),
|
||||
_ => throw new NotImplementedException($"Not Implement {_compileOptions.InputFormat} Impoter!"),
|
||||
"tflite" => Importers.ImportTFLite(content, _compileSession),
|
||||
"onnx" => Importers.ImportOnnx(content, _compileSession),
|
||||
var inputFormat => throw new NotImplementedException($"Not Implement {inputFormat} Importer!"),
|
||||
};
|
||||
return _module;
|
||||
}
|
||||
|
||||
private void DumpModule(IRModule module, string prefix)
|
||||
private async Task RunPassAsync(Action<IPassManager> register, string name)
|
||||
{
|
||||
var dumpPath = Path.Combine(_compileOptions.DumpDir, "dump", prefix);
|
||||
CompilerServices.DumpIR(module.Entry!, prefix, dumpPath);
|
||||
}
|
||||
|
||||
private void RunPass(Action<PassManager> register, string dirName)
|
||||
{
|
||||
var pmgr = new PassManager(_module, new RunPassOptions(CompilerServices.GetTarget(_compileOptions.Target), _compileOptions.DumpLevel, Path.Join(_compileOptions.DumpDir, dirName), _compileOptions));
|
||||
var pmgr = _compileSession.CreatePassManager(name);
|
||||
register(pmgr);
|
||||
pmgr.RunAsync().Wait();
|
||||
_module = await pmgr.RunAsync(Module).ConfigureAwait(false);
|
||||
|
||||
if (_dumpper.IsEnabled(DumpFlags.Compile))
|
||||
{
|
||||
_dumpper.DumpModule(_module, name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using DryIoc;
|
||||
|
||||
namespace Nncase.Compiler.Hosting;
|
||||
|
||||
/// <summary>
|
||||
/// Compiler builder.
|
||||
/// </summary>
|
||||
public interface ICompilerBuilder
|
||||
{
|
||||
/// <summary>
|
||||
/// Configure modules.
|
||||
/// </summary>
|
||||
/// <param name="configureModules">Configure modules action.</param>
|
||||
/// <returns>Compiler builder.</returns>
|
||||
ICompilerBuilder ConfigureModules(Action<IRegistrator> configureModules);
|
||||
}
|
||||
|
||||
internal sealed class CompilerBuilder : ICompilerBuilder
|
||||
{
|
||||
private readonly IContainer _container;
|
||||
|
||||
public CompilerBuilder(IContainer container)
|
||||
{
|
||||
_container = container;
|
||||
}
|
||||
|
||||
public ICompilerBuilder ConfigureModules(Action<IRegistrator> configureModules)
|
||||
{
|
||||
configureModules(_container);
|
||||
return this;
|
||||
}
|
||||
}
|
|
@ -1,70 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Autofac;
|
||||
using Autofac.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Nncase.Hosting;
|
||||
|
||||
namespace Nncase.Hosting;
|
||||
|
||||
/// <summary>
|
||||
/// Compiler host helper.
|
||||
/// </summary>
|
||||
public static class CompilerHost
|
||||
{
|
||||
/// <summary>
|
||||
/// Create compiler host builder.
|
||||
/// </summary>
|
||||
/// <param name="args">Commandline arguments.</param>
|
||||
/// <returns>Created host builder.</returns>
|
||||
public static IHostBuilder CreateHostBuilder(string[]? args = null)
|
||||
{
|
||||
var host = Host.CreateDefaultBuilder(args);
|
||||
host.ConfigureAppConfiguration(ConfigureAppConfiguration)
|
||||
.UseServiceProviderFactory(new AutofacServiceProviderFactory())
|
||||
.ConfigureContainer<ContainerBuilder>(ConfigureContainer)
|
||||
.ConfigureServices(ConfigureServices)
|
||||
.ConfigureLogging(ConfigureLogging);
|
||||
return host;
|
||||
}
|
||||
|
||||
private static void ConfigureContainer(ContainerBuilder builder)
|
||||
{
|
||||
var assemblies = ApplicationParts.LoadApplicationParts(c =>
|
||||
{
|
||||
c.AddCore()
|
||||
.AddEvaluator()
|
||||
.AddGraph()
|
||||
.AddEGraph()
|
||||
.AddStackVM();
|
||||
});
|
||||
builder.RegisterAssemblyModules(assemblies);
|
||||
}
|
||||
|
||||
private static void ConfigureServices(HostBuilderContext context, IServiceCollection services)
|
||||
{
|
||||
services.AddLogging();
|
||||
}
|
||||
|
||||
private static void ConfigureAppConfiguration(HostBuilderContext context, IConfigurationBuilder builder)
|
||||
{
|
||||
builder.SetBasePath(Directory.GetCurrentDirectory())
|
||||
.AddJsonFile("config.json", true, false);
|
||||
}
|
||||
|
||||
private static void ConfigureLogging(ILoggingBuilder loggingBuilder)
|
||||
{
|
||||
loggingBuilder.ClearProviders();
|
||||
loggingBuilder.AddConsole();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using DryIoc;
|
||||
using DryIoc.Microsoft.DependencyInjection;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Nncase;
|
||||
using Nncase.Compiler;
|
||||
using Nncase.Compiler.Hosting;
|
||||
using Nncase.Hosting;
|
||||
|
||||
namespace Microsoft.Extensions.Hosting;
|
||||
|
||||
/// <summary>
|
||||
/// Compiler host builder extentensions.
|
||||
/// </summary>
|
||||
public static class CompilerHostBuilderExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Configure compiler host builder.
|
||||
/// </summary>
|
||||
/// <param name="hostBuilder">Host builder.</param>
|
||||
/// <param name="configureCompiler">Configure compiler builder.</param>
|
||||
/// <returns>Created host builder.</returns>
|
||||
public static IHostBuilder ConfigureCompiler(this IHostBuilder hostBuilder, Action<ICompilerBuilder>? configureCompiler = null)
|
||||
{
|
||||
hostBuilder.UseServiceProviderFactory(new DryIocServiceProviderFactory())
|
||||
.ConfigureContainer<Container>(ConfigureBuiltinModules)
|
||||
.ConfigureServices(ConfigureServices)
|
||||
.ConfigureLogging(ConfigureLogging);
|
||||
|
||||
hostBuilder.ConfigureContainer<Container>(x => configureCompiler?.Invoke(new CompilerBuilder(x)));
|
||||
hostBuilder.ConfigureContainer<Container>(ConfigurePlugins);
|
||||
return hostBuilder;
|
||||
}
|
||||
|
||||
private static void ConfigureBuiltinModules(Container builder)
|
||||
{
|
||||
builder.AddCore()
|
||||
.AddDiagnostics()
|
||||
.AddEvaluator()
|
||||
.AddGraph()
|
||||
.AddEGraph()
|
||||
.AddCodeGen()
|
||||
.AddStackVM()
|
||||
.AddK210();
|
||||
}
|
||||
|
||||
private static void ConfigureServices(HostBuilderContext context, IServiceCollection services)
|
||||
{
|
||||
services.AddLogging();
|
||||
|
||||
services.AddSingleton<PluginLoader>();
|
||||
services.AddScoped<ICompiler, Compiler>();
|
||||
}
|
||||
|
||||
private static void ConfigureLogging(ILoggingBuilder loggingBuilder)
|
||||
{
|
||||
loggingBuilder.ClearProviders();
|
||||
loggingBuilder.AddConsole();
|
||||
}
|
||||
|
||||
private static void ConfigurePlugins(Container builder)
|
||||
{
|
||||
var pluginLoader = builder.Resolve<PluginLoader>();
|
||||
var plugins = pluginLoader.LoadPlugins();
|
||||
foreach (var plugin in plugins)
|
||||
{
|
||||
plugin.ConfigureServices(builder);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using System.Reflection.Metadata;
|
||||
using System.Reflection.PortableExecutable;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.Loader;
|
||||
using LanguageExt;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Nncase.Compiler;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase.Hosting;
|
||||
|
||||
/// <summary>
|
||||
/// Plugin loader helper.
|
||||
/// </summary>
|
||||
public sealed class PluginLoader
|
||||
{
|
||||
private const string _modulesDllPattern = "Nncase.Modules.*.dll";
|
||||
private const string _pluginPathEnvName = "NNCASE_PLUGIN_PATH";
|
||||
|
||||
private static readonly string[] _builtinModules = new[]
|
||||
{
|
||||
"Nncase.Modules.StackVM.dll",
|
||||
"Nncase.Modules.K210.dll",
|
||||
};
|
||||
|
||||
private readonly ILogger<PluginLoader> _logger;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PluginLoader"/> class.
|
||||
/// </summary>
|
||||
/// <param name="logger">Logger.</param>
|
||||
public PluginLoader(ILogger<PluginLoader> logger)
|
||||
{
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load plugins.
|
||||
/// </summary>
|
||||
/// <returns>Plugins.</returns>
|
||||
public IReadOnlyList<IPlugin> LoadPlugins()
|
||||
{
|
||||
var pluginAsms = GetPluginsSearchDirectories().Select(GetPluginAssemblies).SelectMany(x => x)
|
||||
.DistinctBy(Path.GetFileName).Select(Assembly.LoadFrom).Distinct().ToList();
|
||||
var plugins = (from asm in pluginAsms
|
||||
from t in asm.ExportedTypes
|
||||
where t.IsClass
|
||||
&& t.IsAssignableTo(typeof(IPlugin))
|
||||
let ctor = t.GetConstructor(Type.EmptyTypes)
|
||||
where ctor != null
|
||||
select (IPlugin)ctor.Invoke(null)).ToList();
|
||||
return plugins;
|
||||
}
|
||||
|
||||
private static bool IsLoadableAssembly(string filePath)
|
||||
{
|
||||
using var fs = File.OpenRead(filePath);
|
||||
using var peReader = new PEReader(fs);
|
||||
|
||||
if (!peReader.HasMetadata)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var metaReader = peReader.GetMetadataReader();
|
||||
if (!metaReader.IsAssembly)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Is reference assembly
|
||||
if ((from cah in metaReader.CustomAttributes
|
||||
let ca = metaReader.GetCustomAttribute(cah)
|
||||
let ctor = metaReader.GetMemberReference((MemberReferenceHandle)ca.Constructor)
|
||||
let attrType = metaReader.GetTypeReference((TypeReferenceHandle)ctor.Parent)
|
||||
where metaReader.GetString(attrType.Namespace) == nameof(System.Runtime.CompilerServices)
|
||||
&& metaReader.GetString(attrType.Name) == nameof(ReferenceAssemblyAttribute)
|
||||
select cah).Any())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private IEnumerable<string> GetPluginAssemblies(string basePath)
|
||||
{
|
||||
return (from filePath in Directory.GetFiles(basePath, _modulesDllPattern, SearchOption.AllDirectories)
|
||||
where !_builtinModules.Contains(Path.GetFileName(filePath))
|
||||
&& IsLoadableAssembly(filePath)
|
||||
select filePath).Distinct();
|
||||
}
|
||||
|
||||
private IEnumerable<string> GetPluginsSearchDirectories()
|
||||
{
|
||||
var directories = new List<string>();
|
||||
|
||||
// 1. Environment variable
|
||||
var targetPathEnv = Environment.GetEnvironmentVariable(_pluginPathEnvName);
|
||||
if (string.IsNullOrWhiteSpace(targetPathEnv))
|
||||
{
|
||||
_logger.LogWarning($"{_pluginPathEnvName} is not set.");
|
||||
}
|
||||
else
|
||||
{
|
||||
var targetPaths = from path in targetPathEnv.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries)
|
||||
select Environment.ExpandEnvironmentVariables(path);
|
||||
directories.AddRange(targetPaths);
|
||||
}
|
||||
|
||||
if (_logger.IsEnabled(LogLevel.Trace))
|
||||
{
|
||||
_logger.LogInformation($"Loading plugins from {string.Join(", ", directories)}.");
|
||||
}
|
||||
|
||||
return directories.Distinct();
|
||||
}
|
||||
}
|
|
@ -10,11 +10,14 @@ using System.Runtime.CompilerServices;
|
|||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.Hosting;
|
||||
using Nncase.IR;
|
||||
using Nncase.Quantization;
|
||||
using Nncase.Runtime;
|
||||
using Nncase.Runtime.Interop;
|
||||
using static Nncase.Compiler.PythonHelper;
|
||||
|
||||
namespace Nncase.Compiler.Interop;
|
||||
|
||||
|
@ -41,14 +44,12 @@ public unsafe struct CApiMT
|
|||
public delegate* unmanaged<IntPtr> CompileOptionsCreatePtr;
|
||||
public delegate* unmanaged<IntPtr, byte*, nuint, void> CompileOptionsSetInputFilePtr;
|
||||
public delegate* unmanaged<IntPtr, byte*, nuint, void> CompileOptionsSetInputFormatPtr;
|
||||
public delegate* unmanaged<IntPtr, byte*, nuint, void> CompileOptionsSetTargetPtr;
|
||||
public delegate* unmanaged<IntPtr, int, void> CompileOptionsSetDumpLevelPtr;
|
||||
public delegate* unmanaged<IntPtr, byte*, nuint, void> CompileOptionsSetDumpDirPtr;
|
||||
public delegate* unmanaged<IntPtr, DumpFlags, void> CompileOptionsSetDumpFlagsPtr;
|
||||
public delegate* unmanaged<IntPtr, IntPtr, void> CompileOptionsSetQuantizeOptionsPtr;
|
||||
public delegate* unmanaged<IntPtr, IntPtr, void> CompileOptionsSetQuantTypePtr;
|
||||
public delegate* unmanaged<IntPtr, ModelQuantMode, void> CompileOptionsSetModelQuantModePtr;
|
||||
public delegate* unmanaged<IntPtr, IntPtr, IntPtr> CompileSessionCreatePtr;
|
||||
public delegate* unmanaged<IntPtr, IntPtr> CompileSessionGetCompilerPtr;
|
||||
public delegate* unmanaged<void> CompilerInitializePtr;
|
||||
public delegate* unmanaged<IntPtr, IntPtr> CompilerCreatePtr;
|
||||
public delegate* unmanaged<IntPtr, IntPtr, IntPtr> CompilerImportModulePtr;
|
||||
public delegate* unmanaged<IntPtr, void> CompilerCompilePtr;
|
||||
public delegate* unmanaged<IntPtr, IntPtr, void> CompilerGencodePtr;
|
||||
|
@ -61,9 +62,12 @@ public unsafe struct CApiMT
|
|||
public delegate* unmanaged<IntPtr> QuantizeOptionsCreatePtr;
|
||||
public delegate* unmanaged<IntPtr, IntPtr, void> QuantizeOptionsSetCalibrationDatasetPtr;
|
||||
public delegate* unmanaged<IntPtr, CalibMethod, void> QuantizeOptionsSetCalibrationMethodPtr;
|
||||
public delegate* unmanaged<IntPtr, ModelQuantMode, void> QuantizeOptionsSetModelQuantModePtr;
|
||||
public delegate* unmanaged<IntPtr, IntPtr, void> QuantizeOptionsSetQuantTypePtr;
|
||||
public delegate* unmanaged<IntPtr, IntPtr> RTValueFromHandlePtr;
|
||||
public delegate* unmanaged<IntPtr, IntPtr> RTValueGetHandlePtr;
|
||||
public delegate* unmanaged<CStreamMT*, IntPtr, IntPtr> StreamCreatePtr;
|
||||
public delegate* unmanaged<byte*, nuint, IntPtr> TargetCreatePtr;
|
||||
public delegate* unmanaged<byte*, nuint, byte> TargetExistsPtr;
|
||||
}
|
||||
|
||||
|
@ -83,14 +87,12 @@ public static unsafe class CApi
|
|||
mt->CompileOptionsCreatePtr = &CompileOptionsCreate;
|
||||
mt->CompileOptionsSetInputFilePtr = &CompileOptionsSetInputFile;
|
||||
mt->CompileOptionsSetInputFormatPtr = &CompileOptionsSetInputFormat;
|
||||
mt->CompileOptionsSetTargetPtr = &CompileOptionsSetTarget;
|
||||
mt->CompileOptionsSetDumpLevelPtr = &CompileOptionsSetDumpLevel;
|
||||
mt->CompileOptionsSetDumpDirPtr = &CompileOptionsSetDumpDir;
|
||||
mt->CompileOptionsSetDumpFlagsPtr = &CompileOptionsSetDumpFlags;
|
||||
mt->CompileOptionsSetQuantizeOptionsPtr = &CompileOptionsSetQuantizeOptions;
|
||||
mt->CompileOptionsSetQuantTypePtr = &CompileOptionsSetQuantType;
|
||||
mt->CompileOptionsSetModelQuantModePtr = &CompileOptionsSetModelQuantMode;
|
||||
mt->CompileSessionCreatePtr = &CompileSessionCreate;
|
||||
mt->CompileSessionGetCompilerPtr = &CompileSessionGetCompiler;
|
||||
mt->CompilerInitializePtr = &CompilerInitialize;
|
||||
mt->CompilerCreatePtr = &CompilerCreate;
|
||||
mt->CompilerImportModulePtr = &CompilerImportModule;
|
||||
mt->CompilerCompilePtr = &CompilerCompile;
|
||||
mt->CompilerGencodePtr = &CompilerGencode;
|
||||
|
@ -103,9 +105,12 @@ public static unsafe class CApi
|
|||
mt->QuantizeOptionsCreatePtr = &QuantizeOptionsCreate;
|
||||
mt->QuantizeOptionsSetCalibrationDatasetPtr = &QuantizeOptionsSetCalibrationDataset;
|
||||
mt->QuantizeOptionsSetCalibrationMethodPtr = &QuantizeOptionsSetCalibrationMethod;
|
||||
mt->QuantizeOptionsSetModelQuantModePtr = &QuantizeOptionsSetModelQuantMode;
|
||||
mt->QuantizeOptionsSetQuantTypePtr = &QuantizeOptionsSetQuantType;
|
||||
mt->RTValueFromHandlePtr = &RTValueFromHandle;
|
||||
mt->RTValueGetHandlePtr = &RTValueGetHandle;
|
||||
mt->StreamCreatePtr = &StreamCreate;
|
||||
mt->TargetCreatePtr = &TargetCreate;
|
||||
mt->TargetExistsPtr = &TargetExists;
|
||||
}
|
||||
|
||||
|
@ -191,15 +196,9 @@ public static unsafe class CApi
|
|||
}
|
||||
|
||||
[UnmanagedCallersOnly]
|
||||
private static void CompileOptionsSetTarget(IntPtr compileOptionsHandle, byte* targetPtr, nuint targetLength)
|
||||
private static void CompileOptionsSetDumpFlags(IntPtr compileOptionsHandle, DumpFlags dumpFlags)
|
||||
{
|
||||
Get<CompileOptions>(compileOptionsHandle).Target = ToString(targetPtr, targetLength);
|
||||
}
|
||||
|
||||
[UnmanagedCallersOnly]
|
||||
private static void CompileOptionsSetDumpLevel(IntPtr compileOptionsHandle, int dumpLevel)
|
||||
{
|
||||
Get<CompileOptions>(compileOptionsHandle).DumpLevel = dumpLevel;
|
||||
Get<CompileOptions>(compileOptionsHandle).DumpFlags = dumpFlags;
|
||||
}
|
||||
|
||||
[UnmanagedCallersOnly]
|
||||
|
@ -215,27 +214,27 @@ public static unsafe class CApi
|
|||
}
|
||||
|
||||
[UnmanagedCallersOnly]
|
||||
private static void CompileOptionsSetQuantType(IntPtr compileOptionsHandle, IntPtr quantTypeHandle)
|
||||
private static IntPtr CompileSessionCreate(IntPtr targetHandle, IntPtr compileOptionsHandle)
|
||||
{
|
||||
Get<CompileOptions>(compileOptionsHandle).QuantType = Get<DataType>(quantTypeHandle);
|
||||
var target = Get<ITarget>(targetHandle);
|
||||
var compileOptions = Get<CompileOptions>(compileOptionsHandle);
|
||||
return GCHandle.ToIntPtr(GCHandle.Alloc(CompileSession.Create(target, compileOptions)));
|
||||
}
|
||||
|
||||
[UnmanagedCallersOnly]
|
||||
private static void CompileOptionsSetModelQuantMode(IntPtr compileOptionsHandle, ModelQuantMode quantMode)
|
||||
private static IntPtr CompileSessionGetCompiler(IntPtr compileSessionHandle)
|
||||
{
|
||||
Get<CompileOptions>(compileOptionsHandle).ModelQuantMode = quantMode;
|
||||
var compileSession = Get<CompileSession>(compileSessionHandle);
|
||||
return GCHandle.ToIntPtr(GCHandle.Alloc(compileSession.Compiler));
|
||||
}
|
||||
|
||||
[UnmanagedCallersOnly]
|
||||
private static void CompilerInitialize()
|
||||
{
|
||||
Compiler.Initialize();
|
||||
}
|
||||
|
||||
[UnmanagedCallersOnly]
|
||||
private static IntPtr CompilerCreate(IntPtr compileOptionsHandle)
|
||||
{
|
||||
return GCHandle.ToIntPtr(GCHandle.Alloc(new Compiler(Get<CompileOptions>(compileOptionsHandle))));
|
||||
var host = Host.CreateDefaultBuilder()
|
||||
.ConfigureCompiler()
|
||||
.Build();
|
||||
CompilerServices.Configure(host.Services);
|
||||
}
|
||||
|
||||
[UnmanagedCallersOnly]
|
||||
|
@ -243,7 +242,7 @@ public static unsafe class CApi
|
|||
{
|
||||
var compiler = Get<Compiler>(compilerHandle);
|
||||
var stream = Get<CStream>(streamHandle);
|
||||
var module = compiler.ImportModule(stream);
|
||||
var module = compiler.ImportModuleAsync(stream).Result;
|
||||
return GCHandle.ToIntPtr(GCHandle.Alloc(module));
|
||||
}
|
||||
|
||||
|
@ -251,7 +250,7 @@ public static unsafe class CApi
|
|||
private static void CompilerCompile(IntPtr compilerHandle)
|
||||
{
|
||||
var compiler = Get<Compiler>(compilerHandle);
|
||||
compiler.Compile();
|
||||
compiler.CompileAsync().Wait();
|
||||
}
|
||||
|
||||
[UnmanagedCallersOnly]
|
||||
|
@ -330,10 +329,22 @@ public static unsafe class CApi
|
|||
Get<QuantizeOptions>(quantizeOptionsHandle).CalibrationMethod = calibMethod;
|
||||
}
|
||||
|
||||
[UnmanagedCallersOnly]
|
||||
private static void QuantizeOptionsSetModelQuantMode(IntPtr quantizeOptionsHandle, ModelQuantMode quantMode)
|
||||
{
|
||||
Get<QuantizeOptions>(quantizeOptionsHandle).ModelQuantMode = quantMode;
|
||||
}
|
||||
|
||||
[UnmanagedCallersOnly]
|
||||
private static void QuantizeOptionsSetQuantType(IntPtr quantizeOptionsHandle, IntPtr quantTypeHandle)
|
||||
{
|
||||
Get<QuantizeOptions>(quantizeOptionsHandle).QuantType = Get<DataType>(quantTypeHandle);
|
||||
}
|
||||
|
||||
[UnmanagedCallersOnly]
|
||||
private static IntPtr RTValueFromHandle(IntPtr handle)
|
||||
{
|
||||
var rtValue = RTValue.FromHandle(handle);
|
||||
var rtValue = RTValue.FromHandle(handle, true);
|
||||
return GCHandle.ToIntPtr(GCHandle.Alloc(rtValue));
|
||||
}
|
||||
|
||||
|
@ -350,6 +361,13 @@ public static unsafe class CApi
|
|||
return GCHandle.ToIntPtr(GCHandle.Alloc(new CStream(mt, handle)));
|
||||
}
|
||||
|
||||
[UnmanagedCallersOnly]
|
||||
private static IntPtr TargetCreate(byte* targetNamePtr, nuint targetNameLength)
|
||||
{
|
||||
var targetName = ToString(targetNamePtr, targetNameLength);
|
||||
return GCHandle.ToIntPtr(GCHandle.Alloc(CompilerServices.GetTarget(targetName)));
|
||||
}
|
||||
|
||||
[UnmanagedCallersOnly]
|
||||
private static byte TargetExists(byte* targetNamePtr, nuint targetNameLength)
|
||||
{
|
||||
|
|
|
@ -9,13 +9,16 @@
|
|||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Autofac.Extensions.DependencyInjection" />
|
||||
<PackageReference Include="DryIoc.dll" />
|
||||
<PackageReference Include="DryIoc.Microsoft.DependencyInjection" />
|
||||
<PackageReference Include="Microsoft.Extensions.Hosting" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\modules\Nncase.Modules.K210\Nncase.Modules.K210.csproj" />
|
||||
<ProjectReference Include="..\Nncase.CodeGen\Nncase.CodeGen.csproj" />
|
||||
<ProjectReference Include="..\Nncase.Core\Nncase.Core.csproj" />
|
||||
<ProjectReference Include="..\Nncase.Diagnostics\Nncase.Diagnostics.csproj" />
|
||||
<ProjectReference Include="..\Nncase.Graph\Nncase.Graph.csproj" />
|
||||
<ProjectReference Include="..\Nncase.EGraph\Nncase.EGraph.csproj" />
|
||||
<ProjectReference Include="..\Nncase.Evaluator\Nncase.Evaluator.csproj" />
|
||||
|
|
|
@ -1,130 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System.Diagnostics;
|
||||
using System.Runtime.InteropServices;
|
||||
using NetFabric.Hyperlinq;
|
||||
using Nncase.Evaluator;
|
||||
using Nncase.IR;
|
||||
using Nncase.IR.Math;
|
||||
using Nncase.Quantization;
|
||||
using Nncase.Runtime.Interop;
|
||||
using Nncase.Utilities;
|
||||
|
||||
namespace Nncase.Compiler;
|
||||
|
||||
public static class PythonHelper
|
||||
{
|
||||
public static void LaunchDebugger()
|
||||
{
|
||||
Console.WriteLine(System.Environment.Version.ToString());
|
||||
Debugger.Launch();
|
||||
}
|
||||
|
||||
public static IValue TensorValueFromBytes(DataType type, byte[] span, int[] dimensions)
|
||||
{
|
||||
return Value.FromTensor(Tensor.FromBytes(type, span, dimensions));
|
||||
}
|
||||
|
||||
public static Tensor TensorFromBytes(DataType type, byte[] span, int[] dimensions)
|
||||
{
|
||||
return Tensor.FromBytes(type, span, dimensions);
|
||||
}
|
||||
|
||||
public static byte[] BytesBufferFromTensor(Tensor value)
|
||||
{
|
||||
return value.BytesBuffer.ToArray();
|
||||
}
|
||||
|
||||
public static Memory<byte> ToMemory(byte[] bytes) => new(bytes);
|
||||
|
||||
public static byte[] GetRTTensorBytes(RTTensor tensor)
|
||||
{
|
||||
var buffer = tensor.Buffer.Buffer.AsHost()!;
|
||||
using (var mmOwner = buffer.Map(RTMapAccess.Read))
|
||||
{
|
||||
return mmOwner.Memory.Span.ToArray();
|
||||
}
|
||||
}
|
||||
|
||||
public static uint[] GetRTTensorDims(RTTensor tensor)
|
||||
{
|
||||
return tensor.Dimensions.ToArray();
|
||||
}
|
||||
|
||||
public static IValue Evaluate(Expr expr, IReadOnlyDictionary<Var, IValue>? varsValues = null)
|
||||
{
|
||||
if (CompilerServices.CompileOptions.DumpLevel > 4)
|
||||
{
|
||||
return DumpManager.RunWithDump("Evaluator", () => CompilerServices.Evaluate(expr, varsValues));
|
||||
}
|
||||
else
|
||||
{
|
||||
return CompilerServices.Evaluate(expr, varsValues);
|
||||
}
|
||||
}
|
||||
|
||||
public static RTTensor[] RunSimulator(RTInterpreter interp, RTValue[] input)
|
||||
{
|
||||
interp.SetDumpRoot(CompilerServices.CompileOptions.DumpDir);
|
||||
var entry = interp.Entry;
|
||||
var result = entry.Invoke(input);
|
||||
if (result is RTTensor tensor)
|
||||
{
|
||||
return new[] { tensor };
|
||||
}
|
||||
else if (result is RTTuple tuple)
|
||||
{
|
||||
// todo: field maybe a tuple, but not process in this
|
||||
return tuple.Fields.Select(x => (RTTensor)x).ToArray();
|
||||
}
|
||||
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public static bool TargetExist(string target)
|
||||
{
|
||||
try
|
||||
{
|
||||
CompilerServices.GetTarget(target);
|
||||
return true;
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Console.WriteLine(e);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Tensor[sample_count * input_count] dataSet
|
||||
public static PytestCalibrationDatasetProvider MakeDatasetProvider(Tensor[] dataSet, int sampleCount, Var[] fnParams)
|
||||
{
|
||||
var inputCount = dataSet[0].Length / sampleCount;
|
||||
|
||||
var samples = dataSet.Chunk(inputCount).Select(inputs => inputs.Zip(fnParams).ToDictionary(
|
||||
item => item.Item2,
|
||||
item => (IValue)Value.FromTensor(item.Item1))).ToArray().ToAsyncEnumerable();
|
||||
return new PytestCalibrationDatasetProvider(samples, sampleCount);
|
||||
}
|
||||
|
||||
public static QuantizeOptions MakeQuantizeOptions(ICalibrationDatasetProvider datasetProvider)
|
||||
{
|
||||
return new QuantizeOptions
|
||||
{ BindQuantMethod = false, CalibrationDataset = datasetProvider, CalibrationMethod = CalibMethod.NoClip };
|
||||
}
|
||||
|
||||
public class PytestCalibrationDatasetProvider : ICalibrationDatasetProvider
|
||||
{
|
||||
private readonly int _sampleCount;
|
||||
|
||||
public PytestCalibrationDatasetProvider(IAsyncEnumerable<IReadOnlyDictionary<Var, IValue>> samples, int sampleCount)
|
||||
{
|
||||
Samples = samples;
|
||||
_sampleCount = sampleCount;
|
||||
}
|
||||
|
||||
public int? Count => _sampleCount;
|
||||
|
||||
public IAsyncEnumerable<IReadOnlyDictionary<Var, IValue>> Samples { get; }
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase.CodeGen;
|
||||
|
||||
/// <summary>
|
||||
/// Linked model.
|
||||
/// </summary>
|
||||
public interface ILinkedModel
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets entry function id.
|
||||
/// </summary>
|
||||
FunctionId? Entry { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets linked modules.
|
||||
/// </summary>
|
||||
IReadOnlyList<ILinkedModule> Modules { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Serialize model to stream.
|
||||
/// </summary>
|
||||
/// <param name="output">Stream to be written.</param>
|
||||
void Serialize(Stream output);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Model builder.
|
||||
/// </summary>
|
||||
public interface IModelBuilder
|
||||
{
|
||||
/// <summary>
|
||||
/// Build linked model.
|
||||
/// </summary>
|
||||
/// <param name="module">Module.</param>
|
||||
/// <returns>Linked model.</returns>
|
||||
ILinkedModel Build(IRModule module);
|
||||
}
|
|
@ -2,146 +2,38 @@
|
|||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.Quantization;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
/// <summary>
|
||||
/// CompileOptions.
|
||||
/// Compile options.
|
||||
/// </summary>
|
||||
public sealed class CompileOptions
|
||||
public sealed record CompileOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="CompileOptions"/> class.
|
||||
/// copy ctor.
|
||||
/// Gets or sets input file.
|
||||
/// </summary>
|
||||
public CompileOptions(CompileOptions other)
|
||||
{
|
||||
InputFile = other.InputFile;
|
||||
InputFormat = other.InputFormat;
|
||||
Target = other.Target;
|
||||
DumpLevel = other.DumpLevel;
|
||||
DumpDir = other.DumpDir;
|
||||
ModelQuantMode = other.ModelQuantMode;
|
||||
QuantType = other.QuantType;
|
||||
WQuantType = other.WQuantType;
|
||||
OutputFile = other.OutputFile;
|
||||
}
|
||||
public string InputFile { get; set; } = "<stream>";
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="CompileOptions"/> class.
|
||||
/// CompileOptions.
|
||||
/// Gets or sets the import model format.
|
||||
/// </summary>
|
||||
public CompileOptions()
|
||||
{
|
||||
InputFile = string.Empty;
|
||||
InputFormat = string.Empty;
|
||||
Target = string.Empty;
|
||||
DumpLevel = -1;
|
||||
DumpDir = string.Empty;
|
||||
ModelQuantMode = ModelQuantMode.NoQuant;
|
||||
QuantType = DataTypes.UInt8;
|
||||
WQuantType = DataTypes.UInt8;
|
||||
OutputFile = string.Empty;
|
||||
}
|
||||
public string InputFormat { get; set; } = "onnx";
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="CompileOptions"/> class.
|
||||
/// init.
|
||||
/// Gets or sets the dump flags.
|
||||
/// </summary>
|
||||
/// <param name="modelQuantMode"></param>
|
||||
public CompileOptions(ModelQuantMode modelQuantMode)
|
||||
{
|
||||
InputFile = string.Empty;
|
||||
InputFormat = string.Empty;
|
||||
Target = string.Empty;
|
||||
DumpLevel = -1;
|
||||
DumpDir = string.Empty;
|
||||
ModelQuantMode = modelQuantMode;
|
||||
QuantType = DataTypes.UInt8;
|
||||
WQuantType = DataTypes.UInt8;
|
||||
OutputFile = string.Empty;
|
||||
QuantizeOptions = QuantizeOptions;
|
||||
}
|
||||
public DumpFlags DumpFlags { get; set; } = DumpFlags.None;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public string InputFile { get; set; }
|
||||
/// <summary>
|
||||
/// Gets or sets the dump directory.
|
||||
/// </summary>
|
||||
public string DumpDir { get; set; } = string.Empty;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public string InputFormat { get; set; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public string Target { get; set; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public int DumpLevel { get; set; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public string DumpDir { get; set; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public DataType QuantType { get; set; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public DataType WQuantType { get; set; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public string OutputFile { get; set; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public ModelQuantMode ModelQuantMode { get; set; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public QuantizeOptions? QuantizeOptions { get; set; }
|
||||
/// <summary>
|
||||
/// Gets or sets quant options.
|
||||
/// </summary>
|
||||
public QuantizeOptions QuantizeOptions { get; set; } = QuantizeOptions.CreateNoQuant();
|
||||
}
|
||||
|
||||
// /// <summary>
|
||||
// /// Options of compile command.
|
||||
// /// </summary>
|
||||
// public interface CompileOptions
|
||||
// {
|
||||
// /// <summary>
|
||||
// /// Gets or sets input file.
|
||||
// /// </summary>
|
||||
// public string InputFile { get; set; }
|
||||
|
||||
// /// <summary>
|
||||
// /// Gets or sets output file.
|
||||
// /// </summary>
|
||||
// public string OutputFile { get; set; }
|
||||
|
||||
// /// <summary>
|
||||
// /// Gets or sets the import model format.
|
||||
// /// </summary>
|
||||
// public string InputFormat { get; set; }
|
||||
|
||||
// /// <summary>
|
||||
// /// Gets or sets target.
|
||||
// /// </summary>
|
||||
// public string Target { get; set; }
|
||||
|
||||
// /// <summary>
|
||||
// /// Gets or sets the dump level.
|
||||
// /// </summary>
|
||||
// public int DumpLevel { get; set; }
|
||||
|
||||
// /// <summary>
|
||||
// /// Gets or sets the dump directory.
|
||||
// /// </summary>
|
||||
// public string DumpDir { get; set; }
|
||||
|
||||
// /// <summary>
|
||||
// /// weather use ptq
|
||||
// /// </summary>
|
||||
// public bool UsePTQ { get; set; }
|
||||
|
||||
// /// <summary>
|
||||
// /// Gets or sets quant type
|
||||
// /// </summary>
|
||||
// public DataType QuantType { get; set; }
|
||||
|
||||
// /// <summary>
|
||||
// /// Gets or sets quant mode
|
||||
// /// </summary>
|
||||
// public QuantMode QuantMode { get; set; }
|
||||
// }
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using DryIoc;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Nncase.CodeGen;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using Nncase.Transform;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
/// <summary>
|
||||
/// Compile session.
|
||||
/// </summary>
|
||||
public sealed class CompileSession : IServiceProvider, IDisposable
|
||||
{
|
||||
private readonly IResolverContext _serviceProvider;
|
||||
|
||||
private bool _disposedValue;
|
||||
private ICompiler? _compiler;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="CompileSession"/> class.
|
||||
/// </summary>
|
||||
/// <param name="serviceProvider">Service provider.</param>
|
||||
/// <param name="target">Target.</param>
|
||||
/// <param name="compileOptions">Compile options.</param>
|
||||
internal CompileSession(IResolverContext serviceProvider, ITarget target, CompileOptions compileOptions)
|
||||
{
|
||||
_serviceProvider = serviceProvider;
|
||||
Target = target;
|
||||
CompileOptions = compileOptions;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets target.
|
||||
/// </summary>
|
||||
public ITarget Target { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets compile options.
|
||||
/// </summary>
|
||||
public CompileOptions CompileOptions { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets compiler.
|
||||
/// </summary>
|
||||
public ICompiler Compiler => _compiler ??= this.GetRequiredService<ICompiler>();
|
||||
|
||||
/// <summary>
|
||||
/// Create new compile session.
|
||||
/// </summary>
|
||||
/// <param name="target">Compile target.</param>
|
||||
/// <param name="compileOptions">Compile options.</param>
|
||||
/// <returns>Created compile session.</returns>
|
||||
public static CompileSession Create(ITarget target, CompileOptions compileOptions)
|
||||
{
|
||||
var childContainer = CompilerServices.CreateScope();
|
||||
childContainer.RegisterInstance(target);
|
||||
childContainer.RegisterInstance(compileOptions);
|
||||
|
||||
var session = new CompileSession(childContainer, target, compileOptions);
|
||||
childContainer.RegisterInstance(session);
|
||||
return session;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public object? GetService(Type serviceType) => _serviceProvider.GetService(serviceType);
|
||||
|
||||
/// <summary>
|
||||
/// Create new pass manager.
|
||||
/// </summary>
|
||||
/// <param name="name">Name.</param>
|
||||
/// <returns>Created pass manager.</returns>
|
||||
public IPassManager CreatePassManager(string name)
|
||||
{
|
||||
return new PassManager(name, this);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public void Dispose()
|
||||
{
|
||||
Dispose(disposing: true);
|
||||
}
|
||||
|
||||
private void Dispose(bool disposing)
|
||||
{
|
||||
if (!_disposedValue)
|
||||
{
|
||||
if (disposing)
|
||||
{
|
||||
_serviceProvider.Dispose();
|
||||
}
|
||||
|
||||
_disposedValue = true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
internal struct CompileSessionScope : IDisposable
|
||||
{
|
||||
private static readonly AsyncLocal<CompileSession?> _compileSession = new AsyncLocal<CompileSession?>();
|
||||
|
||||
private readonly bool _initialized;
|
||||
private readonly CompileSession? _originalCompileSession;
|
||||
|
||||
public CompileSessionScope(CompileSession compileSession)
|
||||
{
|
||||
_initialized = true;
|
||||
_originalCompileSession = _compileSession.Value;
|
||||
_compileSession.Value = compileSession;
|
||||
}
|
||||
|
||||
public static CompileSession? Current => _compileSession.Value;
|
||||
|
||||
public static CompileSession GetCurrentThrowIfNull() => Current ?? throw new InvalidOperationException($"Current {nameof(CompileSession)} is not set");
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
if (_initialized)
|
||||
{
|
||||
_compileSession.Value = _originalCompileSession;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -3,12 +3,14 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.ComponentModel;
|
||||
using System.Data;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using DryIoc;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Nncase.CostModel;
|
||||
using Nncase.Evaluator;
|
||||
using Nncase.IR;
|
||||
|
@ -23,12 +25,6 @@ namespace Nncase;
|
|||
/// </summary>
|
||||
public interface ICompilerServicesProvider
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets get CompileOptions.
|
||||
/// </summary>
|
||||
/// <returns>CompileOptions.</returns>
|
||||
CompileOptions CompileOptions { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Inference type of the expression tree.
|
||||
/// </summary>
|
||||
|
@ -151,7 +147,7 @@ public interface ICompilerServicesProvider
|
|||
/// <param name="rules">Rewrite rules.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
/// <returns>Rewrited expression.</returns>
|
||||
Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options);
|
||||
Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options);
|
||||
|
||||
/// <summary>
|
||||
/// Match enodes as root.
|
||||
|
@ -185,7 +181,7 @@ public interface ICompilerServicesProvider
|
|||
/// <param name="rules">Rewrite rules.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
/// <returns>Rewrited expression.</returns>
|
||||
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options);
|
||||
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options);
|
||||
}
|
||||
|
||||
internal interface ICompilerServicesProviderInternal
|
||||
|
@ -198,21 +194,13 @@ internal interface ICompilerServicesProviderInternal
|
|||
/// </summary>
|
||||
public static class CompilerServices
|
||||
{
|
||||
private static IServiceProvider? _serviceProvider;
|
||||
private static ICompilerServicesProvider? _provider;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets get the compile options.
|
||||
/// Gets root services.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public static CompileOptions CompileOptions
|
||||
{
|
||||
get { return Provider.CompileOptions; }
|
||||
set { Provider.CompileOptions = value; }
|
||||
}
|
||||
|
||||
public static string CompileTarget => CompileOptions.Target;
|
||||
|
||||
public static ITarget GetCompileTarget => GetTarget(CompileTarget);
|
||||
internal static IServiceProvider ServiceProvider => _serviceProvider ?? throw new InvalidOperationException("Compiler services provider must be set.");
|
||||
|
||||
internal static IDataTypeServiceProvider DataTypeService => ((ICompilerServicesProviderInternal)Provider).DataTypeService;
|
||||
|
||||
|
@ -221,10 +209,11 @@ public static class CompilerServices
|
|||
/// <summary>
|
||||
/// Configure compiler services.
|
||||
/// </summary>
|
||||
/// <param name="provider">Service provider.</param>
|
||||
public static void Configure(ICompilerServicesProvider provider)
|
||||
/// <param name="serviceProvider">Root service provider.</param>
|
||||
public static void Configure(IServiceProvider serviceProvider)
|
||||
{
|
||||
_provider = provider;
|
||||
_serviceProvider = serviceProvider;
|
||||
_provider = serviceProvider.GetRequiredService<ICompilerServicesProvider>();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -353,7 +342,7 @@ public static class CompilerServices
|
|||
/// <param name="rules">Rewrite rules.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
/// <returns>Rewrited expression.</returns>
|
||||
public static Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options)
|
||||
public static Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
|
||||
{
|
||||
return Provider.Rewrite(expr, rules, options);
|
||||
}
|
||||
|
@ -365,7 +354,7 @@ public static class CompilerServices
|
|||
/// <param name="rules">Rewrite rules.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
/// <returns>Rewrited expression.</returns>
|
||||
public static Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options)
|
||||
public static Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
|
||||
{
|
||||
return Provider.ERewrite(expr, rules, options);
|
||||
}
|
||||
|
@ -444,6 +433,22 @@ public static class CompilerServices
|
|||
/// <param name="name">Target name.</param>
|
||||
/// <returns>Target.</returns>
|
||||
public static ITarget GetTarget(string name) => Provider.GetTarget(name);
|
||||
|
||||
internal static DryIoc.IContainer CreateScope()
|
||||
{
|
||||
var container = (DryIoc.IContainer)_serviceProvider!;
|
||||
var childDefaultServiceKey = new object();
|
||||
var rules = container.Rules
|
||||
.WithDefaultRegistrationServiceKey(childDefaultServiceKey)
|
||||
.WithFactorySelector(Rules.SelectKeyedOverDefaultFactory(childDefaultServiceKey));
|
||||
return container!.With(
|
||||
container.Parent,
|
||||
rules,
|
||||
container.ScopeContext,
|
||||
RegistrySharing.CloneButKeepCache,
|
||||
container.SingletonScope.Clone(false),
|
||||
Scope.Of(container.OwnCurrentScope));
|
||||
}
|
||||
}
|
||||
|
||||
internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerServicesProviderInternal
|
||||
|
@ -457,11 +462,8 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
|
|||
private readonly IEGraphMatchProvider _eGraphMatchProvider;
|
||||
private readonly IEGraphRewriteProvider _eGraphrewriteProvider;
|
||||
private readonly ITargetProvider _targetProvider;
|
||||
private CompileOptions _compileOptions;
|
||||
|
||||
public CompilerServicesProvider(
|
||||
|
||||
// IOptions<CompileOptions> compileOptions,
|
||||
IEvaluateProvider evaluateProvider,
|
||||
ITypeInferenceProvider typeInferenceProvider,
|
||||
IIRPrinterProvider irprinterProvider,
|
||||
|
@ -488,12 +490,6 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
|
|||
|
||||
public IDataTypeServiceProvider DataTypeService { get; }
|
||||
|
||||
public CompileOptions CompileOptions
|
||||
{
|
||||
get => _compileOptions;
|
||||
set => _compileOptions = value;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public IValue Evaluate(Expr expr, IReadOnlyDictionary<Var, IValue>? varsValues = null, Dictionary<Type, IEvaluator>? evaluator_cache = null)
|
||||
{
|
||||
|
@ -551,19 +547,19 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
|
|||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options)
|
||||
public Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
|
||||
{
|
||||
return _rewriteProvider.Rewrite(expr, rules, options);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public Cost EvaluateCost(Expr expr, IReadOnlyDictionary<Var, Cost>? varsValues = null)
|
||||
public Cost? EvaluateCost(Expr expr, IReadOnlyDictionary<Var, Cost>? varsValues = null)
|
||||
{
|
||||
return _costEvaluateProvider.EvaluateCost(expr, varsValues);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public Cost EvaluateOpCost(Op op, ICostEvaluateContext context)
|
||||
public Cost? EvaluateOpCost(Op op, ICostEvaluateContext context)
|
||||
{
|
||||
return _costEvaluateProvider.EvaluateOpCost(op, context);
|
||||
}
|
||||
|
@ -583,7 +579,7 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
|
|||
return _targetProvider.GetTarget(name);
|
||||
}
|
||||
|
||||
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options)
|
||||
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
|
||||
{
|
||||
return _eGraphrewriteProvider.ERewrite(expr, rules, options);
|
||||
}
|
||||
|
|
|
@ -1,31 +1,32 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Autofac;
|
||||
using DryIoc;
|
||||
using Nncase.Hosting;
|
||||
|
||||
namespace Nncase.Converters;
|
||||
|
||||
/// <summary>
|
||||
/// Converters module.
|
||||
/// </summary>
|
||||
public class ConvertersModule : Module
|
||||
internal class ConvertersModule : IApplicationPart
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
protected override void Load(ContainerBuilder builder)
|
||||
public void ConfigureServices(IRegistrator registrator)
|
||||
{
|
||||
builder.RegisterType<BFloat16Converters>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<BooleanConverters>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<DoubleConverters>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<HalfConverters>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<Int16Converters>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<Int32Converters>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<Int64Converters>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<Int8Converters>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<SingleConverters>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<UInt16Converters>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<UInt32Converters>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<UInt64Converters>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<UInt8Converters>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<PointerConverters>().AsImplementedInterfaces().SingleInstance();
|
||||
registrator.RegisterManyInterface<BFloat16Converters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<BooleanConverters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<DoubleConverters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<HalfConverters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<Int16Converters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<Int32Converters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<Int64Converters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<Int8Converters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<SingleConverters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<UInt16Converters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<UInt32Converters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<UInt64Converters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<UInt8Converters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<PointerConverters>(reuse: Reuse.Singleton);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,10 @@ using System.Linq;
|
|||
using System.Reflection;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using DryIoc;
|
||||
using Nncase.Converters;
|
||||
using Nncase.Hosting;
|
||||
using Nncase.Targets;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
|
@ -18,11 +22,12 @@ public static class CoreApplicationPart
|
|||
/// <summary>
|
||||
/// Add core assembly.
|
||||
/// </summary>
|
||||
/// <param name="assemblies">Assembly collection.</param>
|
||||
/// <returns>Updated assembly collection.</returns>
|
||||
public static IList<Assembly> AddCore(this IList<Assembly> assemblies)
|
||||
/// <param name="registrator">Service registrator.</param>
|
||||
/// <returns>Configured service registrator.</returns>
|
||||
public static IRegistrator AddCore(this IRegistrator registrator)
|
||||
{
|
||||
assemblies.Add(typeof(CoreApplicationPart).Assembly);
|
||||
return assemblies;
|
||||
return registrator.RegisterModule<CoreModule>()
|
||||
.RegisterModule<ConvertersModule>()
|
||||
.RegisterModule<TargetsModule>();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,38 +1,40 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Autofac;
|
||||
using DryIoc;
|
||||
using Nncase.Hosting;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
/// <summary>
|
||||
/// Core module.
|
||||
/// </summary>
|
||||
public class CoreModule : Module
|
||||
internal class CoreModule : IApplicationPart
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
protected override void Load(ContainerBuilder builder)
|
||||
public void ConfigureServices(IRegistrator registrator)
|
||||
{
|
||||
builder.RegisterType<CompilerServicesProvider>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<DataTypeServiceProvider>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<IR.IRPrinterProvider>().AsImplementedInterfaces().SingleInstance();
|
||||
builder.RegisterType<CompileOptions>().AsImplementedInterfaces().SingleInstance();
|
||||
registrator.RegisterManyInterface<CompilerServicesProvider>(reuse: Reuse.Singleton);
|
||||
registrator.Register<IDataTypeServiceProvider, DataTypeServiceProvider>(reuse: Reuse.Singleton);
|
||||
registrator.Register<IIRPrinterProvider, IRPrinterProvider>(reuse: Reuse.Singleton);
|
||||
|
||||
// Prim types
|
||||
builder.RegisterType<BooleanType>().As<PrimType>().SingleInstance();
|
||||
builder.RegisterType<Utf8CharType>().As<PrimType>().SingleInstance();
|
||||
builder.RegisterType<Int8Type>().As<PrimType>().SingleInstance();
|
||||
builder.RegisterType<Int16Type>().As<PrimType>().SingleInstance();
|
||||
builder.RegisterType<Int32Type>().As<PrimType>().SingleInstance();
|
||||
builder.RegisterType<Int64Type>().As<PrimType>().SingleInstance();
|
||||
builder.RegisterType<UInt8Type>().As<PrimType>().SingleInstance();
|
||||
builder.RegisterType<UInt16Type>().As<PrimType>().SingleInstance();
|
||||
builder.RegisterType<UInt32Type>().As<PrimType>().SingleInstance();
|
||||
builder.RegisterType<UInt64Type>().As<PrimType>().SingleInstance();
|
||||
builder.RegisterType<Float16Type>().As<PrimType>().SingleInstance();
|
||||
builder.RegisterType<Float32Type>().As<PrimType>().SingleInstance();
|
||||
builder.RegisterType<Float64Type>().As<PrimType>().SingleInstance();
|
||||
builder.RegisterType<BFloat16Type>().As<PrimType>().SingleInstance();
|
||||
builder.RegisterType<QuantParamType>().As<ValueType>().SingleInstance();
|
||||
registrator.Register<PrimType, BooleanType>(reuse: Reuse.Singleton);
|
||||
registrator.Register<PrimType, Utf8CharType>(reuse: Reuse.Singleton);
|
||||
registrator.Register<PrimType, Int8Type>(reuse: Reuse.Singleton);
|
||||
registrator.Register<PrimType, Int16Type>(reuse: Reuse.Singleton);
|
||||
registrator.Register<PrimType, Int32Type>(reuse: Reuse.Singleton);
|
||||
registrator.Register<PrimType, Int64Type>(reuse: Reuse.Singleton);
|
||||
registrator.Register<PrimType, UInt8Type>(reuse: Reuse.Singleton);
|
||||
registrator.Register<PrimType, UInt16Type>(reuse: Reuse.Singleton);
|
||||
registrator.Register<PrimType, UInt32Type>(reuse: Reuse.Singleton);
|
||||
registrator.Register<PrimType, UInt64Type>(reuse: Reuse.Singleton);
|
||||
registrator.Register<PrimType, Float16Type>(reuse: Reuse.Singleton);
|
||||
registrator.Register<PrimType, Float32Type>(reuse: Reuse.Singleton);
|
||||
registrator.Register<PrimType, Float64Type>(reuse: Reuse.Singleton);
|
||||
registrator.Register<PrimType, BFloat16Type>(reuse: Reuse.Singleton);
|
||||
|
||||
// Value types
|
||||
registrator.Register<ValueType, QuantParamType>(reuse: Reuse.Singleton);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ using System.Collections.Generic;
|
|||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Autofac;
|
||||
using DryIoc;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
|
@ -28,27 +28,27 @@ internal class DataTypeServiceProvider : IDataTypeServiceProvider
|
|||
private readonly Dictionary<RuntimeTypeHandle, PrimType> _primTypes = new();
|
||||
private readonly Dictionary<Runtime.TypeCode, PrimType> _typeCodeToPrimTypes = new();
|
||||
private readonly Dictionary<RuntimeTypeHandle, ValueType> _valueTypes = new();
|
||||
private readonly IComponentContext _componentContext;
|
||||
private readonly IResolver _resolver;
|
||||
|
||||
public DataTypeServiceProvider(PrimType[] primTypes, ValueType[] valueTypes, IComponentContext componentContext)
|
||||
public DataTypeServiceProvider(PrimType[] primTypes, ValueType[] valueTypes, IResolver resolver)
|
||||
{
|
||||
_primTypes = primTypes.ToDictionary(x => x.CLRType.TypeHandle);
|
||||
_typeCodeToPrimTypes = primTypes.Where(x => x.TypeCode < Runtime.TypeCode.ValueType).ToDictionary(x => x.TypeCode);
|
||||
_valueTypes = valueTypes.ToDictionary(x => x.CLRType.TypeHandle);
|
||||
_componentContext = componentContext;
|
||||
_resolver = resolver;
|
||||
}
|
||||
|
||||
public ISpanConverter GetConverter(Type fromType, Type toType)
|
||||
{
|
||||
if (fromType.IsGenericType && fromType.GetGenericTypeDefinition() == typeof(Pointer<>))
|
||||
{
|
||||
var converter = _componentContext.Resolve(typeof(IPointerSpanConverter<>).MakeGenericType(toType));
|
||||
var converter = _resolver.Resolve(typeof(IPointerSpanConverter<>).MakeGenericType(toType));
|
||||
var wrapperType = typeof(PointerSpanConverter<,>).MakeGenericType(fromType.GenericTypeArguments[0], toType);
|
||||
return (ISpanConverter)Activator.CreateInstance(wrapperType, converter)!;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (ISpanConverter)_componentContext.Resolve(typeof(ISpanConverter<,>).MakeGenericType(fromType, toType));
|
||||
return (ISpanConverter)_resolver.Resolve(typeof(ISpanConverter<,>).MakeGenericType(fromType, toType));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase.Diagnostics;
|
||||
|
||||
/// <summary>
|
||||
/// Dump flags.
|
||||
/// </summary>
|
||||
[Flags]
|
||||
public enum DumpFlags
|
||||
{
|
||||
/// <summary>
|
||||
/// Nothing need to be dump.
|
||||
/// </summary>
|
||||
None = 0,
|
||||
|
||||
/// <summary>
|
||||
/// Dump import ops.
|
||||
/// </summary>
|
||||
ImportOps = 1 << 1,
|
||||
|
||||
/// <summary>
|
||||
/// Dump pass pre and post ir.
|
||||
/// </summary>
|
||||
PassIR = 1 << 2,
|
||||
|
||||
/// <summary>
|
||||
/// Dump egraph costs.
|
||||
/// </summary>
|
||||
EGraphCost = 1 << 3,
|
||||
|
||||
/// <summary>
|
||||
/// Dump rewrite.
|
||||
/// </summary>
|
||||
Rewrite = 1 << 4,
|
||||
|
||||
/// <summary>
|
||||
/// Dump calibration.
|
||||
/// </summary>
|
||||
Calibration = 1 << 5,
|
||||
|
||||
/// <summary>
|
||||
/// Dump evaluator values.
|
||||
/// </summary>
|
||||
Evaluator = 1 << 6,
|
||||
|
||||
/// <summary>
|
||||
/// Dump compile stages.
|
||||
/// </summary>
|
||||
Compile = 1 << 7,
|
||||
|
||||
/// <summary>
|
||||
/// Dump tiling.
|
||||
/// </summary>
|
||||
Tiling = 1 << 8,
|
||||
|
||||
/// <summary>
|
||||
/// Dump schedule.
|
||||
/// </summary>
|
||||
Schedule = 1 << 9,
|
||||
|
||||
/// <summary>
|
||||
/// Dump codegen.
|
||||
/// </summary>
|
||||
CodeGen = 1 << 10,
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
||||
namespace Nncase.Diagnostics;
|
||||
|
||||
/// <summary>
|
||||
/// <see cref="IDumpper"/> scope.
|
||||
/// </summary>
|
||||
public struct DumpScope : IDisposable
|
||||
{
|
||||
private static readonly AsyncLocal<IDumpper> _dumpper = new AsyncLocal<IDumpper>();
|
||||
|
||||
private readonly bool _initialized;
|
||||
private readonly IDumpper _originalDumpper;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="DumpScope"/> struct.
|
||||
/// </summary>
|
||||
/// <param name="subDirectory">Sub directory.</param>
|
||||
/// <param name="serviceProvider">Service provider.</param>
|
||||
public DumpScope(string subDirectory, IServiceProvider? serviceProvider = null)
|
||||
{
|
||||
_initialized = true;
|
||||
_originalDumpper = GetCurrent(serviceProvider);
|
||||
_dumpper.Value = _originalDumpper.CreateSubDummper(subDirectory);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="DumpScope"/> struct.
|
||||
/// </summary>
|
||||
/// <param name="newDumpper">New dumpper.</param>
|
||||
/// <param name="serviceProvider">Service provider.</param>
|
||||
public DumpScope(IDumpper newDumpper, IServiceProvider? serviceProvider = null)
|
||||
{
|
||||
_initialized = true;
|
||||
_originalDumpper = GetCurrent(serviceProvider);
|
||||
_dumpper.Value = newDumpper;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets current dumpper.
|
||||
/// </summary>
|
||||
public static IDumpper Current => GetCurrent(null);
|
||||
|
||||
/// <summary>
|
||||
/// Gets current <see cref="IDumpper"/> or use root of scope.
|
||||
/// </summary>
|
||||
/// <param name="serviceProvider">Service provider.</param>
|
||||
/// <returns>Current dumpper.</returns>
|
||||
public static IDumpper GetCurrent(IServiceProvider? serviceProvider) =>
|
||||
_dumpper.Value ??= (serviceProvider ?? CompileSessionScope.Current)?.GetRequiredService<IDumpperFactory>().Root ?? NullDumpper.Instance;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public void Dispose()
|
||||
{
|
||||
if (_initialized)
|
||||
{
|
||||
_dumpper.Value = _originalDumpper;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR;
|
||||
using Nncase.Utilities;
|
||||
|
||||
namespace Nncase.Diagnostics;
|
||||
|
||||
/// <summary>
|
||||
/// Data dumpper.
|
||||
/// </summary>
|
||||
public interface IDumpper
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets dump directory.
|
||||
/// </summary>
|
||||
string Directory { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether dump is enabled.
|
||||
/// </summary>
|
||||
/// <param name="dumpFlags">Dump flags.</param>
|
||||
/// <returns>Whether dump is enabled.</returns>
|
||||
bool IsEnabled(DumpFlags dumpFlags);
|
||||
|
||||
/// <summary>
|
||||
/// Create sub dummper.
|
||||
/// </summary>
|
||||
/// <param name="subDirectory">Sub directory.</param>
|
||||
/// <returns>Sub dummper.</returns>
|
||||
IDumpper CreateSubDummper(string subDirectory);
|
||||
|
||||
void DumpIR(Expr expr, string prefix, string? reletivePath = null);
|
||||
|
||||
void DumpDotIR(Expr expr, string prefix, string? reletivePath = null);
|
||||
|
||||
void DumpModule(IRModule module, string? reletivePath = null);
|
||||
|
||||
Stream OpenFile(string reletivePath, FileMode fileMode = FileMode.Create);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Dumpper factory.
|
||||
/// </summary>
|
||||
public interface IDumpperFactory
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets root dummper.
|
||||
/// </summary>
|
||||
IDumpper Root { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Creat dumpper.
|
||||
/// </summary>
|
||||
/// <param name="relativePath">Sub directory.</param>
|
||||
/// <returns>Dumpper.</returns>
|
||||
IDumpper CreateDummper(string relativePath);
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase.Diagnostics;
|
||||
|
||||
/// <summary>
|
||||
/// <see cref="IDumpper"/> with no backing store.
|
||||
/// </summary>
|
||||
public sealed class NullDumpper : IDumpper
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets instance.
|
||||
/// </summary>
|
||||
public static NullDumpper Instance { get; } = new NullDumpper();
|
||||
|
||||
/// <inheritdoc/>
|
||||
public string Directory => System.IO.Directory.GetCurrentDirectory();
|
||||
|
||||
/// <inheritdoc/>
|
||||
public IDumpper CreateSubDummper(string subDirectory) => this;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public void DumpIR(Expr expr, string prefix, string? reletivePath = null)
|
||||
{
|
||||
}
|
||||
|
||||
public void DumpDotIR(Expr expr, string prefix, string? reletivePath = null)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public void DumpModule(IRModule module, string? reletivePath = null)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public bool IsEnabled(DumpFlags dumpFlags) => false;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public Stream OpenFile(string reletivePath, FileMode fileMode = FileMode.Create) => Stream.Null;
|
||||
}
|
|
@ -1,96 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase.Hosting;
|
||||
|
||||
/// <summary>
|
||||
/// Application parts helper.
|
||||
/// </summary>
|
||||
public class ApplicationParts
|
||||
{
|
||||
private const string _appPartsDllPattern = "Nncase.Modules.*.dll";
|
||||
private const string _targetPathEnvName = "NNCASE_TARGET_PATH";
|
||||
|
||||
/// <summary>
|
||||
/// Load application parts.
|
||||
/// </summary>
|
||||
/// <param name="configureAction">Configure action.</param>
|
||||
/// <returns>Application parts assemblies.</returns>
|
||||
public static Assembly[] LoadApplicationParts(Action<IList<Assembly>> configureAction)
|
||||
{
|
||||
var defaultAssemblies = new List<Assembly>() { Assembly.GetCallingAssembly() };
|
||||
configureAction(defaultAssemblies);
|
||||
|
||||
return defaultAssemblies.Concat(
|
||||
GetApplicationPartsSearchDirectories().Select(LoadApplicationParts).SelectMany(x => x)
|
||||
.DistinctBy(Path.GetFileName).Select(Assembly.LoadFrom))
|
||||
.Distinct().ToArray();
|
||||
}
|
||||
|
||||
private static IEnumerable<string> LoadApplicationParts(string basePath)
|
||||
{
|
||||
return Directory.GetFiles(basePath, _appPartsDllPattern, SearchOption.AllDirectories)
|
||||
.Where(x => !Path.GetDirectoryName(x)!.EndsWith("ref"));
|
||||
}
|
||||
|
||||
private static IEnumerable<string> GetApplicationPartsSearchDirectories()
|
||||
{
|
||||
var directories = new List<string>();
|
||||
|
||||
// 1. Executable base
|
||||
var exePath = Path.GetDirectoryName(Assembly.GetCallingAssembly().Location);
|
||||
if (!string.IsNullOrWhiteSpace(exePath))
|
||||
{
|
||||
directories.Add(exePath);
|
||||
}
|
||||
|
||||
// 2. Environment variable
|
||||
var targetPathEnv = Environment.GetEnvironmentVariable(_targetPathEnvName);
|
||||
if (string.IsNullOrWhiteSpace(targetPathEnv))
|
||||
{
|
||||
// todo:log
|
||||
Console.WriteLine("NNCASE_TARGET_PATH is not set.");
|
||||
}
|
||||
else
|
||||
{
|
||||
var targetPaths = from path in targetPathEnv.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries)
|
||||
select Environment.ExpandEnvironmentVariables(path);
|
||||
directories.AddRange(targetPaths);
|
||||
}
|
||||
|
||||
foreach (var directory in directories)
|
||||
{
|
||||
Console.WriteLine(directory);
|
||||
}
|
||||
|
||||
return directories.Distinct();
|
||||
}
|
||||
}
|
||||
|
||||
// Custom comparer for the Product class
|
||||
internal class PathComparer : IEqualityComparer<string>
|
||||
{
|
||||
// Products are equal if their names and product numbers are equal.
|
||||
public bool Equals(string x, string y)
|
||||
{
|
||||
Console.WriteLine("------------------");
|
||||
Console.WriteLine(x);
|
||||
Console.WriteLine(y);
|
||||
Console.WriteLine("------------------");
|
||||
return Path.GetFileName(x) == Path.GetFileName(y) || x == y;
|
||||
}
|
||||
|
||||
// If Equals() returns true for a pair of objects
|
||||
// then GetHashCode() must return the same value for these objects.
|
||||
public int GetHashCode(string s)
|
||||
{
|
||||
return s.GetHashCode();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using DryIoc;
|
||||
|
||||
namespace Nncase.Hosting;
|
||||
|
||||
/// <summary>
|
||||
/// Application part.
|
||||
/// </summary>
|
||||
public interface IApplicationPart
|
||||
{
|
||||
/// <summary>
|
||||
/// Configure services.
|
||||
/// </summary>
|
||||
/// <param name="registrator">Service registrator.</param>
|
||||
void ConfigureServices(IRegistrator registrator);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Application part extensions.
|
||||
/// </summary>
|
||||
public static class ApplicationPartExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Register module.
|
||||
/// </summary>
|
||||
/// <typeparam name="TModule">Module type.</typeparam>
|
||||
/// <param name="registrator">Service registrator.</param>
|
||||
/// <returns>Configured service registrator.</returns>
|
||||
public static IRegistrator RegisterModule<TModule>(this IRegistrator registrator)
|
||||
where TModule : class, IApplicationPart, new()
|
||||
{
|
||||
var module = new TModule();
|
||||
module.ConfigureServices(registrator);
|
||||
return registrator;
|
||||
}
|
||||
|
||||
/// <summary>Registers single registration for all implemented public interfaces and base classes.</summary>
|
||||
/// <typeparam name="TImplementation">Implementation type.</typeparam>
|
||||
/// <param name="registrator">Service registrator.</param>
|
||||
/// <param name="reuse">Reuse strategy.</param>
|
||||
public static void RegisterManyInterface<TImplementation>(this IRegistrator registrator, IReuse? reuse = null)
|
||||
where TImplementation : class
|
||||
{
|
||||
registrator.RegisterMany<TImplementation>(reuse, serviceTypeCondition: t => t.IsInterface);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase.Hosting;
|
||||
|
||||
/// <summary>
|
||||
/// Plugin.
|
||||
/// </summary>
|
||||
public interface IPlugin : IApplicationPart
|
||||
{
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
/// <summary>
|
||||
/// Compiler.
|
||||
/// </summary>
|
||||
public interface ICompiler
|
||||
{
|
||||
/// <summary>
|
||||
/// Import DL model as ir module.
|
||||
/// </summary>
|
||||
/// <param name="content">Model content.</param>
|
||||
/// <returns>Imported ir module.</returns>
|
||||
Task<IRModule> ImportModuleAsync(Stream content);
|
||||
|
||||
/// <summary>
|
||||
/// Compile module.
|
||||
/// </summary>
|
||||
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
|
||||
Task CompileAsync();
|
||||
|
||||
/// <summary>
|
||||
/// Generate code to stream.
|
||||
/// </summary>
|
||||
/// <param name="output">Stream to be written.</param>
|
||||
void Gencode(Stream output);
|
||||
}
|
|
@ -8,82 +8,81 @@ using System.Linq;
|
|||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase.IR
|
||||
namespace Nncase.IR;
|
||||
|
||||
/// <summary>
|
||||
/// the interface that we can use parameterinfo the parameter.
|
||||
/// </summary>
|
||||
/// <typeparam name="T"></typeparam>
|
||||
public interface IParameterList<T>
|
||||
{
|
||||
/// <summary>
|
||||
/// the interface that we can use parameterinfo the parameter.
|
||||
/// get parameter info.
|
||||
/// </summary>
|
||||
/// <typeparam name="T"></typeparam>
|
||||
public interface IParameterList<T>
|
||||
/// <param name="parameter"></param>
|
||||
/// <returns></returns>
|
||||
public T this[ParameterInfo parameter] { get; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Call expression.
|
||||
/// </summary>
|
||||
public sealed record Call(Expr Target, IRArray<Expr> Parameters) : Expr, IParameterList<Expr>
|
||||
{
|
||||
// /// <summary>
|
||||
// /// used by fake ir, represents that whether this op permit int 16 quant.
|
||||
// /// </summary>
|
||||
// public bool PermitInt16Quant = false;
|
||||
|
||||
/// <summary>
|
||||
/// quant config with cosine, List of DataType represents data types for each input might be quantized, List of QuantParam represents quant params for each input.
|
||||
/// may be deleted in the future since there is EnodeBestQuantConfigWithCosine, reserve it now for debug and for unexpected usage when EnodeBestQuantConfigWithCosine is not enough.
|
||||
/// </summary>
|
||||
public List<Tuple<List<DataType>, List<List<QuantParam>>, float>> EnodeQuantConfigWithCosine;
|
||||
|
||||
/// <summary>
|
||||
/// quant config with cosine, List of DataType represents data types for each input might be quantized, List of QuantParam represents quant params for each input.
|
||||
/// </summary>
|
||||
public Tuple<List<DataType>, List<List<QuantParam>>, float> EnodeBestQuantConfigWithCosine;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="Call"/> class.
|
||||
/// </summary>
|
||||
/// <param name="target">Call target.</param>
|
||||
/// <param name="parameters">Parameters.</param>
|
||||
public Call(Expr target, params Expr[] parameters)
|
||||
: this(target, new IRArray<Expr>(parameters.ToImmutableArray()))
|
||||
{
|
||||
/// <summary>
|
||||
/// get parameter info.
|
||||
/// </summary>
|
||||
/// <param name="parameter"></param>
|
||||
/// <returns></returns>
|
||||
public T this[ParameterInfo parameter] { get; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Call expression.
|
||||
/// get param expr.
|
||||
/// </summary>
|
||||
public sealed record Call(Expr Target, IRArray<Expr> Parameters) : Expr, IParameterList<Expr>
|
||||
/// <param name="parameter"></param>
|
||||
/// <returns></returns>
|
||||
/// <exception cref="ArgumentOutOfRangeException"></exception>
|
||||
public Expr this[ParameterInfo parameter]
|
||||
{
|
||||
// /// <summary>
|
||||
// /// used by fake ir, represents that whether this op permit int 16 quant.
|
||||
// /// </summary>
|
||||
// public bool PermitInt16Quant = false;
|
||||
|
||||
/// <summary>
|
||||
/// quant config with cosine, List of DataType represents data types for each input might be quantized, List of QuantParam represents quant params for each input.
|
||||
/// may be deleted in the future since there is EnodeBestQuantConfigWithCosine, reserve it now for debug and for unexpected usage when EnodeBestQuantConfigWithCosine is not enough.
|
||||
/// </summary>
|
||||
public List<Tuple<List<DataType>, List<List<QuantParam>>, float>> EnodeQuantConfigWithCosine;
|
||||
|
||||
/// <summary>
|
||||
/// quant config with cosine, List of DataType represents data types for each input might be quantized, List of QuantParam represents quant params for each input.
|
||||
/// </summary>
|
||||
public Tuple<List<DataType>, List<List<QuantParam>>, float> EnodeBestQuantConfigWithCosine;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="Call"/> class.
|
||||
/// </summary>
|
||||
/// <param name="target">Call target.</param>
|
||||
/// <param name="parameters">Parameters.</param>
|
||||
public Call(Expr target, params Expr[] parameters)
|
||||
: this(target, new IRArray<Expr>(parameters.ToImmutableArray()))
|
||||
get
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// get param expr.
|
||||
/// </summary>
|
||||
/// <param name="parameter"></param>
|
||||
/// <returns></returns>
|
||||
/// <exception cref="ArgumentOutOfRangeException"></exception>
|
||||
public Expr this[ParameterInfo parameter]
|
||||
{
|
||||
get
|
||||
var type = Target.GetType();
|
||||
if (type == parameter.OwnerType)
|
||||
{
|
||||
var type = Target.GetType();
|
||||
if (type == parameter.OwnerType)
|
||||
{
|
||||
return Parameters[parameter.Index];
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentOutOfRangeException($"Target {Target} doesn't have parameter: {parameter.Name}.");
|
||||
}
|
||||
return Parameters[parameter.Index];
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentOutOfRangeException($"Target {Target} doesn't have parameter: {parameter.Name}.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void ParametersForeach(Action<Expr, ParameterInfo> f)
|
||||
public void ParametersForeach(Action<Expr, ParameterInfo> f)
|
||||
{
|
||||
var parameterInfos = ((Op)Target).Parameters.ToArray();
|
||||
for (int i = 0; i < Parameters.Count; i++)
|
||||
{
|
||||
var parameterInfos = ((Op)Target).Parameters.ToArray();
|
||||
for (int i = 0; i < Parameters.Count; i++)
|
||||
{
|
||||
f(Parameters[i], parameterInfos[i]);
|
||||
}
|
||||
f(Parameters[i], parameterInfos[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
/// <summary>
|
||||
/// Module.
|
||||
/// </summary>
|
||||
public sealed class IRModule
|
||||
{
|
||||
private readonly List<BaseFunction> _functions;
|
||||
private int? _entryIndex;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="IRModule"/> class.
|
||||
/// </summary>
|
||||
/// <param name="main">main func.</param>
|
||||
public IRModule(BaseFunction main)
|
||||
{
|
||||
_functions = new() { main };
|
||||
_entryIndex = 0;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="IRModule"/> class.
|
||||
/// the default IrModule ctor.
|
||||
/// </summary>
|
||||
public IRModule()
|
||||
{
|
||||
_functions = new();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets functions.
|
||||
/// </summary>
|
||||
public IReadOnlyList<BaseFunction> Functions => _functions;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets entry function.
|
||||
/// </summary>
|
||||
public BaseFunction? Entry
|
||||
{
|
||||
get => _entryIndex.HasValue ? _functions[_entryIndex.Value] : null;
|
||||
set => _entryIndex = value != null ? _functions.FindIndex(x => object.ReferenceEquals(x, value)) : null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Add function.
|
||||
/// </summary>
|
||||
/// <param name="function">Callable to add.</param>
|
||||
public void Add(BaseFunction function)
|
||||
{
|
||||
_functions.Add(function);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Replace the function defination.
|
||||
/// </summary>
|
||||
/// <param name="index">function index.</param>
|
||||
/// <param name="function">the entry function defination.</param>
|
||||
public void Replace(int index, BaseFunction function)
|
||||
{
|
||||
var old = _functions[index];
|
||||
|
||||
var replacer = new FunctionReplacer(old, function);
|
||||
for (int i = 0; i < _functions.Count; i++)
|
||||
{
|
||||
replacer.Visit(_functions[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < _functions.Count; i++)
|
||||
{
|
||||
var originFunc = _functions[i];
|
||||
if (replacer.ExpressionMemo.TryGetValue(originFunc, out var replace))
|
||||
{
|
||||
_functions[i] = (BaseFunction)replace;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Replace the function call dependencer.
|
||||
/// </summary>
|
||||
private sealed class FunctionReplacer : DeepExprMutator
|
||||
{
|
||||
private readonly BaseFunction _original;
|
||||
private readonly BaseFunction _replace;
|
||||
|
||||
public FunctionReplacer(BaseFunction original, BaseFunction replace)
|
||||
{
|
||||
_original = original;
|
||||
_replace = replace;
|
||||
}
|
||||
|
||||
public override Expr DefaultMutateLeaf(Expr expr)
|
||||
{
|
||||
if (expr is BaseFunction baseFunction
|
||||
&& object.ReferenceEquals(baseFunction, _original))
|
||||
{
|
||||
return _replace;
|
||||
}
|
||||
|
||||
return expr;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,94 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Collections.Immutable;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Nncase.IR
|
||||
{
|
||||
/// <summary>
|
||||
/// Module.
|
||||
/// </summary>
|
||||
public sealed class IRModule
|
||||
{
|
||||
private readonly List<BaseFunction> _functions;
|
||||
|
||||
/// <summary>
|
||||
/// the index of the entry function.
|
||||
/// </summary>
|
||||
private int _entryIndex;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="IRModule"/> class.
|
||||
/// </summary>
|
||||
/// <param name="main"> main func.</param>
|
||||
public IRModule(BaseFunction main)
|
||||
{
|
||||
_functions = new();
|
||||
_functions.Add(main);
|
||||
_entryIndex = 0;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="IRModule"/> class.
|
||||
/// the default IrModule ctor.
|
||||
/// </summary>
|
||||
public IRModule()
|
||||
{
|
||||
_functions = new();
|
||||
_entryIndex = -1;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets functions.
|
||||
/// </summary>
|
||||
public IReadOnlyList<BaseFunction> Functions => _functions;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets entry function.
|
||||
/// </summary>
|
||||
public BaseFunction? Entry
|
||||
{
|
||||
get => _entryIndex == -1 ? null : Functions[_entryIndex];
|
||||
set
|
||||
{
|
||||
if (value is null)
|
||||
{
|
||||
_entryIndex = -1;
|
||||
}
|
||||
else
|
||||
{
|
||||
_entryIndex = _functions.IndexOf(value);
|
||||
if (_entryIndex == -1)
|
||||
{
|
||||
_functions.Add(value);
|
||||
_entryIndex = _functions.Count - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Add function.
|
||||
/// </summary>
|
||||
/// <param name="function">Callable to add.</param>
|
||||
public void Add(BaseFunction function)
|
||||
{
|
||||
_functions.Add(function);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// update the entry function defination.
|
||||
/// </summary>
|
||||
/// <param name="i">function index.</param>
|
||||
/// <param name="function">the entry function defination.</param>
|
||||
public void Update(int i, BaseFunction function)
|
||||
{
|
||||
_functions[i] = function;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -6,6 +6,7 @@ using System.Collections.Generic;
|
|||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using System.Xml.Linq;
|
||||
|
||||
namespace Nncase.IR
|
||||
{
|
||||
|
@ -52,7 +53,7 @@ namespace Nncase.IR
|
|||
/// <summary>
|
||||
/// Variable expression.
|
||||
/// </summary>
|
||||
public record Var : Expr
|
||||
public record Var : Expr, IEquatable<Var>
|
||||
{
|
||||
private static int _globalVarIndex;
|
||||
|
||||
|
@ -64,9 +65,10 @@ namespace Nncase.IR
|
|||
/// <param name="typeAnnotation"></param>
|
||||
public Var(string name, IRType typeAnnotation)
|
||||
{
|
||||
Name = name;
|
||||
TypeAnnotation = typeAnnotation;
|
||||
CheckedType = TypeAnnotation;
|
||||
GlobalVarIndex = GetNextId();
|
||||
Name = name;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -74,8 +76,11 @@ namespace Nncase.IR
|
|||
/// </summary>
|
||||
/// <param name="typeAnnotation">Type annotation.</param>
|
||||
public Var(IRType typeAnnotation)
|
||||
: this($"var_{_globalVarIndex++}", typeAnnotation)
|
||||
{
|
||||
TypeAnnotation = typeAnnotation;
|
||||
CheckedType = TypeAnnotation;
|
||||
GlobalVarIndex = GetNextId();
|
||||
Name = $"var_{GlobalVarIndex}";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -92,14 +97,14 @@ namespace Nncase.IR
|
|||
/// Initializes a new instance of the <see cref="Var"/> class.
|
||||
/// </summary>
|
||||
public Var()
|
||||
: this($"var_{_globalVarIndex++}", AnyType.Default)
|
||||
: this(AnyType.Default)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets get the global var index.
|
||||
/// Gets the global var index.
|
||||
/// </summary>
|
||||
private int GlobalVarIndex => _globalVarIndex;
|
||||
public int GlobalVarIndex { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets name.
|
||||
|
@ -140,5 +145,24 @@ namespace Nncase.IR
|
|||
/// <param name="name"></param>
|
||||
/// <returns></returns>
|
||||
public static Var SizeVar(string name) => Scalar(name, DataTypes.Int32);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public virtual bool Equals(Var? other)
|
||||
{
|
||||
if (other == null)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GlobalVarIndex == other.GlobalVarIndex;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override int GetHashCode() => GlobalVarIndex.GetHashCode();
|
||||
|
||||
private static int GetNextId()
|
||||
{
|
||||
return Interlocked.Increment(ref _globalVarIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,23 +27,21 @@ public interface ITarget
|
|||
/// Bind Quant Method And Quant Cosine With IR.
|
||||
/// </summary>
|
||||
/// <param name="calibrationDataset">calibration dataset.</param>
|
||||
/// <param name="target">target.</param>
|
||||
/// <param name="rangeOfs">rangeOf nodes.</param>
|
||||
/// <param name="childrenOfRangeOfs">rangeOf nodes children.</param>
|
||||
/// <param name="runPassOptions">options.</param>
|
||||
/// <param name="quantizeOptions">options.</param>
|
||||
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
|
||||
Task<Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>> BindQuantMethodCosine(ICalibrationDatasetProvider calibrationDataset, ITarget target, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, RunPassOptions runPassOptions);
|
||||
Task<Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>> BindQuantMethodCosine(ICalibrationDatasetProvider calibrationDataset, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, QuantizeOptions quantizeOptions);
|
||||
|
||||
/// <summary>
|
||||
/// AdaRound Weights.
|
||||
/// </summary>
|
||||
/// <param name="calibrationDataset">calibration dataset.</param>
|
||||
/// <param name="target">target.</param>
|
||||
/// <param name="rangeOfs">rangeOf nodes.</param>
|
||||
/// <param name="childrenOfRangeOfs">rangeOf nodes children.</param>
|
||||
/// <param name="runPassOptions">options.</param>
|
||||
/// <param name="quantizeOptions">options.</param>
|
||||
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
|
||||
Task AdaRoundWeights(ICalibrationDatasetProvider calibrationDataset, ITarget target, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, RunPassOptions runPassOptions);
|
||||
Task AdaRoundWeights(ICalibrationDatasetProvider calibrationDataset, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, QuantizeOptions quantizeOptions);
|
||||
|
||||
/// <summary>
|
||||
/// Parse Target Dependent Options.
|
||||
|
@ -56,21 +54,21 @@ public interface ITarget
|
|||
/// </summary>
|
||||
/// <param name="passManager">pass manager.</param>
|
||||
/// <param name="options">compile options.</param>
|
||||
void RegisterTargetDependentPass(PassManager passManager, CompileOptions options);
|
||||
void RegisterTargetDependentPass(IPassManager passManager, CompileOptions options);
|
||||
|
||||
/// <summary>
|
||||
/// Register Quantize Pass.
|
||||
/// </summary>
|
||||
/// <param name="passManager">pass manager.</param>
|
||||
/// <param name="options">compile options.</param>
|
||||
void RegisterQuantizePass(PassManager passManager, CompileOptions options);
|
||||
void RegisterQuantizePass(IPassManager passManager, CompileOptions options);
|
||||
|
||||
/// <summary>
|
||||
/// Register Target Dependent After Quant Pass.
|
||||
/// </summary>
|
||||
/// <param name="passManager"></param>
|
||||
/// <param name="options">compile options.</param>
|
||||
void RegisterTargetDependentAfterQuantPass(PassManager passManager, CompileOptions options);
|
||||
void RegisterTargetDependentAfterQuantPass(IPassManager passManager, CompileOptions options);
|
||||
|
||||
/// <summary>
|
||||
/// Create module builder.
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Autofac" />
|
||||
<PackageReference Include="DryIoc.dll" />
|
||||
<PackageReference Include="libtorch-cpu-linux-x64" />
|
||||
<PackageReference Include="libtorch-cpu-osx-x64" />
|
||||
<PackageReference Include="libtorch-cpu-win-x64" />
|
||||
|
@ -19,6 +19,7 @@
|
|||
<!-- <PackageReference Include="libtorch-cuda-11.3-linux-x64" /> -->
|
||||
<!-- <PackageReference Include="libtorch-cuda-11.3-win-x64" /> -->
|
||||
<PackageReference Include="Microsoft.Extensions.Hosting.Abstractions" />
|
||||
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
|
||||
<PackageReference Include="Microsoft.Extensions.Options" />
|
||||
<PackageReference Include="Microsoft.Toolkit.HighPerformance" />
|
||||
<PackageReference Include="NetFabric.Hyperlinq" />
|
||||
|
|
|
@ -49,7 +49,7 @@ public static partial class Utility
|
|||
/// <summary>
|
||||
/// match a call with op type T
|
||||
/// auto set first param
|
||||
/// it's always used for Fake to NoFake Rule with ReplaceCall.
|
||||
/// it's always used for Fake to NoFake Pass with ReplaceCall.
|
||||
/// </summary>ReplaceParams
|
||||
/// <param name="callName"></param>
|
||||
/// <param name="opName"></param>
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
[assembly: InternalsVisibleTo("Nncase.Tests")]
|
||||
[assembly: InternalsVisibleTo("Nncase.Tests.TestFixture")]
|
|
@ -80,4 +80,25 @@ public class QuantizeOptions
|
|||
/// Gets or sets a value indicating whether enable adaround to fine tune weights.
|
||||
/// </summary>
|
||||
public bool UseAdaRound { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets quant type.
|
||||
/// </summary>
|
||||
public DataType QuantType { get; set; } = DataTypes.UInt8;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets weights quant type.
|
||||
/// </summary>
|
||||
public DataType WQuantType { get; set; } = DataTypes.UInt8;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets model quant mode.
|
||||
/// </summary>
|
||||
public ModelQuantMode ModelQuantMode { get; set; } = ModelQuantMode.NoQuant;
|
||||
|
||||
/// <summary>
|
||||
/// Creates no quantization options.
|
||||
/// </summary>
|
||||
/// <returns>No quant options.</returns>
|
||||
public static QuantizeOptions CreateNoQuant() => new QuantizeOptions();
|
||||
}
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Autofac;
|
||||
using DryIoc;
|
||||
using Nncase.Hosting;
|
||||
|
||||
namespace Nncase.Targets;
|
||||
|
||||
/// <summary>
|
||||
/// Targets module.
|
||||
/// </summary>
|
||||
public class TargetsModule : Module
|
||||
internal class TargetsModule : IApplicationPart
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
protected override void Load(ContainerBuilder builder)
|
||||
public void ConfigureServices(IRegistrator registrator)
|
||||
{
|
||||
builder.RegisterType<TargetProvider>().AsImplementedInterfaces().SingleInstance();
|
||||
registrator.Register<ITargetProvider, TargetProvider>(reuse: Reuse.Singleton);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -256,6 +256,18 @@ public abstract partial class Tensor : IStructuralComparable, IStructuralEquatab
|
|||
return new Tensor<T>(memory, dimensions);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create tensor from an array, Set the shape as [n].
|
||||
/// </summary>
|
||||
/// <typeparam name="T">CLR type.</typeparam>
|
||||
/// <param name="array">Array.</param>
|
||||
/// <returns>Created tensor.</returns>
|
||||
public static Tensor<T> From<T>(T[] array)
|
||||
where T : unmanaged, IEquatable<T>
|
||||
{
|
||||
return From(array.AsMemory());
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create tensor from an array, Set the shape as provided.
|
||||
/// </summary>
|
||||
|
@ -266,7 +278,7 @@ public abstract partial class Tensor : IStructuralComparable, IStructuralEquatab
|
|||
public static Tensor<T> From<T>(T[] array, ReadOnlySpan<int> dimensions)
|
||||
where T : unmanaged, IEquatable<T>
|
||||
{
|
||||
return new Tensor<T>(array, dimensions);
|
||||
return From(array.AsMemory(), dimensions);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -6,103 +6,45 @@ using System.Collections.Generic;
|
|||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using static TorchSharp.torch.nn;
|
||||
|
||||
namespace Nncase.Transform;
|
||||
|
||||
/// <summary>
|
||||
/// the basic pass.
|
||||
/// </summary>
|
||||
public abstract class BasePass
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="BasePass"/> class.
|
||||
/// the base pass ctor.
|
||||
/// </summary>
|
||||
/// <param name="name"></param>
|
||||
public BasePass(string name)
|
||||
{
|
||||
Name = name;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the pass name.
|
||||
/// </summary>
|
||||
public string Name { get; init; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Pass in Callable scope.
|
||||
/// </summary>
|
||||
public abstract class FunctionPass : BasePass
|
||||
public abstract class FunctionPass : Pass<BaseFunction>
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="FunctionPass"/> class.
|
||||
/// </summary>
|
||||
/// <param name="name">Name.</param>
|
||||
public FunctionPass(string name)
|
||||
: base(name)
|
||||
public FunctionPass()
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Run current pass for specific function.
|
||||
/// </summary>
|
||||
/// <param name="callable">Target function.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
|
||||
public async Task<BaseFunction> RunAsync(BaseFunction callable, RunPassOptions options)
|
||||
/// <inheritdoc/>
|
||||
protected override Task OnPassStartAsync(BaseFunction input, RunPassContext context)
|
||||
{
|
||||
var new_options = options.IndentDir(Name).IndentDir(callable.Name);
|
||||
OnPassStart(callable, new_options);
|
||||
var post = await RunCoreAsync(callable, new_options);
|
||||
OnPassEnd(post, new_options);
|
||||
return post;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Run pass implementation for derived class.
|
||||
/// </summary>
|
||||
/// <param name="callable">Target function.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
|
||||
protected abstract Task<BaseFunction> RunCoreAsync(BaseFunction callable, RunPassOptions options);
|
||||
|
||||
/// <summary>
|
||||
/// the callback function you can custom process func with run pass options.
|
||||
/// </summary>
|
||||
/// <param name="callable"> func without run pass.</param>
|
||||
/// <param name="options"></param>
|
||||
protected virtual void OnPassStart(BaseFunction callable, RunPassOptions options)
|
||||
{
|
||||
switch (options.DumpLevel)
|
||||
if (DumpScope.Current.IsEnabled(DumpFlags.PassIR))
|
||||
{
|
||||
case >= 2:
|
||||
CompilerServices.DumpIR(callable, "Start", options.DumpDir);
|
||||
break;
|
||||
case >= 1:
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
DumpScope.Current.DumpIR(input, "Start");
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// the callback function you can custom process func with run pass options.
|
||||
/// </summary>
|
||||
/// <param name="callable"> func with rewrited. </param>
|
||||
/// <param name="options"></param>
|
||||
protected virtual void OnPassEnd(BaseFunction callable, RunPassOptions options)
|
||||
/// <inheritdoc/>
|
||||
protected override Task OnPassEndAsync(BaseFunction post, RunPassContext context)
|
||||
{
|
||||
switch (options.DumpLevel)
|
||||
if (DumpScope.Current.IsEnabled(DumpFlags.PassIR))
|
||||
{
|
||||
case >= 2:
|
||||
CompilerServices.DumpIR(callable, "End", options.DumpDir);
|
||||
break;
|
||||
case >= 1:
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
DumpScope.Current.DumpIR(post, "End");
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private protected override string? GetDumpRelativePass(BaseFunction input) => input.Name;
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ using System.Collections.Generic;
|
|||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase.Transform;
|
||||
|
||||
|
@ -14,8 +15,45 @@ namespace Nncase.Transform;
|
|||
/// </summary>
|
||||
public interface IEGraph
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets version.
|
||||
/// </summary>
|
||||
int Version { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets eclasses.
|
||||
/// </summary>
|
||||
IEnumerable<EClass> Classes { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets nodes.
|
||||
/// </summary>
|
||||
IEnumerable<ENode> Nodes { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Add expr, get the eclass id.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression.</param>
|
||||
/// <returns>Eclass of this node.</returns>
|
||||
EClass Add(Expr expr);
|
||||
|
||||
/// <summary>
|
||||
/// Find eclass of enode.
|
||||
/// </summary>
|
||||
/// <param name="node">ENode.</param>
|
||||
/// <returns>EClass.</returns>
|
||||
EClass Find(ENode node);
|
||||
|
||||
/// <summary>
|
||||
/// Union two equal Eclass.
|
||||
/// </summary>
|
||||
/// <param name="classA">class a.</param>
|
||||
/// <param name="classB">class b.</param>
|
||||
/// <returns>If version changed.</returns>
|
||||
bool Union(EClass classA, EClass classB);
|
||||
|
||||
/// <summary>
|
||||
/// After merge, we use rebuild get new dep information.
|
||||
/// </summary>
|
||||
void Rebuild();
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ public interface IRewriteProvider
|
|||
/// <param name="rules">Rewrite rules.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
/// <returns>Rewrited expression.</returns>
|
||||
Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options);
|
||||
Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -37,5 +37,14 @@ public interface IEGraphRewriteProvider
|
|||
/// <param name="rules">Rewrite rules.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
/// <returns>Rewrited expression.</returns>
|
||||
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options);
|
||||
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite egraph.
|
||||
/// </summary>
|
||||
/// <param name="eGraph">EGraph.</param>
|
||||
/// <param name="rules">Rewrite rules.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
/// <returns>Rewrited EGraph.</returns>
|
||||
IEGraph ERewrite(IEGraph eGraph, IEnumerable<IRewriteRule> rules, RunPassContext options);
|
||||
}
|
||||
|
|
|
@ -22,16 +22,7 @@ public interface IRewriteRule
|
|||
/// </summary>
|
||||
/// <param name="result">Match result.</param>
|
||||
/// <returns>Replace expression or null if nothing changed.</returns>
|
||||
Expr? GetReplace(IMatchResult result, RunPassOptions options);
|
||||
|
||||
/// <summary>
|
||||
/// check this pattern can be modify in multi branch.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
bool IsMultiBranchSafe()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
Expr? GetReplace(IMatchResult result, RunPassContext options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -6,6 +6,7 @@ using System.Collections.Generic;
|
|||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase.Transform;
|
||||
|
@ -13,72 +14,40 @@ namespace Nncase.Transform;
|
|||
/// <summary>
|
||||
/// Pass in Callable scope.
|
||||
/// </summary>
|
||||
public abstract class ModulePass : BasePass
|
||||
public abstract class ModulePass : Pass<IRModule>
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ModulePass"/> class.
|
||||
/// </summary>
|
||||
/// <param name="name">Name.</param>
|
||||
public ModulePass(string name)
|
||||
: base(name)
|
||||
public ModulePass()
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Run current pass for module.
|
||||
/// </summary>
|
||||
/// <param name="module">Target module.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
|
||||
public async Task RunAsync(IRModule module, RunPassOptions options)
|
||||
/// <inheritdoc/>
|
||||
protected override Task OnPassStartAsync(IRModule input, RunPassContext context)
|
||||
{
|
||||
var new_options = options.IndentDir(Name);
|
||||
OnPassStart(module, new_options);
|
||||
await RunCoreAsync(module, new_options);
|
||||
OnPassEnd(module, new_options);
|
||||
if (DumpScope.Current.IsEnabled(DumpFlags.PassIR))
|
||||
{
|
||||
foreach (var func in input.Functions)
|
||||
{
|
||||
DumpScope.Current.DumpIR(func, func.Name, "Start");
|
||||
}
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Run pass implementation for derived class.
|
||||
/// </summary>
|
||||
/// <param name="module">Target module.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
|
||||
protected abstract Task RunCoreAsync(IRModule module, RunPassOptions options);
|
||||
|
||||
/// <summary>
|
||||
/// the callback function you can custom process func with run pass options.
|
||||
/// </summary>
|
||||
/// <param name="module"> module without run pass.</param>
|
||||
/// <param name="options"></param>
|
||||
protected virtual void OnPassStart(IRModule module, RunPassOptions options)
|
||||
/// <inheritdoc/>
|
||||
protected override Task OnPassEndAsync(IRModule post, RunPassContext context)
|
||||
{
|
||||
if (options.DumpLevel < 3)
|
||||
if (DumpScope.Current.IsEnabled(DumpFlags.PassIR))
|
||||
{
|
||||
return;
|
||||
foreach (var func in post.Functions)
|
||||
{
|
||||
DumpScope.Current.DumpIR(func, func.Name, "End");
|
||||
}
|
||||
}
|
||||
|
||||
foreach (var (func, i) in module.Functions.Select((func, i) => (func, i)))
|
||||
{
|
||||
CompilerServices.DumpIR(func, $"fn_{i}", Path.Combine(options.DumpDir, "Start"));
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// the callback function you can custom process func with run pass options.
|
||||
/// </summary>
|
||||
/// <param name="module"> module with rewrited. </param>
|
||||
/// <param name="options"></param>
|
||||
protected virtual void OnPassEnd(IRModule module, RunPassOptions options)
|
||||
{
|
||||
if (options.DumpLevel < 3)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
foreach (var (func, i) in module.Functions.Select((func, i) => (func, i)))
|
||||
{
|
||||
CompilerServices.DumpIR(func, $"fn_{i}", Path.Combine(options.DumpDir, "End"));
|
||||
}
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using Nncase.TIR;
|
||||
|
||||
namespace Nncase.Transform;
|
||||
|
||||
/// <summary>
|
||||
/// IR or TIR transformer.
|
||||
/// </summary>
|
||||
public interface IPass
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets pass name.
|
||||
/// </summary>
|
||||
string Name { get; set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// IR or TIR transformer.
|
||||
/// </summary>
|
||||
/// <typeparam name="T">Type to transform.</typeparam>
|
||||
public abstract class Pass<T> : IPass
|
||||
where T : class
|
||||
{
|
||||
private string? _name;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="Pass{T}"/> class.
|
||||
/// </summary>
|
||||
internal Pass()
|
||||
{
|
||||
CompileSession = CompileSessionScope.GetCurrentThrowIfNull();
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public string Name
|
||||
{
|
||||
get => _name ??= GetType().Name;
|
||||
set => _name = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets compile session.
|
||||
/// </summary>
|
||||
protected CompileSession CompileSession { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Run pass.
|
||||
/// </summary>
|
||||
/// <param name="input">Input object.</param>
|
||||
/// <param name="context">Run pass context.</param>
|
||||
/// <returns>Output object.</returns>
|
||||
public async Task<T> RunAsync(T input, RunPassContext context)
|
||||
{
|
||||
using var sessionScope = new CompileSessionScope(CompileSession);
|
||||
using var dumpScope = new DumpScope(Path.Join($"{context.Index}_{Name}", GetDumpRelativePass(input)));
|
||||
|
||||
await OnPassStartAsync(input, context);
|
||||
var output = await RunCoreAsync(input, context);
|
||||
await OnPassEndAsync(output, context);
|
||||
return output;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Run pass implementation for derived class.
|
||||
/// </summary>
|
||||
/// <param name="input">Input object.</param>
|
||||
/// <param name="context">Run pass context.</param>
|
||||
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
|
||||
protected abstract Task<T> RunCoreAsync(T input, RunPassContext context);
|
||||
|
||||
/// <summary>
|
||||
/// The callback function you can custom process func with run pass context.
|
||||
/// </summary>
|
||||
/// <param name="input">Input object.</param>
|
||||
/// <param name="context">Run pass context.</param>
|
||||
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
|
||||
protected abstract Task OnPassStartAsync(T input, RunPassContext context);
|
||||
|
||||
/// <summary>
|
||||
/// The callback function you can custom process func with run pass context.
|
||||
/// </summary>
|
||||
/// <param name="post">Post object.</param>
|
||||
/// <param name="context">Run pass context.</param>
|
||||
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
|
||||
protected abstract Task OnPassEndAsync(T post, RunPassContext context);
|
||||
|
||||
private protected virtual string? GetDumpRelativePass(T input) => null;
|
||||
}
|
|
@ -8,218 +8,231 @@ using System.Collections.Immutable;
|
|||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using NetFabric.Hyperlinq;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
|
||||
[assembly: System.Runtime.CompilerServices.InternalsVisibleTo("Nncase.Tests")]
|
||||
using Nncase.TIR;
|
||||
|
||||
namespace Nncase.Transform;
|
||||
|
||||
/// <summary>
|
||||
/// Passes addable.
|
||||
/// </summary>
|
||||
public interface IPassesAddable
|
||||
{
|
||||
/// <summary>
|
||||
/// Add pass.
|
||||
/// </summary>
|
||||
/// <typeparam name="T">Pass type.</typeparam>
|
||||
/// <param name="parameters">Pass constructor parameters.</param>
|
||||
/// <returns>Add result.</returns>
|
||||
AddPassResult<T> Add<T>(params object[] parameters)
|
||||
where T : class, IPass;
|
||||
|
||||
/// <summary>
|
||||
/// Add pass with name.
|
||||
/// </summary>
|
||||
/// <typeparam name="T">Pass type.</typeparam>
|
||||
/// <param name="name">Pass name.</param>
|
||||
/// <param name="parameters">Pass constructor parameters.</param>
|
||||
/// <returns>Add result.</returns>
|
||||
AddPassResult<T> AddWithName<T>(string name, params object[] parameters)
|
||||
where T : class, IPass;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Pass manager.
|
||||
/// </summary>
|
||||
public class PassManager : IEnumerable<BasePass>
|
||||
public interface IPassManager : IPassesAddable
|
||||
{
|
||||
private readonly IRModule _module;
|
||||
private readonly RunPassOptions _options;
|
||||
private readonly List<BasePass> _passes = new List<BasePass>();
|
||||
|
||||
private readonly Dictionary<BaseFunction, BaseFunction> _functions_update_map = new(ReferenceEqualityComparer.Instance);
|
||||
private readonly Dictionary<int, BaseFunction> _functions_mask = new();
|
||||
/// <summary>
|
||||
/// Gets name.
|
||||
/// </summary>
|
||||
public string Name { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PassManager"/> class.
|
||||
/// Gets passes.
|
||||
/// </summary>
|
||||
/// <param name="module">Module.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
public PassManager(IRModule module, RunPassOptions options)
|
||||
{
|
||||
_module = module;
|
||||
_options = options;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// foreach fix when the call target function has been updated.
|
||||
/// </summary>
|
||||
public static void FuncUpdateDependence(IRModule module, Dictionary<BaseFunction, BaseFunction> update_map, RunPassOptions options, string name)
|
||||
{
|
||||
var mutator = new DependenceMutator(update_map);
|
||||
_ = mutator.Visit(module.Entry!);
|
||||
if (!mutator.IsMutated)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < module.Functions.Count; i++)
|
||||
{
|
||||
if (update_map.TryGetValue(module.Functions[i], out var updated_func))
|
||||
{
|
||||
module.Update(i, updated_func);
|
||||
}
|
||||
}
|
||||
|
||||
if (options.DumpLevel > 2)
|
||||
{
|
||||
foreach (var item in module.Functions)
|
||||
{
|
||||
CompilerServices.DumpIR(item, string.Empty, Path.Combine(options.DumpDir, name, $"FuncUpdateDependence"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Add function pass.
|
||||
/// </summary>
|
||||
/// <param name="pass">Pass.</param>
|
||||
public void Add(BasePass pass)
|
||||
{
|
||||
_passes.Add(pass);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public IEnumerator<BasePass> GetEnumerator()
|
||||
{
|
||||
return ((IEnumerable<BasePass>)_passes).GetEnumerator();
|
||||
}
|
||||
public IReadOnlyList<IPass> Passes { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Run passes and update the module.
|
||||
/// </summary>
|
||||
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
|
||||
public async Task RunAsync()
|
||||
{
|
||||
var passes = _passes.AsEnumerable();
|
||||
while (passes.Any())
|
||||
{
|
||||
var type = passes.First().GetType();
|
||||
Type base_type;
|
||||
if (type.IsSubclassOf(typeof(FunctionPass)))
|
||||
{
|
||||
base_type = typeof(FunctionPass);
|
||||
}
|
||||
else if (type.IsSubclassOf(typeof(ModulePass)))
|
||||
{
|
||||
base_type = typeof(ModulePass);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentOutOfRangeException();
|
||||
}
|
||||
|
||||
var candiate = passes.TakeWhile(item => item.GetType().IsSubclassOf(base_type));
|
||||
passes = passes.Skip(candiate.Count());
|
||||
|
||||
if (type.IsSubclassOf(typeof(FunctionPass)))
|
||||
{
|
||||
await RunFunctionAsync(candiate.OfType<FunctionPass>());
|
||||
}
|
||||
else if (type.IsSubclassOf(typeof(ModulePass)))
|
||||
{
|
||||
await RunModuleAsync(candiate.OfType<ModulePass>());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentOutOfRangeException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
IEnumerator IEnumerable.GetEnumerator()
|
||||
{
|
||||
return ((IEnumerable)_passes).GetEnumerator();
|
||||
}
|
||||
|
||||
private async Task RunFunctionAsync(IEnumerable<FunctionPass> passes)
|
||||
{
|
||||
int i = 0;
|
||||
string name = string.Empty;
|
||||
while (i < _module.Functions.Count)
|
||||
{
|
||||
foreach (var pass in passes)
|
||||
{
|
||||
var pre = _module.Functions[i];
|
||||
var post = await pass.RunAsync(pre, _options);
|
||||
if (!object.ReferenceEquals(pre, post))
|
||||
{
|
||||
FuncUpdateRecord(i, pre, post);
|
||||
_module.Update(i, post);
|
||||
}
|
||||
|
||||
name = pass.Name;
|
||||
}
|
||||
|
||||
i++;
|
||||
}
|
||||
|
||||
FuncUpdateDependence(_module, _functions_update_map, _options, name);
|
||||
CleanFuncUpdateRecord();
|
||||
}
|
||||
|
||||
private async Task RunModuleAsync(IEnumerable<ModulePass> passes)
|
||||
{
|
||||
foreach (var pass in passes)
|
||||
{
|
||||
await pass.RunAsync(_module, _options);
|
||||
}
|
||||
}
|
||||
|
||||
private void CleanFuncUpdateRecord()
|
||||
{
|
||||
_functions_update_map.Clear();
|
||||
_functions_mask.Clear();
|
||||
}
|
||||
|
||||
private void FuncUpdateRecord(int i, BaseFunction current, BaseFunction updated)
|
||||
{
|
||||
// if function[i] has not been update, record it to origin function.
|
||||
if (!_functions_mask.TryGetValue(i, out var origin))
|
||||
{
|
||||
origin = current;
|
||||
_functions_mask.Add(i, origin);
|
||||
}
|
||||
|
||||
_functions_update_map[origin] = updated;
|
||||
}
|
||||
/// <param name="module">Input module.</param>
|
||||
/// <returns>A <see cref="Task{IRModule}"/> representing the asynchronous operation.</returns>
|
||||
Task<IRModule> RunAsync(IRModule module);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Update the function call dependencer.
|
||||
/// Add pass result.
|
||||
/// </summary>
|
||||
internal sealed class DependenceMutator : DeepExprMutator
|
||||
/// <typeparam name="T">Pass type.</typeparam>
|
||||
public struct AddPassResult<T> : IPassesAddable
|
||||
where T : class, IPass
|
||||
{
|
||||
public Dictionary<BaseFunction, BaseFunction> FunctionsUpdated;
|
||||
private readonly IPassManager _passManager;
|
||||
private readonly CompileSession _compileSession;
|
||||
|
||||
public DependenceMutator(Dictionary<BaseFunction, BaseFunction> functions_updated)
|
||||
internal AddPassResult(IPassManager passManager, CompileSession compileSession, T pass)
|
||||
{
|
||||
FunctionsUpdated = functions_updated;
|
||||
_passManager = passManager;
|
||||
_compileSession = compileSession;
|
||||
Pass = pass;
|
||||
}
|
||||
|
||||
public override Expr DefaultMutateLeaf(Expr expr)
|
||||
/// <summary>
|
||||
/// Gets pass.
|
||||
/// </summary>
|
||||
public T Pass { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public AddPassResult<TPass> Add<TPass>(params object[] parameters)
|
||||
where TPass : class, IPass => _passManager.Add<TPass>(parameters);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public AddPassResult<TPass> AddWithName<TPass>(string name, params object[] parameters)
|
||||
where TPass : class, IPass => _passManager.AddWithName<TPass>(name, parameters);
|
||||
|
||||
/// <summary>
|
||||
/// Configure pass.
|
||||
/// </summary>
|
||||
/// <param name="configureRule">Configure pass action.</param>
|
||||
/// <returns>This add result.</returns>
|
||||
public AddPassResult<T> Configure(Action<T> configureRule)
|
||||
{
|
||||
if (expr is BaseFunction baseFunction && FunctionsUpdated.TryGetValue(baseFunction, out var updated_basefunc))
|
||||
{
|
||||
return updated_basefunc;
|
||||
}
|
||||
|
||||
return expr;
|
||||
}
|
||||
|
||||
public override Expr Visit(BaseFunction baseFunction)
|
||||
{
|
||||
// first time enter function, mutate
|
||||
var nexpr = base.Visit(baseFunction);
|
||||
if (nexpr is BaseFunction updatedBasefunction && !object.ReferenceEquals(baseFunction, updatedBasefunction))
|
||||
{
|
||||
if (FunctionsUpdated.ContainsKey(baseFunction))
|
||||
{
|
||||
FunctionsUpdated[baseFunction] = updatedBasefunction;
|
||||
}
|
||||
else
|
||||
{
|
||||
FunctionsUpdated.Add(baseFunction, updatedBasefunction);
|
||||
}
|
||||
}
|
||||
|
||||
return nexpr;
|
||||
using var scope = new CompileSessionScope(_compileSession);
|
||||
configureRule(Pass);
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed class PassManager : IPassManager
|
||||
{
|
||||
private readonly CompileSession _compileSession;
|
||||
private readonly IDumpper _dummper;
|
||||
private readonly List<IPass> _passes = new List<IPass>();
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PassManager"/> class.
|
||||
/// </summary>
|
||||
/// <param name="name">Pass manager name.</param>
|
||||
/// <param name="compileSession">Compile session.</param>
|
||||
public PassManager(string name, CompileSession compileSession)
|
||||
{
|
||||
Name = name;
|
||||
_compileSession = compileSession;
|
||||
_dummper = DumpScope.GetCurrent(compileSession).CreateSubDummper(name);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets name.
|
||||
/// </summary>
|
||||
public string Name { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets passes.
|
||||
/// </summary>
|
||||
public IReadOnlyList<IPass> Passes => _passes;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public AddPassResult<T> Add<T>(params object[] parameters)
|
||||
where T : class, IPass
|
||||
{
|
||||
using var scope = new CompileSessionScope(_compileSession);
|
||||
using var dumpScope = new DumpScope(_dummper);
|
||||
var pass = ActivatorUtilities.CreateInstance<T>(_compileSession, parameters);
|
||||
_passes.Add(pass);
|
||||
return new(this, _compileSession, pass);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public AddPassResult<T> AddWithName<T>(string name, params object[] parameters)
|
||||
where T : class, IPass
|
||||
{
|
||||
var result = Add<T>(parameters);
|
||||
result.Configure(p => p.Name = name);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public async Task<IRModule> RunAsync(IRModule module)
|
||||
{
|
||||
using var dumpScope = new DumpScope(_dummper);
|
||||
for (int i = 0; i < _passes.Count; i++)
|
||||
{
|
||||
var task = _passes[i] switch
|
||||
{
|
||||
FunctionPass fp => RunAsync(module, i, fp),
|
||||
PrimFuncPass pfp => RunAsync(module, i, pfp),
|
||||
ModulePass mp => RunAsync(module, i, mp),
|
||||
_ => throw new NotSupportedException($"Unsupported pass type: {_passes[i].GetType().AssemblyQualifiedName}"),
|
||||
};
|
||||
|
||||
module = await task;
|
||||
}
|
||||
|
||||
return module;
|
||||
}
|
||||
|
||||
private async Task<IRModule> RunAsync(IRModule module, int passIndex, FunctionPass pass)
|
||||
{
|
||||
for (int i = 0; i < module.Functions.Count; i++)
|
||||
{
|
||||
var pre = module.Functions[i];
|
||||
var context = new RunPassContext { Index = passIndex };
|
||||
var post = await pass.RunAsync(pre, context);
|
||||
if (!object.ReferenceEquals(pre, post))
|
||||
{
|
||||
if (_dummper.IsEnabled(DumpFlags.PassIR))
|
||||
{
|
||||
_dummper.DumpModule(module, $"Before_{passIndex}_{pass.Name}");
|
||||
}
|
||||
|
||||
module.Replace(i, post);
|
||||
|
||||
if (_dummper.IsEnabled(DumpFlags.PassIR))
|
||||
{
|
||||
_dummper.DumpModule(module, $"After_{passIndex}_{pass.Name}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return module;
|
||||
}
|
||||
|
||||
private async Task<IRModule> RunAsync(IRModule module, int passIndex, PrimFuncPass pass)
|
||||
{
|
||||
for (int i = 0; i < module.Functions.Count; i++)
|
||||
{
|
||||
var pre = module.Functions[i];
|
||||
if (pre is PrimFunction pf)
|
||||
{
|
||||
var context = new RunPassContext { Index = passIndex };
|
||||
var post = await pass.RunAsync(pf, context);
|
||||
if (!object.ReferenceEquals(pre, post))
|
||||
{
|
||||
if (_dummper.IsEnabled(DumpFlags.PassIR))
|
||||
{
|
||||
_dummper.DumpModule(module, $"Before_{passIndex}_{pass.Name}");
|
||||
}
|
||||
|
||||
module.Replace(i, post);
|
||||
|
||||
if (_dummper.IsEnabled(DumpFlags.PassIR))
|
||||
{
|
||||
_dummper.DumpModule(module, $"After_{passIndex}_{pass.Name}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return module;
|
||||
}
|
||||
|
||||
private async Task<IRModule> RunAsync(IRModule module, int passIndex, ModulePass pass)
|
||||
{
|
||||
var context = new RunPassContext { Index = passIndex };
|
||||
return await pass.RunAsync(module, context);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using Nncase.TIR;
|
||||
|
||||
namespace Nncase.Transform;
|
||||
|
||||
/// <summary>
|
||||
/// TIR Mutator Pass.
|
||||
/// NOTE only apply on prim func
|
||||
/// Because of we will mutate the expression multiple times, so use MutatorCreator create the new mutator.
|
||||
/// </summary>
|
||||
public class PrimFuncPass : Pass<PrimFunction>
|
||||
{
|
||||
private readonly List<MutatorDescriptor> _mutatorDescriptors = new();
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PrimFuncPass"/> class.
|
||||
/// </summary>
|
||||
public PrimFuncPass()
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Add mutator.
|
||||
/// </summary>
|
||||
/// <typeparam name="T">Mutator type.</typeparam>
|
||||
/// <param name="configureMutator">Configure mutator action.</param>
|
||||
/// <param name="arguments">Mutator's constructor arguments.</param>
|
||||
/// <returns>This primfunc pass.</returns>
|
||||
public PrimFuncPass Add<T>(Action<T>? configureMutator, params object[] arguments)
|
||||
where T : ExprMutator
|
||||
{
|
||||
_mutatorDescriptors.Add(new()
|
||||
{
|
||||
Factory = ActivatorUtilities.CreateFactory(typeof(T), arguments.Select(x => x.GetType()).ToArray()),
|
||||
Configure = configureMutator == null ? null : x => configureMutator?.Invoke((T)x),
|
||||
Arguments = arguments,
|
||||
});
|
||||
return this;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override Task<PrimFunction> RunCoreAsync(PrimFunction input, RunPassContext context)
|
||||
{
|
||||
var post = input;
|
||||
int count = 0;
|
||||
bool isMutated = false;
|
||||
|
||||
do
|
||||
{
|
||||
foreach (var descriptor in _mutatorDescriptors)
|
||||
{
|
||||
var mutator = descriptor.Activate(CompileSession);
|
||||
post = (PrimFunction)mutator.Visit(post);
|
||||
isMutated = mutator.IsMutated;
|
||||
|
||||
if (isMutated)
|
||||
{
|
||||
var typeInferSuccess = CompilerServices.InferenceType(post);
|
||||
if (DumpScope.Current.IsEnabled(DumpFlags.PassIR))
|
||||
{
|
||||
DumpScope.Current.DumpIR(post, $"{count++}_{mutator.GetType().Name}");
|
||||
}
|
||||
|
||||
Trace.Assert(typeInferSuccess);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!isMutated)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
while (true);
|
||||
|
||||
return Task.FromResult(post);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override Task OnPassStartAsync(PrimFunction input, RunPassContext context) => Task.CompletedTask;
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override Task OnPassEndAsync(PrimFunction post, RunPassContext context) => Task.CompletedTask;
|
||||
|
||||
private protected override string? GetDumpRelativePass(PrimFunction input) => input.Name;
|
||||
|
||||
private struct MutatorDescriptor
|
||||
{
|
||||
public ObjectFactory Factory;
|
||||
public Action<ExprMutator>? Configure;
|
||||
public object[] Arguments;
|
||||
|
||||
public ExprMutator Activate(CompileSession compileSession)
|
||||
{
|
||||
using var scope = new CompileSessionScope(compileSession);
|
||||
var mutator = (ExprMutator)Factory(compileSession, Arguments);
|
||||
Configure?.Invoke(mutator);
|
||||
return mutator;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,7 +20,7 @@ public abstract class QuantRule : RewriteRule<Pattern>
|
|||
/// <summary>
|
||||
/// NOTE the option will be set by SourceGenerator when the GetReplace called.
|
||||
/// </summary>
|
||||
public RunPassOptions Option = null!;
|
||||
public RunPassContext Option = null!;
|
||||
|
||||
/// <summary>
|
||||
/// the match result
|
||||
|
@ -36,22 +36,22 @@ public abstract class QuantRule : RewriteRule<Pattern>
|
|||
/// <summary>
|
||||
/// Gets get ModelQuantMode.
|
||||
/// </summary>
|
||||
public ModelQuantMode ModelQuantMode => Option.CompileOptions.ModelQuantMode;
|
||||
public ModelQuantMode ModelQuantMode => CompileSession.CompileOptions.QuantizeOptions.ModelQuantMode;
|
||||
|
||||
/// <summary>
|
||||
/// Gets get QuantType.
|
||||
/// </summary>
|
||||
public DataType QuantType => Option.CompileOptions.QuantType;
|
||||
public DataType QuantType => CompileSession.CompileOptions.QuantizeOptions.QuantType;
|
||||
|
||||
/// <summary>
|
||||
/// Gets get WQuantType.
|
||||
/// </summary>
|
||||
public DataType WQuantType => Option.CompileOptions.WQuantType;
|
||||
public DataType WQuantType => CompileSession.CompileOptions.QuantizeOptions.WQuantType;
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether get UseMixQuant flag.
|
||||
/// </summary>
|
||||
public bool UseMixQuant => Option.CompileOptions.QuantizeOptions.BindQuantMethod;
|
||||
public bool UseMixQuant => CompileSession.CompileOptions.QuantizeOptions.BindQuantMethod;
|
||||
|
||||
/// <summary>
|
||||
/// check the datatype is the quant type.
|
||||
|
|
|
@ -18,6 +18,14 @@ namespace Nncase.Transform;
|
|||
public abstract class RewriteRule<TPattern> : IRewriteRule
|
||||
where TPattern : Pattern
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="RewriteRule{TPattern}"/> class.
|
||||
/// </summary>
|
||||
public RewriteRule()
|
||||
{
|
||||
CompileSession = CompileSessionScope.GetCurrentThrowIfNull();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets pattern.
|
||||
/// </summary>
|
||||
|
@ -25,12 +33,11 @@ public abstract class RewriteRule<TPattern> : IRewriteRule
|
|||
|
||||
IPattern IRewriteRule.Pattern => Pattern;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public bool IsMultiBranchSafe { get; init; }
|
||||
/// <summary>
|
||||
/// Gets compile session.
|
||||
/// </summary>
|
||||
protected CompileSession CompileSession { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
bool IRewriteRule.IsMultiBranchSafe() => IsMultiBranchSafe;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public abstract Expr? GetReplace(IMatchResult result, RunPassOptions options);
|
||||
public abstract Expr? GetReplace(IMatchResult result, RunPassContext options);
|
||||
}
|
||||
|
|
|
@ -7,92 +7,102 @@ using System.Collections.Generic;
|
|||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Nncase.TIR;
|
||||
|
||||
namespace Nncase.Transform;
|
||||
|
||||
public abstract class RulesPass : FunctionPass, IEnumerable<IRewriteRule>
|
||||
/// <summary>
|
||||
/// Rules addable.
|
||||
/// </summary>
|
||||
public interface IRulesAddable
|
||||
{
|
||||
private readonly List<IRewriteRule> _rules = new();
|
||||
/// <summary>
|
||||
/// Add the rewrite rule.
|
||||
/// </summary>
|
||||
/// <typeparam name="T">Rule type.</typeparam>
|
||||
/// <param name="parameters">Rule's constructor parameters.</param>
|
||||
/// <returns>Add result.</returns>
|
||||
RulesPass.AddResult<T> Add<T>(params object[] parameters)
|
||||
where T : class, IRewriteRule;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="RulesPass"/> class.
|
||||
/// Add the rewrite rule.
|
||||
/// </summary>
|
||||
/// <param name="name">Name.</param>
|
||||
public RulesPass(string name)
|
||||
: base(name)
|
||||
{
|
||||
}
|
||||
/// <param name="ruleType">Rule type.</param>
|
||||
/// <param name="parameters">Rule's constructor parameters.</param>
|
||||
/// <returns>Add result.</returns>
|
||||
RulesPass.AddResult<IRewriteRule> Add(Type ruleType, params object[] parameters);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Pass contains rewrite rules.
|
||||
/// </summary>
|
||||
public abstract class RulesPass : FunctionPass, IRulesAddable
|
||||
{
|
||||
private readonly List<IRewriteRule> _rules = new();
|
||||
|
||||
/// <summary>
|
||||
/// Gets rules.
|
||||
/// </summary>
|
||||
public IReadOnlyList<IRewriteRule> Rules => _rules;
|
||||
|
||||
/// <summary>
|
||||
/// add the pattern rule.
|
||||
/// </summary>
|
||||
/// <param name="rule">Rule.</param>
|
||||
public void Add(IRewriteRule rule) => _rules.Add(rule);
|
||||
|
||||
/// <summary>
|
||||
/// add the pattern rules.
|
||||
/// </summary>
|
||||
/// <param name="rules">Rules.</param>
|
||||
public void Add(params IRewriteRule[] rules) => _rules.AddRange(rules);
|
||||
|
||||
/// <summary>
|
||||
/// <see cref="Add(IRewriteRule[])"/>.
|
||||
/// </summary>
|
||||
/// <param name="rules">Rules.</param>
|
||||
public void Add(IEnumerable<IRewriteRule> rules) => _rules.AddRange(rules);
|
||||
/// <inheritdoc/>
|
||||
public AddResult<T> Add<T>(params object[] parameters)
|
||||
where T : class, IRewriteRule
|
||||
{
|
||||
using var scope = new CompileSessionScope(CompileSession);
|
||||
var rule = ActivatorUtilities.CreateInstance<T>(CompileSession, parameters);
|
||||
_rules.Add(rule);
|
||||
return new(this, rule);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public IEnumerator<IRewriteRule> GetEnumerator()
|
||||
public AddResult<IRewriteRule> Add(Type ruleType, params object[] parameters)
|
||||
{
|
||||
return _rules.GetEnumerator();
|
||||
}
|
||||
|
||||
IEnumerator IEnumerable.GetEnumerator()
|
||||
{
|
||||
return GetEnumerator();
|
||||
using var scope = new CompileSessionScope(CompileSession);
|
||||
var rule = (IRewriteRule)ActivatorUtilities.CreateInstance(CompileSession, ruleType, parameters);
|
||||
_rules.Add(rule);
|
||||
return new(this, rule);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// the callback function you can custom process func with run pass options.
|
||||
/// Add rule result.
|
||||
/// </summary>
|
||||
/// <param name="callable"> func without run pass.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
protected override void OnPassStart(BaseFunction callable, RunPassOptions options)
|
||||
/// <typeparam name="T">Pass type.</typeparam>
|
||||
public struct AddResult<T> : IRulesAddable
|
||||
where T : class, IRewriteRule
|
||||
{
|
||||
switch (options.DumpLevel)
|
||||
private readonly RulesPass _rulesPass;
|
||||
|
||||
internal AddResult(RulesPass rulesPass, T rule)
|
||||
{
|
||||
case >= 2:
|
||||
CompilerServices.DumpIR((Expr)callable, "Start", options.DumpDir);
|
||||
break;
|
||||
case >= 1:
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
_rulesPass = rulesPass;
|
||||
Rule = rule;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// the callback function you can custom process func with run pass options.
|
||||
/// </summary>
|
||||
/// <param name="callable"> func with rewrited. </param>
|
||||
/// <param name="options">Options.</param>
|
||||
protected override void OnPassEnd(BaseFunction callable, RunPassOptions options)
|
||||
{
|
||||
switch (options.DumpLevel)
|
||||
/// <summary>
|
||||
/// Gets rule.
|
||||
/// </summary>
|
||||
public T Rule { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public AddResult<TRule> Add<TRule>(params object[] parameters)
|
||||
where TRule : class, IRewriteRule => _rulesPass.Add<TRule>(parameters);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public AddResult<IRewriteRule> Add(Type ruleType, params object[] parameters)
|
||||
=> _rulesPass.Add(ruleType, parameters);
|
||||
|
||||
/// <summary>
|
||||
/// Configure rule.
|
||||
/// </summary>
|
||||
/// <param name="configureRule">Configure rule action.</param>
|
||||
/// <returns>This add result.</returns>
|
||||
public AddResult<T> Configure(Action<T> configureRule)
|
||||
{
|
||||
case >= 2:
|
||||
CompilerServices.DumpIR((Expr)callable, "End", options.DumpDir);
|
||||
break;
|
||||
case >= 1:
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
configureRule(Rule);
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.PatternMatch;
|
||||
|
||||
namespace Nncase.Transform;
|
||||
|
||||
/// <summary>
|
||||
/// Options for running pass.
|
||||
/// </summary>
|
||||
public sealed record RunPassContext
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets pass index in a <see cref="PassManager"/>.
|
||||
/// </summary>
|
||||
public int Index { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets a value indicating whether control rewrite once or not.
|
||||
/// Default is false.
|
||||
/// </summary>
|
||||
public bool RewriteOnce { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the match option.
|
||||
/// </summary>
|
||||
public MatchOptions MatchOptions { get; set; } = new MatchOptions();
|
||||
}
|
|
@ -1,169 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.PatternMatch;
|
||||
|
||||
namespace Nncase.Transform
|
||||
{
|
||||
/// <summary>
|
||||
/// Options for running pass.
|
||||
/// </summary>
|
||||
public class RunPassOptions
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="RunPassOptions"/> class.
|
||||
/// </summary>
|
||||
/// <param name="compileOptions"></param>
|
||||
public RunPassOptions(CompileOptions compileOptions)
|
||||
{
|
||||
Target = CompilerServices.GetTarget(compileOptions.Target);
|
||||
DumpLevel = compileOptions.DumpLevel;
|
||||
DumpDir = compileOptions.DumpDir;
|
||||
CompileOptions = compileOptions;
|
||||
PassName = string.Empty;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="RunPassOptions"/> class.
|
||||
/// parameterless ctor.
|
||||
/// </summary>
|
||||
public RunPassOptions(ITarget target)
|
||||
{
|
||||
Target = target;
|
||||
DumpLevel = CompilerServices.CompileOptions.DumpLevel;
|
||||
DumpDir = CompilerServices.CompileOptions.DumpDir;
|
||||
CompileOptions = CompilerServices.CompileOptions;
|
||||
PassName = string.Empty;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="RunPassOptions"/> class.
|
||||
/// constructor.
|
||||
/// </summary>
|
||||
/// <param name="target"> target device. </param>
|
||||
/// <param name="dumpLevel"> int level. </param>
|
||||
/// <param name="dumpDir"> dir. </param>
|
||||
public RunPassOptions(ITarget target, int dumpLevel, string dumpDir)
|
||||
: this(target, dumpLevel, dumpDir, CompilerServices.CompileOptions)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="RunPassOptions"/> class.
|
||||
/// create the run pass options.
|
||||
/// </summary>
|
||||
/// <param name="target"></param>
|
||||
/// <param name="dumpLevel"></param>
|
||||
/// <param name="dumpDir"></param>
|
||||
/// <param name="options"></param>
|
||||
public RunPassOptions(ITarget target, int dumpLevel, string dumpDir, CompileOptions options)
|
||||
{
|
||||
Target = target;
|
||||
DumpLevel = dumpLevel;
|
||||
DumpDir = dumpDir;
|
||||
PassName = string.Empty;
|
||||
RewriteOnce = false;
|
||||
CompileOptions = options;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="RunPassOptions"/> class.
|
||||
/// copy construct.
|
||||
/// </summary>
|
||||
/// <param name="other"></param>
|
||||
public RunPassOptions(RunPassOptions other)
|
||||
{
|
||||
Target = other.Target;
|
||||
DumpLevel = other.DumpLevel;
|
||||
DumpDir = other.DumpDir;
|
||||
PassName = other.PassName;
|
||||
RewriteOnce = other.RewriteOnce;
|
||||
CompileOptions = other.CompileOptions;
|
||||
MatchOptions = other.MatchOptions;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the invalid pass.
|
||||
/// </summary>
|
||||
public static RunPassOptions Invalid => new RunPassOptions(null!, -1, string.Empty);
|
||||
|
||||
/// <summary>
|
||||
/// Gets target.
|
||||
/// </summary>
|
||||
public ITarget Target { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets dump level 0 = do nothing
|
||||
/// Dump level 1 = print to std output
|
||||
/// Dump level 2 = print dump to file.
|
||||
/// </summary>
|
||||
public int DumpLevel { get; private set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets dump dir.
|
||||
/// </summary>
|
||||
public string DumpDir { get; private set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets current pass name.
|
||||
/// </summary>
|
||||
public string PassName { get; private set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether control rewrite once or not.
|
||||
/// Default is false.
|
||||
/// </summary>
|
||||
public bool RewriteOnce { get; private set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets get the compile options.
|
||||
/// </summary>
|
||||
public CompileOptions CompileOptions { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the match option.
|
||||
/// </summary>
|
||||
public MatchOptions MatchOptions { get; set; } = new MatchOptions();
|
||||
|
||||
/// <summary>
|
||||
/// set the pass name.
|
||||
/// </summary>
|
||||
/// <param name="name"></param>
|
||||
/// <returns></returns>
|
||||
public RunPassOptions SetPassName(string name) => new(Target, DumpLevel, DumpDir, CompileOptions) { PassName = name, MatchOptions = MatchOptions };
|
||||
|
||||
/// <summary>
|
||||
/// set the dumpDir.
|
||||
/// </summary>
|
||||
/// <param name="path"></param>
|
||||
/// <returns></returns>
|
||||
public RunPassOptions SetDumpDir(string path) => new(Target, DumpLevel, path, CompileOptions) { PassName = PassName, MatchOptions = MatchOptions };
|
||||
|
||||
/// <summary>
|
||||
/// set the dump level.
|
||||
/// </summary>
|
||||
/// <param name="dumpLevel"></param>
|
||||
/// <returns></returns>
|
||||
public RunPassOptions SetDumpLevel(int dumpLevel) => new(Target, dumpLevel, DumpDir, CompileOptions) { PassName = PassName, MatchOptions = MatchOptions };
|
||||
|
||||
/// <summary>
|
||||
/// set the RewriteOnce.
|
||||
/// </summary>
|
||||
/// <param name="once"></param>
|
||||
/// <returns></returns>
|
||||
public RunPassOptions SetRewriteOnce(bool once) => new(Target, DumpLevel, DumpDir, CompileOptions) { PassName = PassName, RewriteOnce = once, MatchOptions = MatchOptions };
|
||||
|
||||
/// <summary>
|
||||
/// indent the dumpDir.
|
||||
/// </summary>
|
||||
/// <param name="path"></param>
|
||||
/// <returns></returns>
|
||||
public RunPassOptions IndentDir(string path) => new(Target, DumpLevel, Path.Combine(DumpDir, path), CompileOptions) { PassName = PassName, MatchOptions = MatchOptions };
|
||||
}
|
||||
}
|
|
@ -1,145 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using Nncase.IR;
|
||||
using Nncase.TIR;
|
||||
|
||||
namespace Nncase.Transform
|
||||
{
|
||||
/// <summary>
|
||||
/// TIR Mutator Pass.
|
||||
/// NOTE only apply on prim func
|
||||
/// Because of we will mutate the expression multiple times, so use MutatorCreator create the new mutator.
|
||||
/// </summary>
|
||||
public class PrimFuncPass : FunctionPass, IEnumerable<Func<ExprMutator>>
|
||||
{
|
||||
/// <summary>
|
||||
/// Save rules.
|
||||
/// </summary>
|
||||
public readonly List<Func<ExprMutator>> MutatorCreators = new();
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PrimFuncPass"/> class.
|
||||
/// </summary>
|
||||
/// <param name="name">Name.</param>
|
||||
public PrimFuncPass(string name)
|
||||
: base(name)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public IEnumerator<Func<ExprMutator>> GetEnumerator()
|
||||
{
|
||||
return ((IEnumerable<Func<ExprMutator>>)MutatorCreators).GetEnumerator();
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
IEnumerator IEnumerable.GetEnumerator()
|
||||
{
|
||||
return ((IEnumerable)MutatorCreators).GetEnumerator();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// add the mutator.
|
||||
/// </summary>
|
||||
/// <param name="mutator"></param>
|
||||
public void Add(Func<ExprMutator> mutator)
|
||||
{
|
||||
MutatorCreators.Add(mutator);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override Task<BaseFunction> RunCoreAsync(BaseFunction callable, RunPassOptions options)
|
||||
{
|
||||
if (callable is not PrimFunction)
|
||||
{
|
||||
return Task.FromResult(callable);
|
||||
}
|
||||
|
||||
var post = callable;
|
||||
var last = post;
|
||||
int count = 0;
|
||||
var typeinfer_ret = true;
|
||||
do
|
||||
{
|
||||
bool isMutated = false;
|
||||
foreach (var creator in MutatorCreators)
|
||||
{
|
||||
var mutator = creator();
|
||||
string mutator_name = mutator.GetType().Name;
|
||||
last = post;
|
||||
post = (BaseFunction)mutator.Visit(last);
|
||||
isMutated = mutator.IsMutated;
|
||||
if (isMutated)
|
||||
{
|
||||
typeinfer_ret = CompilerServices.InferenceType(post);
|
||||
OnMutated(post, $"{count++}_{mutator.GetType().Name}", options);
|
||||
if (!typeinfer_ret)
|
||||
{
|
||||
throw new InvalidOperationException($"{Name}: After Run Mutator {count - 1}_{mutator_name} , The Type Inference Failed!");
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!isMutated)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
while (true);
|
||||
|
||||
return Task.FromResult(post);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// the callback function you can custom process func with run pass options.
|
||||
/// </summary>
|
||||
/// <param name="callable"> func without run pass.</param>
|
||||
/// <param name="options"></param>
|
||||
protected override void OnPassStart(BaseFunction callable, RunPassOptions options)
|
||||
{
|
||||
if (callable is not PrimFunction)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
base.OnPassStart(callable, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// the callback function you can custom process func with run pass options.
|
||||
/// </summary>
|
||||
/// <param name="callable"> func with rewrited. </param>
|
||||
/// <param name="options"></param>
|
||||
protected override void OnPassEnd(BaseFunction callable, RunPassOptions options)
|
||||
{
|
||||
if (callable is not PrimFunction)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
base.OnPassEnd(callable, options);
|
||||
}
|
||||
|
||||
private void OnMutated(BaseFunction callable, string prefix, RunPassOptions options)
|
||||
{
|
||||
switch (options.DumpLevel)
|
||||
{
|
||||
case >= 2:
|
||||
CompilerServices.DumpIR((Expr)callable, prefix, options.DumpDir, false);
|
||||
break;
|
||||
case >= 1:
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -68,17 +68,6 @@ public static class ValueDumper
|
|||
DumpTensors(tensorValue.Select(x => x.AsTensor()).ToArray(), sr);
|
||||
}
|
||||
}
|
||||
|
||||
public static string GetMaybeDumpDir(string dir)
|
||||
{
|
||||
var root = Path.Join(CompilerServices.CompileOptions.DumpDir, dir);
|
||||
if (!Directory.Exists(root))
|
||||
{
|
||||
Directory.CreateDirectory(root);
|
||||
}
|
||||
|
||||
return root;
|
||||
}
|
||||
}
|
||||
|
||||
public static class DumpUtility
|
||||
|
@ -169,93 +158,3 @@ public static class DumpUtility
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class DumpManager
|
||||
{
|
||||
public static bool Append;
|
||||
|
||||
public static int Count = 1;
|
||||
|
||||
public static string Dir;
|
||||
|
||||
public static bool OpenDump { get; private set; }
|
||||
|
||||
public string CountStr => Count.ToString();
|
||||
|
||||
public static void RunWithDump(string dir, Action f)
|
||||
{
|
||||
RunWithDump<int>(dir, () =>
|
||||
{
|
||||
f();
|
||||
|
||||
// discard return value
|
||||
return -1;
|
||||
});
|
||||
}
|
||||
|
||||
public static T RunWithDump<T>(string dir, Func<T> f)
|
||||
{
|
||||
Dir = dir;
|
||||
Count = 1;
|
||||
OpenDump = true;
|
||||
Append = false;
|
||||
var result = f();
|
||||
OpenDump = false;
|
||||
return result;
|
||||
}
|
||||
|
||||
public string GetMaybeDumpDir()
|
||||
{
|
||||
return ValueDumper.GetMaybeDumpDir(Dir);
|
||||
}
|
||||
|
||||
protected void UpdateOrder(string root, string target, Shape shape)
|
||||
{
|
||||
using (var order = new StreamWriter(Path.Join(root, "!out_shape_list"), Append))
|
||||
{
|
||||
order.WriteLine($"{target}: {DumpUtility.SerializeShape(shape)}");
|
||||
}
|
||||
}
|
||||
|
||||
protected void DumpCallParam(string target, ParameterInfo info, Action<StreamWriter> f)
|
||||
{
|
||||
var path = Path.Join(GetMaybeDumpDir(), $"{CountStr}${target}${info.Name}");
|
||||
using (var sr = new StreamWriter(path))
|
||||
{
|
||||
f(sr);
|
||||
}
|
||||
}
|
||||
|
||||
protected void DumpCall(string target, Shape shape, Action<StreamWriter> f)
|
||||
{
|
||||
var path = Path.Join(GetMaybeDumpDir(), $"{CountStr}${target}");
|
||||
using (var sr = new StreamWriter(path))
|
||||
{
|
||||
f(sr);
|
||||
}
|
||||
|
||||
UpdateOrder(GetMaybeDumpDir(), target, shape);
|
||||
Append = true;
|
||||
++Count;
|
||||
}
|
||||
}
|
||||
|
||||
public class Counter
|
||||
{
|
||||
private int _count;
|
||||
|
||||
public Counter(int count = 0)
|
||||
{
|
||||
_count = count;
|
||||
}
|
||||
|
||||
public T Run<T>(Func<int, T> f)
|
||||
{
|
||||
return f(_count++);
|
||||
}
|
||||
|
||||
public void Run(Action<int> f)
|
||||
{
|
||||
f(_count++);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ public class ReplaceUtility
|
|||
/// usage:
|
||||
/// Call(FakeXXX, input, otherArg1, ...)
|
||||
/// newInput => Call(op, newInput, otherArg1, ...)
|
||||
/// it's always used for Fake to NoFake Rule with IsWildcardCall.
|
||||
/// it's always used for Fake to NoFake Pass with IsWildcardCall.
|
||||
/// </summary>
|
||||
/// <param name="call"></param>
|
||||
/// <param name="op"></param>
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using DryIoc;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.Hosting;
|
||||
|
||||
namespace Nncase.Diagnostics;
|
||||
|
||||
/// <summary>
|
||||
/// Diagnostics module.
|
||||
/// </summary>
|
||||
internal class DiagnosticsModule : IApplicationPart
|
||||
{
|
||||
public void ConfigureServices(IRegistrator registrator)
|
||||
{
|
||||
registrator.Register<IDumpperFactory, DumpperFactory>(reuse: Reuse.Scoped);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase.Diagnostics;
|
||||
|
||||
internal sealed class Dumpper : IDumpper
|
||||
{
|
||||
private readonly DumpFlags _dumpFlags;
|
||||
private readonly string _dumpDirectory;
|
||||
|
||||
public Dumpper(DumpFlags dumpFlags, string dumpDirectory)
|
||||
{
|
||||
_dumpFlags = dumpFlags;
|
||||
_dumpDirectory = dumpDirectory;
|
||||
}
|
||||
|
||||
public string Directory => _dumpDirectory;
|
||||
|
||||
public IDumpper CreateSubDummper(string subDirectory)
|
||||
{
|
||||
return new Dumpper(_dumpFlags, Path.Join(_dumpDirectory, subDirectory));
|
||||
}
|
||||
|
||||
public void DumpIR(Expr expr, string prefix, string? reletivePath = null)
|
||||
{
|
||||
var path = Path.Join(_dumpDirectory, reletivePath);
|
||||
CompilerServices.DumpIR(expr, prefix, EnsureWritable(path));
|
||||
}
|
||||
|
||||
public void DumpDotIR(Expr expr, string prefix, string? reletivePath = null)
|
||||
{
|
||||
var path = Path.Join(_dumpDirectory, reletivePath);
|
||||
CompilerServices.DumpDotIR(expr, prefix, EnsureWritable(path));
|
||||
}
|
||||
|
||||
public void DumpModule(IRModule module, string? reletivePath = null)
|
||||
{
|
||||
foreach (var func in module.Functions)
|
||||
{
|
||||
DumpIR(func, string.Empty, reletivePath);
|
||||
}
|
||||
}
|
||||
|
||||
public bool IsEnabled(DumpFlags dumpFlags)
|
||||
{
|
||||
return _dumpFlags.HasFlag(dumpFlags);
|
||||
}
|
||||
|
||||
public Stream OpenFile(string reletivePath, FileMode fileMode)
|
||||
{
|
||||
var path = Path.Join(_dumpDirectory, reletivePath);
|
||||
return File.Open(EnsureWritable(path), fileMode);
|
||||
}
|
||||
|
||||
private static string EnsureWritable(string path)
|
||||
{
|
||||
var directory = Path.GetDirectoryName(path) ?? throw new ArgumentException($"Invalid path: {path}");
|
||||
System.IO.Directory.CreateDirectory(directory);
|
||||
return path;
|
||||
}
|
||||
}
|
||||
|
||||
internal sealed class DumpperFactory : IDumpperFactory
|
||||
{
|
||||
private readonly CompileOptions _compileOptions;
|
||||
|
||||
public DumpperFactory(CompileOptions compileOptions)
|
||||
{
|
||||
_compileOptions = compileOptions;
|
||||
}
|
||||
|
||||
public IDumpper Root => new Dumpper(_compileOptions.DumpFlags, _compileOptions.DumpDir);
|
||||
|
||||
public IDumpper CreateDummper(string relativePath)
|
||||
{
|
||||
return new Dumpper(_compileOptions.DumpFlags, Path.Join(_compileOptions.DumpDir, relativePath));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using DryIoc;
|
||||
using Nncase.Converters;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.Hosting;
|
||||
using Nncase.Targets;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
/// <summary>
|
||||
/// Diagnostics application part extensions.
|
||||
/// </summary>
|
||||
public static class DiagnosticsApplicationPart
|
||||
{
|
||||
/// <summary>
|
||||
/// Add diagnostics assembly.
|
||||
/// </summary>
|
||||
/// <param name="registrator">Service registrator.</param>
|
||||
/// <returns>Configured service registrator.</returns>
|
||||
public static IRegistrator AddDiagnostics(this IRegistrator registrator)
|
||||
{
|
||||
return registrator.RegisterModule<DiagnosticsModule>();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net6.0</TargetFramework>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
<RootNamespace>Nncase</RootNamespace>
|
||||
<GenerateDocumentationFile>true</GenerateDocumentationFile>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\Nncase.Core\Nncase.Core.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
|
@ -4,6 +4,7 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Drawing;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using GiGraph.Dot.Entities.Clusters;
|
||||
using GiGraph.Dot.Entities.Graphs;
|
||||
|
@ -23,12 +24,12 @@ namespace Nncase.Transform;
|
|||
|
||||
public partial class EGraphPrinter
|
||||
{
|
||||
internal static DotGraph DumpEgraphAsDot(EGraph eGraph, CostModel.EGraphCostModel costModel, EClass entry, string file)
|
||||
internal static DotGraph DumpEgraphAsDot(EGraph eGraph, CostModel.EGraphCostModel costModel, EClass entry, Stream file)
|
||||
{
|
||||
var printer = new EGraphPrinter(eGraph);
|
||||
printer.ConvertEGraphAsDot();
|
||||
printer.AttachEGraphCost(costModel, entry);
|
||||
return printer.SaveToFile(file);
|
||||
return printer.SaveToStream(file);
|
||||
}
|
||||
|
||||
private DotGraph AttachEGraphCost(CostModel.EGraphCostModel costModel, EClass entry)
|
||||
|
|
|
@ -7,6 +7,10 @@ using System.Linq;
|
|||
using System.Reflection;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using DryIoc;
|
||||
using Nncase.Hosting;
|
||||
using Nncase.PatternMatch;
|
||||
using Nncase.Transform;
|
||||
|
||||
namespace Nncase;
|
||||
|
||||
|
@ -18,11 +22,11 @@ public static class EGraphApplicationPart
|
|||
/// <summary>
|
||||
/// Add egraph assembly.
|
||||
/// </summary>
|
||||
/// <param name="assemblies">Assembly collection.</param>
|
||||
/// <returns>Updated assembly collection.</returns>
|
||||
public static IList<Assembly> AddEGraph(this IList<Assembly> assemblies)
|
||||
/// <param name="registrator">Service registrator.</param>
|
||||
/// <returns>Configured service registrator.</returns>
|
||||
public static IRegistrator AddEGraph(this IRegistrator registrator)
|
||||
{
|
||||
assemblies.Add(typeof(EGraphApplicationPart).Assembly);
|
||||
return assemblies;
|
||||
return registrator.RegisterModule<PatternMatchModule>()
|
||||
.RegisterModule<TransformModule>();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,18 +1,19 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Autofac;
|
||||
using DryIoc;
|
||||
using Nncase.Hosting;
|
||||
using Nncase.Transform;
|
||||
|
||||
namespace Nncase.PatternMatch;
|
||||
|
||||
/// <summary>
|
||||
/// PatternMatch module.
|
||||
/// </summary>
|
||||
public class PatternMatchModule : Module
|
||||
internal class PatternMatchModule : IApplicationPart
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
protected override void Load(ContainerBuilder builder)
|
||||
public void ConfigureServices(IRegistrator registrator)
|
||||
{
|
||||
builder.RegisterType<EGraphMatchProvider>().AsImplementedInterfaces().SingleInstance();
|
||||
registrator.RegisterManyInterface<EGraphMatchProvider>(reuse: Reuse.Singleton);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,24 +46,16 @@ public sealed partial class EGraph : IEGraph
|
|||
Add(expr);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets eclasses.
|
||||
/// </summary>
|
||||
/// <inheritdoc/>
|
||||
public IEnumerable<EClass> Classes => _classes;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public IEnumerable<ENode> Nodes => _nodes.Keys;
|
||||
|
||||
/// <summary>
|
||||
/// Gets version.
|
||||
/// </summary>
|
||||
/// <inheritdoc/>
|
||||
public int Version => _version;
|
||||
|
||||
/// <summary>
|
||||
/// Add expr, get the eclass id.
|
||||
/// </summary>
|
||||
/// <param name="expr">Expression.</param>
|
||||
/// <returns>Eclass of this node.</returns>
|
||||
/// <inheritdoc/>
|
||||
public EClass Add(Expr expr)
|
||||
{
|
||||
if (expr.CheckedType is null)
|
||||
|
@ -75,22 +67,13 @@ public sealed partial class EGraph : IEGraph
|
|||
return converter.Visit(expr);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Find eclass of enode.
|
||||
/// </summary>
|
||||
/// <param name="node">ENode.</param>
|
||||
/// <returns>EClass.</returns>
|
||||
/// <inheritdoc/>
|
||||
public EClass Find(ENode node)
|
||||
{
|
||||
return _nodes[node].Class;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Union two equal Eclass.
|
||||
/// </summary>
|
||||
/// <param name="classA">class a.</param>
|
||||
/// <param name="classB">class b.</param>
|
||||
/// <returns>If version changed.</returns>
|
||||
/// <inheritdoc/>
|
||||
public bool Union(EClass classA, EClass classB)
|
||||
{
|
||||
classA = classA.Find();
|
||||
|
@ -123,9 +106,7 @@ public sealed partial class EGraph : IEGraph
|
|||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// After merge, we use rebuild get new dep information.
|
||||
/// </summary>
|
||||
/// <inheritdoc/>
|
||||
public void Rebuild()
|
||||
{
|
||||
while (_worklist.Count > 0)
|
||||
|
@ -190,7 +171,7 @@ public sealed partial class EGraph : IEGraph
|
|||
_nodes.Remove(enode);
|
||||
originalClass.RemoveNode(enode);
|
||||
|
||||
// 2. Update node's children
|
||||
// 2. Replace node's children
|
||||
var newNode = enode.Canonicalize();
|
||||
|
||||
if (_nodes.TryGetValue(newNode, out var existingEntry))
|
||||
|
|
|
@ -6,7 +6,9 @@ using System.Collections.Generic;
|
|||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using LanguageExt.ClassInstances;
|
||||
using Nncase.CostModel;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using Nncase.PatternMatch;
|
||||
using static Nncase.PatternMatch.F.Math;
|
||||
|
@ -25,9 +27,8 @@ public static class EGraphExtractExtensions
|
|||
/// <param name="eGraph">eGraph.</param>
|
||||
/// <param name="root">Root eclass.</param>
|
||||
/// <param name="basefunc_cost_evaluator">base func cost evaluator.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
/// <returns>Extracted root expression.</returns>
|
||||
public static Expr Extract(this EGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, RunPassOptions options)
|
||||
public static Expr Extract(this EGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator)
|
||||
{
|
||||
// 1. set the all expr checked shape
|
||||
foreach (var eclass in eGraph.Classes)
|
||||
|
@ -43,12 +44,13 @@ public static class EGraphExtractExtensions
|
|||
|
||||
// 2. start the cost evaluator
|
||||
var costModel = new EGraphCostEvaluator(root.Find(), basefunc_cost_evaluator).Evaluate();
|
||||
if (options.DumpLevel > 2)
|
||||
if (DumpScope.Current.IsEnabled(DumpFlags.EGraphCost))
|
||||
{
|
||||
EGraphPrinter.DumpEgraphAsDot(eGraph, costModel, root.Find(), Path.Combine(options.DumpDir, "Costs", $"V{eGraph.Version}"));
|
||||
using var fs = DumpScope.Current.OpenFile(Path.Combine("Costs", $"V{eGraph.Version}.dot"));
|
||||
EGraphPrinter.DumpEgraphAsDot(eGraph, costModel, root.Find(), fs);
|
||||
}
|
||||
|
||||
return new EGraphExtractor(costModel, options).Extract(root.Find());
|
||||
return new EGraphExtractor(costModel).Extract(root.Find());
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -93,19 +95,29 @@ public static class EGraphExtractExtensions
|
|||
internal class EGraphExtractor
|
||||
{
|
||||
private readonly EGraphCostModel _costModel;
|
||||
private readonly RunPassOptions _options;
|
||||
private readonly Dictionary<EClass, Expr> _eclassMemo = new();
|
||||
private readonly Dictionary<EClass, Expr> _markerEclassMemo = new();
|
||||
private StreamWriter? _dumpWriter;
|
||||
|
||||
public EGraphExtractor(EGraphCostModel costModel, RunPassOptions options)
|
||||
public EGraphExtractor(EGraphCostModel costModel)
|
||||
{
|
||||
_costModel = costModel;
|
||||
_options = options;
|
||||
}
|
||||
|
||||
public Expr Extract(EClass root)
|
||||
{
|
||||
Visit(root);
|
||||
_dumpWriter = DumpScope.Current.IsEnabled(DumpFlags.EGraphCost)
|
||||
? new StreamWriter(DumpScope.Current.OpenFile($"{nameof(EGraphExtractor)}_Class_{root.Id}.txt"))
|
||||
: null;
|
||||
try
|
||||
{
|
||||
Visit(root);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_dumpWriter?.Dispose();
|
||||
}
|
||||
|
||||
return _eclassMemo[root];
|
||||
}
|
||||
|
||||
|
@ -219,17 +231,17 @@ internal class EGraphExtractor
|
|||
var parameters = children.Skip(1);
|
||||
|
||||
// for mix quant debug.
|
||||
if (call.EnodeQuantConfigWithCosine != null && _options.DumpLevel > 3)
|
||||
if (call.EnodeQuantConfigWithCosine != null && _dumpWriter != null)
|
||||
{
|
||||
Console.WriteLine(call + " " + call.CheckedType);
|
||||
_dumpWriter.WriteLine(call + " " + call.CheckedType);
|
||||
for (int i = 0; i < call.EnodeQuantConfigWithCosine.Count; i++)
|
||||
{
|
||||
for (int j = 0; j < call.EnodeQuantConfigWithCosine[i].Item1.Count; j++)
|
||||
{
|
||||
Console.Write(call.EnodeQuantConfigWithCosine[i].Item1[j] + " ");
|
||||
_dumpWriter.Write(call.EnodeQuantConfigWithCosine[i].Item1[j] + " ");
|
||||
}
|
||||
|
||||
Console.WriteLine(call.EnodeQuantConfigWithCosine[i].Item3);
|
||||
_dumpWriter.WriteLine(call.EnodeQuantConfigWithCosine[i].Item3);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Drawing;
|
||||
using System.IO;
|
||||
using GiGraph.Dot.Entities.Clusters;
|
||||
using GiGraph.Dot.Entities.Graphs;
|
||||
using GiGraph.Dot.Entities.Nodes;
|
||||
|
@ -23,12 +24,12 @@ public partial class EGraphPrinter
|
|||
{
|
||||
private readonly Color[] _knownColors = new Color[] { Color.MediumAquamarine, Color.MediumBlue, Color.MediumOrchid, Color.MediumPurple, Color.MediumSeaGreen, Color.MediumSlateBlue, Color.MediumSpringGreen, Color.Maroon, Color.MediumTurquoise, Color.MidnightBlue, Color.MintCream, Color.MistyRose, Color.Moccasin, Color.NavajoWhite, Color.Navy, Color.OldLace, Color.MediumVioletRed, Color.Magenta, Color.Linen, Color.LimeGreen, Color.LavenderBlush, Color.LawnGreen, Color.LemonChiffon, Color.LightBlue, Color.LightCoral, Color.LightCyan, Color.LightGoldenrodYellow, Color.LightGray, Color.LightGreen, Color.LightPink, Color.LightSalmon, Color.LightSeaGreen, Color.LightSkyBlue, Color.LightSlateGray, Color.LightSteelBlue, Color.LightYellow, Color.Lime, Color.Olive, Color.OliveDrab, Color.Orange, Color.OrangeRed, Color.Silver, Color.SkyBlue, Color.SlateBlue, Color.SlateGray, Color.Snow, Color.SpringGreen, Color.SteelBlue, Color.Tan, Color.Teal, Color.Thistle, Color.Tomato, Color.Transparent, Color.Turquoise, Color.Violet, Color.Wheat, Color.White, Color.WhiteSmoke, Color.Sienna, Color.Lavender, Color.SeaShell, Color.SandyBrown, Color.Orchid, Color.PaleGoldenrod, Color.PaleGreen, Color.PaleTurquoise, Color.PaleVioletRed, Color.PapayaWhip, Color.PeachPuff, Color.Peru, Color.Pink, Color.Plum, Color.PowderBlue, Color.Purple, Color.Red, Color.RosyBrown, Color.RoyalBlue, Color.SaddleBrown, Color.Salmon, Color.SeaGreen, Color.Yellow, Color.Khaki, Color.Cyan, Color.DarkMagenta, Color.DarkKhaki, Color.DarkGreen, Color.DarkGray, Color.DarkGoldenrod, Color.DarkCyan, Color.DarkBlue, Color.Ivory, Color.Crimson, Color.Cornsilk, Color.CornflowerBlue, Color.Coral, Color.Chocolate, Color.DarkOliveGreen, Color.Chartreuse, Color.BurlyWood, Color.Brown, Color.BlueViolet, Color.Blue, Color.BlanchedAlmond, Color.Black, Color.Bisque, Color.Beige, Color.Azure, Color.Aquamarine, Color.Aqua, Color.AntiqueWhite, Color.AliceBlue, Color.CadetBlue, Color.DarkOrange, Color.YellowGreen, Color.DarkRed, Color.Indigo, Color.IndianRed, Color.DarkOrchid, Color.Honeydew, Color.GreenYellow, Color.Green, Color.Gray, Color.Goldenrod, Color.Gold, Color.GhostWhite, Color.Gainsboro, Color.Fuchsia, Color.ForestGreen, Color.HotPink, Color.Firebrick, Color.FloralWhite, Color.DodgerBlue, Color.DimGray, Color.DeepSkyBlue, Color.DeepPink, Color.DarkViolet, Color.DarkTurquoise, Color.DarkSlateGray, Color.DarkSlateBlue, Color.DarkSeaGreen, Color.DarkSalmon, };
|
||||
|
||||
public static DotGraph DumpEgraphAsDot(EGraph eGraph, IReadOnlyList<IMatchResult>? matches, string file)
|
||||
public static DotGraph DumpEgraphAsDot(IEGraph eGraph, IReadOnlyList<IMatchResult>? matches, Stream output)
|
||||
{
|
||||
var printer = new EGraphPrinter(eGraph);
|
||||
printer.ConvertEGraphAsDot();
|
||||
printer.AttachEGraphMatches(matches);
|
||||
return printer.SaveToFile(file);
|
||||
return printer.SaveToStream(output);
|
||||
}
|
||||
|
||||
public DotGraph AttachEGraphMatches(IReadOnlyList<IMatchResult>? matches)
|
||||
|
|
|
@ -7,6 +7,8 @@ using System.IO;
|
|||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using Nncase.PatternMatch;
|
||||
|
||||
|
@ -17,90 +19,75 @@ namespace Nncase.Transform;
|
|||
/// </summary>
|
||||
public class EGraphPass : RulesPass
|
||||
{
|
||||
private readonly List<IRewriteRule> _rules = new();
|
||||
private readonly IEGraphRewriteProvider _rewriteProvider;
|
||||
private readonly Evaluator.IBaseFuncCostEvaluator? _baseFuncCostEvaluator;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="EGraphPass"/> class.
|
||||
/// </summary>
|
||||
/// <param name="name">Name.</param>
|
||||
public EGraphPass(string name)
|
||||
: base(name)
|
||||
{
|
||||
_baseFuncCostEvaluator = null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="EGraphPass"/> class.
|
||||
/// </summary>
|
||||
/// <param name="name">Pass Name.</param>
|
||||
/// <param name="baseFuncCostEvaluator">Extenal cost evaluator.</param>
|
||||
public EGraphPass(string name, Evaluator.IBaseFuncCostEvaluator baseFuncCostEvaluator)
|
||||
: base(name)
|
||||
public EGraphPass(Evaluator.IBaseFuncCostEvaluator? baseFuncCostEvaluator = null)
|
||||
{
|
||||
_rewriteProvider = CompileSession.GetRequiredService<IEGraphRewriteProvider>();
|
||||
_baseFuncCostEvaluator = baseFuncCostEvaluator;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override async Task<BaseFunction> RunCoreAsync(BaseFunction function, RunPassOptions options)
|
||||
protected override async Task<BaseFunction> RunCoreAsync(BaseFunction function, RunPassContext context)
|
||||
{
|
||||
var graph = new EGraph();
|
||||
var root = graph.Add(function);
|
||||
EGraphRewriter.Rewrite(graph, Rules, options);
|
||||
OnPostRewriteStart(graph, options);
|
||||
await OnPostRewrite(graph, options);
|
||||
OnPostRewriteEnd(graph, options);
|
||||
var post = graph.Extract(root, _baseFuncCostEvaluator, options);
|
||||
_rewriteProvider.ERewrite(graph, Rules, context);
|
||||
await OnPostRewriteStartAsync(graph, context);
|
||||
await OnPostRewriteAsync(graph, context);
|
||||
await OnPostRewriteEndAsync(graph, context);
|
||||
var post = graph.Extract(root, _baseFuncCostEvaluator);
|
||||
CompilerServices.InferenceType(post);
|
||||
return (BaseFunction)post;
|
||||
}
|
||||
|
||||
protected virtual Task OnPostRewrite(EGraph graph, RunPassOptions options)
|
||||
/// <summary>
|
||||
/// The callback after egraph rewrite.
|
||||
/// </summary>
|
||||
/// <param name="eGraph">EGraph after rewrite.</param>
|
||||
/// <param name="context">Run pass context.</param>
|
||||
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
|
||||
protected virtual Task OnPostRewriteAsync(EGraph eGraph, RunPassContext context)
|
||||
{
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// the callback function you can custom process func with run pass options.
|
||||
/// The callback function you can custom process func with run pass options.
|
||||
/// </summary>
|
||||
/// <param name="eGraph"> egraph without run pass.</param>
|
||||
/// <param name="options">Options.</param>
|
||||
protected virtual void OnPostRewriteStart(EGraph eGraph, RunPassOptions options)
|
||||
/// <param name="eGraph">EGraph after rewrite.</param>
|
||||
/// <param name="context">Run pass context.</param>
|
||||
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
|
||||
protected virtual Task OnPostRewriteStartAsync(EGraph eGraph, RunPassContext context)
|
||||
{
|
||||
switch (options.DumpLevel)
|
||||
if (DumpScope.Current.IsEnabled(DumpFlags.PassIR))
|
||||
{
|
||||
case >= 4:
|
||||
EGraphPrinter.DumpEgraphAsDot(
|
||||
eGraph,
|
||||
null,
|
||||
Path.Combine(options.DumpDir, options.PassName, "PostRewriteStart", $"V{eGraph.Version}"));
|
||||
break;
|
||||
case >= 1:
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
using var fs = DumpScope.Current.OpenFile(Path.Combine("PostRewriteStart", $"V{eGraph.Version}.dot"));
|
||||
EGraphPrinter.DumpEgraphAsDot(eGraph, null, fs);
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// the callback function you can custom process func with run pass options.
|
||||
/// The callback function you can custom process func with run pass options.
|
||||
/// </summary>
|
||||
/// <param name="eGraph"> egraph with rewrited. </param>
|
||||
/// <param name="options">Options.</param>
|
||||
protected virtual void OnPostRewriteEnd(EGraph eGraph, RunPassOptions options)
|
||||
/// <param name="eGraph">EGraph after post rewrite.</param>
|
||||
/// <param name="context">Run pass context.</param>
|
||||
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
|
||||
protected virtual Task OnPostRewriteEndAsync(EGraph eGraph, RunPassContext context)
|
||||
{
|
||||
switch (options.DumpLevel)
|
||||
if (DumpScope.Current.IsEnabled(DumpFlags.PassIR))
|
||||
{
|
||||
case >= 4:
|
||||
EGraphPrinter.DumpEgraphAsDot(
|
||||
eGraph,
|
||||
null,
|
||||
Path.Combine(options.DumpDir, options.PassName, "PostRewriteEnd", $"V{eGraph.Version}"));
|
||||
break;
|
||||
case >= 1:
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
using var fs = DumpScope.Current.OpenFile(Path.Combine("PostRewriteEnd", $"V{eGraph.Version}.dot"));
|
||||
EGraphPrinter.DumpEgraphAsDot(eGraph, null, fs);
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ public partial class EGraphPrinter
|
|||
|
||||
private readonly Dictionary<EClass, string> _opMaps = new();
|
||||
|
||||
private readonly EGraph _eGraph;
|
||||
private readonly IEGraph _eGraph;
|
||||
|
||||
private readonly DotDumpVisitor _visitor = new DotDumpVisitor();
|
||||
|
||||
|
@ -55,7 +55,7 @@ public partial class EGraphPrinter
|
|||
/// ctor for egraph.
|
||||
/// </summary>
|
||||
/// <param name="egraph"></param>
|
||||
public EGraphPrinter(EGraph egraph)
|
||||
public EGraphPrinter(IEGraph egraph)
|
||||
{
|
||||
_idCounter = 0;
|
||||
DotGraph = new(directed: true);
|
||||
|
@ -74,13 +74,26 @@ public partial class EGraphPrinter
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// dump egraph as dot graph.
|
||||
/// </summary>
|
||||
/// <param name="eGraph">egraph.</param>
|
||||
/// <param name="output">Output stream.</param>
|
||||
/// <returns>Converted Graph.</returns>
|
||||
public static DotGraph DumpEgraphAsDot(IEGraph eGraph, Stream output)
|
||||
{
|
||||
var printer = new EGraphPrinter(eGraph);
|
||||
printer.ConvertEGraphAsDot();
|
||||
return printer.SaveToStream(output);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// dump egraph as dot graph.
|
||||
/// </summary>
|
||||
/// <param name="eGraph">egraph.</param>
|
||||
/// <param name="file">path.</param>
|
||||
/// <returns>Converted Graph.</returns>
|
||||
public static DotGraph DumpEgraphAsDot(EGraph eGraph, string file)
|
||||
public static DotGraph DumpEgraphAsDot(IEGraph eGraph, string file)
|
||||
{
|
||||
var printer = new EGraphPrinter(eGraph);
|
||||
printer.ConvertEGraphAsDot();
|
||||
|
@ -188,10 +201,21 @@ public partial class EGraphPrinter
|
|||
return DotGraph;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Save the DotGraph into stream.
|
||||
/// </summary>
|
||||
/// <param name="output">Output stream.</param>
|
||||
/// <returns>this dot graph.</returns>
|
||||
public DotGraph SaveToStream(Stream output)
|
||||
{
|
||||
DotGraph.Build(new StreamWriter(output, leaveOpen: true));
|
||||
return DotGraph;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Save the DotGraph into file.
|
||||
/// </summary>
|
||||
/// <param name="file">file path.</param>
|
||||
/// <param name="file">Output file.</param>
|
||||
/// <returns>this dot graph.</returns>
|
||||
public DotGraph SaveToFile(string file)
|
||||
{
|
||||
|
|
|
@ -1,90 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using Nncase.IR;
|
||||
using Nncase.PatternMatch;
|
||||
|
||||
namespace Nncase.Transform;
|
||||
|
||||
internal static class EGraphRewriter
|
||||
{
|
||||
/// <summary>
|
||||
/// Run egraph rewrite.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public static EGraph Rewrite(EGraph eGraph, IEnumerable<IRewriteRule> rules, RunPassOptions options)
|
||||
{
|
||||
var matches = new List<(IRewriteRule, IReadOnlyList<IMatchResult>)> { };
|
||||
var last_version = eGraph.Version;
|
||||
int count = 0;
|
||||
|
||||
while (true)
|
||||
{
|
||||
foreach (var rule in rules)
|
||||
{
|
||||
if (EGraphMatcher.TryMatchRoot(eGraph.Nodes, rule.Pattern, out var results))
|
||||
{
|
||||
matches.Add((rule, results));
|
||||
|
||||
if (options.DumpLevel > 1 && results.Count != 0)
|
||||
{
|
||||
EGraphPrinter.DumpEgraphAsDot(
|
||||
eGraph,
|
||||
results,
|
||||
Path.Combine(options.DumpDir, options.PassName, "Matches", $"V{eGraph.Version}_{count++}_{rule.GetType().Name}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
foreach (var (rule, results) in matches)
|
||||
{
|
||||
var replacedExprs = (from result in results
|
||||
let expr = rule.GetReplace(result, options)
|
||||
where expr != null
|
||||
select (eGraph.Find((ENode)result.Root), expr)).ToList();
|
||||
|
||||
foreach (var (oldEClass, newExpr) in replacedExprs)
|
||||
{
|
||||
if (!CompilerServices.InferenceType(newExpr))
|
||||
{
|
||||
CompilerServices.DumpIR(newExpr, "Replaced_Expr", options.DumpDir);
|
||||
throw new InvalidOperationException("Can't Inference The Replace Expr Type!");
|
||||
}
|
||||
|
||||
var newEClass = eGraph.Add(newExpr);
|
||||
if (options.DumpLevel > 4)
|
||||
{
|
||||
Console.WriteLine($"Version {eGraph.Version} : Merge {{{oldEClass}}} to {{{newEClass}}}");
|
||||
}
|
||||
|
||||
eGraph.Union(newEClass, oldEClass);
|
||||
}
|
||||
}
|
||||
|
||||
matches.Clear();
|
||||
if (last_version == eGraph.Version)
|
||||
{
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
last_version = eGraph.Version;
|
||||
}
|
||||
|
||||
eGraph.Rebuild();
|
||||
if (options.DumpLevel > 1)
|
||||
{
|
||||
EGraphPrinter.DumpEgraphAsDot(
|
||||
eGraph,
|
||||
Path.Combine(options.DumpDir, options.PassName, "Rebuild", $"V{eGraph.Version}"));
|
||||
}
|
||||
}
|
||||
|
||||
return eGraph;
|
||||
}
|
||||
}
|
|
@ -3,16 +3,29 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using Nncase.PatternMatch;
|
||||
using Tensorflow;
|
||||
|
||||
namespace Nncase.Transform;
|
||||
|
||||
internal class EGraphRewriteProvider : IEGraphRewriteProvider
|
||||
{
|
||||
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options)
|
||||
private readonly ILogger _logger;
|
||||
|
||||
public EGraphRewriteProvider(ILogger<EGraphRewriteProvider> logger)
|
||||
{
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
|
||||
{
|
||||
if (expr.CheckedType is null)
|
||||
{
|
||||
|
@ -21,8 +34,73 @@ internal class EGraphRewriteProvider : IEGraphRewriteProvider
|
|||
|
||||
var graph = new EGraph();
|
||||
var root = graph.Add(expr);
|
||||
EGraphRewriter.Rewrite(graph, rules, options);
|
||||
var post = graph.Extract(root, null, options);
|
||||
ERewrite(graph, rules, options);
|
||||
var post = graph.Extract(root, null);
|
||||
return post;
|
||||
}
|
||||
|
||||
public IEGraph ERewrite(IEGraph eGraph, IEnumerable<IRewriteRule> rules, RunPassContext context)
|
||||
{
|
||||
var matches = new List<(IRewriteRule, IReadOnlyList<IMatchResult>)> { };
|
||||
var last_version = eGraph.Version;
|
||||
int count = 0;
|
||||
|
||||
while (true)
|
||||
{
|
||||
foreach (var rule in rules)
|
||||
{
|
||||
if (EGraphMatcher.TryMatchRoot(eGraph.Nodes, rule.Pattern, out var results))
|
||||
{
|
||||
matches.Add((rule, results));
|
||||
|
||||
if (DumpScope.Current.IsEnabled(DumpFlags.Rewrite) && results.Count != 0)
|
||||
{
|
||||
using var fs = DumpScope.Current.OpenFile(Path.Combine("Matches", $"V{eGraph.Version}_{count++}_{rule.GetType().Name}.dot"));
|
||||
EGraphPrinter.DumpEgraphAsDot(eGraph, results, fs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
foreach (var (rule, results) in matches)
|
||||
{
|
||||
var replacedExprs = (from result in results
|
||||
let expr = rule.GetReplace(result, context)
|
||||
where expr != null
|
||||
select (eGraph.Find((ENode)result.Root), expr)).ToList();
|
||||
|
||||
foreach (var (oldEClass, newExpr) in replacedExprs)
|
||||
{
|
||||
var typeInferSuccess = CompilerServices.InferenceType(newExpr);
|
||||
Trace.Assert(typeInferSuccess);
|
||||
|
||||
var newEClass = eGraph.Add(newExpr);
|
||||
if (_logger.IsEnabled(LogLevel.Trace))
|
||||
{
|
||||
_logger.LogTrace("Version {Version} : Merge {OldClass} to {NewClass}", eGraph.Version, oldEClass, newEClass);
|
||||
}
|
||||
|
||||
eGraph.Union(newEClass, oldEClass);
|
||||
}
|
||||
}
|
||||
|
||||
matches.Clear();
|
||||
if (last_version == eGraph.Version)
|
||||
{
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
last_version = eGraph.Version;
|
||||
}
|
||||
|
||||
eGraph.Rebuild();
|
||||
if (DumpScope.Current.IsEnabled(DumpFlags.Rewrite))
|
||||
{
|
||||
using var fs = DumpScope.Current.OpenFile(Path.Combine("Rebuild", $"V{eGraph.Version}.dot"));
|
||||
EGraphPrinter.DumpEgraphAsDot(eGraph, fs);
|
||||
}
|
||||
}
|
||||
|
||||
return eGraph;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Autofac;
|
||||
using DryIoc;
|
||||
using Nncase.Hosting;
|
||||
|
||||
namespace Nncase.Transform;
|
||||
|
||||
/// <summary>
|
||||
/// Transform module.
|
||||
/// </summary>
|
||||
public class TransformModule : Module
|
||||
internal class TransformModule : IApplicationPart
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
protected override void Load(ContainerBuilder builder)
|
||||
public void ConfigureServices(IRegistrator registrator)
|
||||
{
|
||||
builder.RegisterType<EGraphRewriteProvider>().AsImplementedInterfaces().SingleInstance();
|
||||
registrator.Register<IEGraphRewriteProvider, EGraphRewriteProvider>(reuse: Reuse.ScopedOrSingleton);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Autofac;
|
||||
using Nncase.IR;
|
||||
using Nncase.IR.Buffer;
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Autofac;
|
||||
using Nncase.IR;
|
||||
using Nncase.IR.Buffer;
|
||||
|
||||
|
|
|
@ -1,23 +1,23 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Autofac;
|
||||
using DryIoc;
|
||||
using Nncase.Evaluator.Tensors;
|
||||
using Nncase.Hosting;
|
||||
|
||||
namespace Nncase.Evaluator.Buffer;
|
||||
|
||||
/// <summary>
|
||||
/// Buffer module.
|
||||
/// </summary>
|
||||
public class BufferModule : Module
|
||||
internal class BufferModule : IApplicationPart
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
protected override void Load(ContainerBuilder builder)
|
||||
public void ConfigureServices(IRegistrator registrator)
|
||||
{
|
||||
builder.RegisterType<DDrOfEvaluator>().AsImplementedInterfaces();
|
||||
builder.RegisterType<BaseMentOfEvaluator>().AsImplementedInterfaces();
|
||||
builder.RegisterType<StrideOfEvaluator>().AsImplementedInterfaces();
|
||||
builder.RegisterType<AllocateEvaluator>().AsImplementedInterfaces();
|
||||
builder.RegisterType<UninitializedEvaluator>().AsImplementedInterfaces();
|
||||
registrator.RegisterManyInterface<DDrOfEvaluator>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<BaseMentOfEvaluator>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<StrideOfEvaluator>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<AllocateEvaluator>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<UninitializedEvaluator>(reuse: Reuse.Singleton);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Autofac;
|
||||
using Nncase.IR;
|
||||
using Nncase.IR.Buffer;
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue