Refactor usage of IoC (#767)

* Use DryIoC

* Fix all usage of ioc

* Fix merged tests

* Use MSBuildThisFileDirectory instead of SolutionDir

* Fix pytest

* Fix build

* Fix build

* Fix os.makedirs with exist_ok=True

* Fix merge

* Apply code-format changes

* Add dump test

* Fix dump test

* Resolve review & Add more dump tests

Co-authored-by: sunnycase <sunnycase@users.noreply.github.com>
pull/774/head
sunnycase 2023-01-11 19:24:38 +08:00 committed by GitHub
parent 8f5fe4d477
commit 4b93ce843a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
218 changed files with 4792 additions and 4684 deletions

View File

@ -533,7 +533,7 @@ dotnet_diagnostic.MA0097.severity = warning # MA0097: A class that implements
dotnet_diagnostic.MA0098.severity = suggestion # MA0098: Use indexer instead of extension methods
dotnet_diagnostic.MA0099.severity = warning # MA0099: Use Explicit enum value instead of 0
MA0053.public_class_should_be_sealed = false
end_of_line = lf
end_of_line = crlf
tab_width = 4
dotnet_style_allow_multiple_blank_lines_experimental = true:silent
dotnet_style_allow_statement_immediately_after_block_experimental = true:silent

113
.gitattributes vendored
View File

@ -1,63 +1,62 @@
###############################################################################
# Set default behavior to automatically normalize line endings.
###############################################################################
* text=auto
###############################################################################
# Set default behavior for command prompt diff.
#
# This is need for earlier builds of msysgit that does not have it on by
# default for csharp files.
# Note: This is only used by command line
###############################################################################
#*.cs diff=csharp
*.doc diff=astextplain
*.DOC diff=astextplain
*.docx diff=astextplain
*.DOCX diff=astextplain
*.dot diff=astextplain
*.DOT diff=astextplain
*.pdf diff=astextplain
*.PDF diff=astextplain
*.rtf diff=astextplain
*.RTF diff=astextplain
###############################################################################
# Set the merge driver for project and solution files
#
# Merging from the command prompt will add diff markers to the files if there
# are conflicts (Merging from VS is not affected by the settings below, in VS
# the diff markers are never inserted). Diff markers may cause the following
# file extensions to fail to load in VS. An alternative would be to treat
# these files as binary and thus will always conflict and require user
# intervention with every merge. To do so, just uncomment the entries below
###############################################################################
#*.sln merge=binary
#*.csproj merge=binary
#*.vbproj merge=binary
#*.vcxproj merge=binary
#*.vcproj merge=binary
#*.dbproj merge=binary
#*.fsproj merge=binary
#*.lsproj merge=binary
#*.wixproj merge=binary
#*.modelproj merge=binary
#*.sqlproj merge=binary
#*.wwaproj merge=binary
*.jpg binary
*.png binary
*.gif binary
###############################################################################
# behavior for image files
#
# image files are treated as binary by default.
###############################################################################
#*.jpg binary
#*.png binary
#*.gif binary
# Force bash scripts to always use lf line endings so that if a repo is accessed
# in Unix via a file share from Windows, the scripts will work.
*.in text eol=lf
*.sh text eol=lf
###############################################################################
# diff behavior for common document formats
#
# Convert binary document formats to text before diffing them. This feature
# is only available from the command line. Turn it on by uncommenting the
# entries below.
###############################################################################
#*.doc diff=astextplain
#*.DOC diff=astextplain
#*.docx diff=astextplain
#*.DOCX diff=astextplain
#*.dot diff=astextplain
#*.DOT diff=astextplain
#*.pdf diff=astextplain
#*.PDF diff=astextplain
#*.rtf diff=astextplain
#*.RTF diff=astextplain
# Likewise, force cmd and batch scripts to always use crlf
*.cmd text eol=crlf
*.bat text eol=crlf
*.cs text=auto diff=csharp
*.vb text=auto
*.resx text=auto
*.c text=auto
*.cpp text=auto
*.cxx text=auto
*.h text=auto
*.hxx text=auto
*.py text=auto
*.rb text=auto
*.java text=auto
*.html text=auto
*.htm text=auto
*.css text=auto
*.scss text=auto
*.sass text=auto
*.less text=auto
*.js text=auto
*.lisp text=auto
*.clj text=auto
*.sql text=auto
*.php text=auto
*.lua text=auto
*.m text=auto
*.asm text=auto
*.erl text=auto
*.fs text=auto
*.fsx text=auto
*.hs text=auto
*.csproj text=auto
*.vbproj text=auto
*.fsproj text=auto
*.dbproj text=auto
*.sln text=auto eol=crlf

View File

@ -1,16 +1,19 @@
<Project>
<PropertyGroup>
<ManagePackageVersionsCentrally>true</ManagePackageVersionsCentrally>
<CodeAnalysisRuleSet>$(SolutionDir)/tools/StyleCopAnalyzers.ruleset</CodeAnalysisRuleSet>
<CodeAnalysisRuleSet>$(MSBuildThisFileDirectory)/tools/StyleCopAnalyzers.ruleset</CodeAnalysisRuleSet>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<PackageVersion Include="AnyTensorFlow.NET" Version="0.70.1" />
<PackageVersion Include="Autofac" Version="6.3.0" />
<PackageVersion Include="Autofac.Extensions.DependencyInjection" Version="7.2.0" />
<PackageVersion Include="BitFields" Version="0.1.0" />
<PackageVersion Include="coverlet.collector" Version="3.0.2" />
<PackageVersion Include="DryIoc.dll" Version="5.3.1" />
<PackageVersion Include="DryIoc.Microsoft.DependencyInjection" Version="6.1.0" />
<PackageVersion Include="Extension.Mathematics" Version="1.2.12" />
<PackageVersion Include="Fody">
<Version>6.6.4</Version>
</PackageVersion>
<PackageVersion Include="GiGraph.Dot" Version="2.0.0" />
<PackageVersion Include="Google.Protobuf" Version="3.19.1" />
<PackageVersion Include="Grpc.Tools" Version="2.42.0" />
@ -25,10 +28,17 @@
<!-- <PackageVersion Include="libtorch-cuda-11.3-win-x64" Version="1.11.0.1" /> -->
<PackageVersion Include="MagicalTensorflowLib" Version="0.0.2" />
<PackageVersion Include="MagicalTensorflowLibOSX-ARM64" Version="0.0.4" />
<PackageVersion Include="MethodBoundaryAspect.Fody">
<Version>2.0.148</Version>
</PackageVersion>
<PackageVersion Include="MethodDecorator.Fody">
<Version>1.1.1</Version>
</PackageVersion>
<PackageVersion Include="Microsoft.CodeAnalysis.CSharp" Version="4.0.1" />
<PackageVersion Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.3" />
<PackageVersion Include="Microsoft.Extensions.Hosting" Version="6.0.0" />
<PackageVersion Include="Microsoft.Extensions.Hosting.Abstractions" Version="6.0.0" />
<PackageVersion Include="Microsoft.Extensions.Logging.Abstractions" Version="6.0.0" />
<PackageVersion Include="Microsoft.Extensions.Options" Version="6.0.0" />
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="16.9.4" />
<PackageVersion Include="Microsoft.Toolkit.HighPerformance" Version="7.1.1" />
@ -44,6 +54,9 @@
<PackageVersion Include="TorchSharp" Version="0.99.0" />
<PackageVersion Include="xunit" Version="2.4.1" />
<PackageVersion Include="xunit.assert" Version="2.4.1" />
<PackageVersion Include="Xunit.Combinatorial">
<Version>1.5.25</Version>
</PackageVersion>
<PackageVersion Include="Xunit.DependencyInjection" Version="8.3.0" />
<PackageVersion Include="xunit.runner.visualstudio" Version="2.4.3" />
</ItemGroup>
@ -54,6 +67,6 @@
</PackageReference>
</ItemGroup>
<ItemGroup>
<AdditionalFiles Include="$(SolutionDir)/tools/stylecop.json" />
<AdditionalFiles Include="$(MSBuildThisFileDirectory)/tools/stylecop.json" />
</ItemGroup>
</Project>

View File

@ -7,6 +7,8 @@ using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using DryIoc;
using Nncase.Hosting;
namespace Nncase;
@ -18,11 +20,10 @@ public static class K210ApplicationPart
/// <summary>
/// Add k210 assembly.
/// </summary>
/// <param name="assemblies">Assembly collection.</param>
/// <returns>Updated assembly collection.</returns>
public static IList<Assembly> AddK210(this IList<Assembly> assemblies)
/// <param name="registrator">Service registrator.</param>
/// <returns>Configured service registrator.</returns>
public static IRegistrator AddK210(this IRegistrator registrator)
{
assemblies.Add(typeof(K210ApplicationPart).Assembly);
return assemblies;
return registrator.RegisterModule<K210Module>();
}
}

View File

@ -1,8 +1,9 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Autofac;
using DryIoc;
using Nncase.Evaluator.K210;
using Nncase.Hosting;
using Nncase.Targets;
namespace Nncase;
@ -10,14 +11,13 @@ namespace Nncase;
/// <summary>
/// K210 module.
/// </summary>
public class K210Module : Module
internal sealed class K210Module : IApplicationPart
{
/// <inheritdoc/>
protected override void Load(ContainerBuilder builder)
public void ConfigureServices(IRegistrator registrator)
{
builder.RegisterType<K210Target>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<FakeKPUConv2DEvaluator>().AsImplementedInterfaces();
builder.RegisterType<FakeKPUDownloadEvaluator>().AsImplementedInterfaces();
builder.RegisterType<FakeKPUUploadEvaluator>().AsImplementedInterfaces();
registrator.Register<ITarget, K210Target>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<FakeKPUConv2DEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<FakeKPUDownloadEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<FakeKPUUploadEvaluator>(reuse: Reuse.Singleton);
}
}

View File

@ -18,7 +18,7 @@
</ItemGroup>
<ItemGroup>
<ProjectReference Include="$(SolutionDir)/tools/Nncase.SourceGenerator/Nncase.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
<ProjectReference Include="../../tools/Nncase.SourceGenerator/Nncase.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
</ItemGroup>
</Project>

View File

@ -32,36 +32,37 @@ public class K210Target : ITarget
}
/// <inheritdoc/>
public void RegisterTargetDependentPass(PassManager passManager, CompileOptions options)
public void RegisterTargetDependentPass(IPassManager passManager, CompileOptions options)
{
if (options.ModelQuantMode == ModelQuantMode.UsePTQ)
if (options.QuantizeOptions.ModelQuantMode == ModelQuantMode.UsePTQ)
{
passManager.Add(new EGraphPassWithQuantize("lowering_kpu", options.QuantizeOptions!)
passManager.Add<EGraphPassWithQuantize>().Configure(p =>
{
new LowerConv2D(),
p.Name = "lowering_kpu";
p.Add<LowerConv2D>();
});
}
}
/// <inheritdoc/>
public Task<Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>> BindQuantMethodCosine(ICalibrationDatasetProvider calibrationDataset, ITarget target, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, RunPassOptions runPassOptions)
public Task<Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>> BindQuantMethodCosine(ICalibrationDatasetProvider calibrationDataset, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, QuantizeOptions quantizeOptions)
{
return null;
return Task.FromResult(new Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>());
}
/// <inheritdoc/>
public Task AdaRoundWeights(ICalibrationDatasetProvider calibrationDataset, ITarget target, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, RunPassOptions runPassOptions)
public Task AdaRoundWeights(ICalibrationDatasetProvider calibrationDataset, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, QuantizeOptions quantizeOptions)
{
return null;
return Task.CompletedTask;
}
/// <inheritdoc/>
public void RegisterQuantizePass(PassManager passManager, CompileOptions options)
public void RegisterQuantizePass(IPassManager passManager, CompileOptions options)
{
}
/// <inheritdoc/>
public void RegisterTargetDependentAfterQuantPass(PassManager passManager, CompileOptions options)
public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, CompileOptions options)
{
}

View File

@ -14,7 +14,7 @@
</ItemGroup>
<ItemGroup>
<ProjectReference Include="$(SolutionDir)/tools/Nncase.SourceGenerator/Nncase.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
<ProjectReference Include="../../tools/Nncase.SourceGenerator/Nncase.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
</ItemGroup>
</Project>

View File

@ -7,6 +7,8 @@ using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using DryIoc;
using Nncase.Hosting;
namespace Nncase;
@ -18,11 +20,10 @@ public static class StackVMApplicationPart
/// <summary>
/// Add stackVM assembly.
/// </summary>
/// <param name="assemblies">Assembly collection.</param>
/// <returns>Updated assembly collection.</returns>
public static IList<Assembly> AddStackVM(this IList<Assembly> assemblies)
/// <param name="registrator">Service registrator.</param>
/// <returns>Configured service registrator.</returns>
public static IRegistrator AddStackVM(this IRegistrator registrator)
{
assemblies.Add(typeof(StackVMApplicationPart).Assembly);
return assemblies;
return registrator.RegisterModule<StackVMModule>();
}
}

View File

@ -1,7 +1,8 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Autofac;
using DryIoc;
using Nncase.Hosting;
using Nncase.Targets;
namespace Nncase;
@ -9,11 +10,10 @@ namespace Nncase;
/// <summary>
/// StackVM module.
/// </summary>
public class StackVMModule : Module
internal class StackVMModule : IApplicationPart
{
/// <inheritdoc/>
protected override void Load(ContainerBuilder builder)
public void ConfigureServices(IRegistrator registrator)
{
builder.RegisterType<CPUTarget>().AsImplementedInterfaces().SingleInstance();
registrator.Register<ITarget, CPUTarget>(reuse: Reuse.Singleton);
}
}

View File

@ -21,8 +21,12 @@ namespace Nncase.Targets;
/// </summary>
public class CPUTarget : ITarget
{
/// <inheritdoc/>
public string Kind => "cpu";
/// <summary>
/// Gets kind.
/// </summary>
public static readonly string Kind = "cpu";
string ITarget.Kind => Kind;
/// <inheritdoc/>
public void ParseTargetDependentOptions(IConfigurationSection configure)
@ -30,30 +34,30 @@ public class CPUTarget : ITarget
}
/// <inheritdoc/>
public void RegisterTargetDependentPass(PassManager passManager, CompileOptions options)
public void RegisterTargetDependentPass(IPassManager passManager, CompileOptions options)
{
}
/// <inheritdoc/>
public Task<Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>> BindQuantMethodCosine(ICalibrationDatasetProvider calibrationDataset, ITarget target, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, RunPassOptions runPassOptions)
public Task<Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>> BindQuantMethodCosine(ICalibrationDatasetProvider calibrationDataset, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, QuantizeOptions quantizeOptions)
{
var enodeQuantCosineDict = new Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>();
return Task.FromResult(enodeQuantCosineDict);
}
/// <inheritdoc/>
public Task AdaRoundWeights(ICalibrationDatasetProvider calibrationDataset, ITarget target, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, RunPassOptions runPassOptions)
public Task AdaRoundWeights(ICalibrationDatasetProvider calibrationDataset, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, QuantizeOptions quantizeOptions)
{
return null;
return Task.CompletedTask;
}
/// <inheritdoc/>
public void RegisterQuantizePass(PassManager passManager, CompileOptions options)
public void RegisterQuantizePass(IPassManager passManager, CompileOptions options)
{
}
/// <inheritdoc/>
public void RegisterTargetDependentAfterQuantPass(PassManager passManager, CompileOptions options)
public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, CompileOptions options)
{
}

View File

@ -67,14 +67,17 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Modules.K210", "modu
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Quantization", "src\Nncase.Quantization\Nncase.Quantization.csproj", "{317C0D8F-75B3-4248-83E8-17AADDCF247A}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.TestFixture", "src\Nncase.TestFixture\Nncase.TestFixture.csproj", "{FF67E6C1-205B-49D6-BA83-2DC536F0D414}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "root", "root", "{EF1F8779-2B98-4E6F-A3DC-CA3FD2CADAD8}"
ProjectSection(SolutionItems) = preProject
.editorconfig = .editorconfig
Directory.Packages.props = Directory.Packages.props
NuGet.Config = NuGet.Config
EndProjectSection
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Diagnostics", "src\Nncase.Diagnostics\Nncase.Diagnostics.csproj", "{E365B1B1-4D13-4839-9763-A7A7C5F32FD4}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Tests.TestFixture", "src\Nncase.Tests.TestFixture\Nncase.Tests.TestFixture.csproj", "{98A03405-CA53-4EC4-9B18-94D1C8DF9453}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@ -165,10 +168,14 @@ Global
{317C0D8F-75B3-4248-83E8-17AADDCF247A}.Debug|Any CPU.Build.0 = Debug|Any CPU
{317C0D8F-75B3-4248-83E8-17AADDCF247A}.Release|Any CPU.ActiveCfg = Release|Any CPU
{317C0D8F-75B3-4248-83E8-17AADDCF247A}.Release|Any CPU.Build.0 = Release|Any CPU
{FF67E6C1-205B-49D6-BA83-2DC536F0D414}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{FF67E6C1-205B-49D6-BA83-2DC536F0D414}.Debug|Any CPU.Build.0 = Debug|Any CPU
{FF67E6C1-205B-49D6-BA83-2DC536F0D414}.Release|Any CPU.ActiveCfg = Release|Any CPU
{FF67E6C1-205B-49D6-BA83-2DC536F0D414}.Release|Any CPU.Build.0 = Release|Any CPU
{E365B1B1-4D13-4839-9763-A7A7C5F32FD4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{E365B1B1-4D13-4839-9763-A7A7C5F32FD4}.Debug|Any CPU.Build.0 = Debug|Any CPU
{E365B1B1-4D13-4839-9763-A7A7C5F32FD4}.Release|Any CPU.ActiveCfg = Release|Any CPU
{E365B1B1-4D13-4839-9763-A7A7C5F32FD4}.Release|Any CPU.Build.0 = Release|Any CPU
{98A03405-CA53-4EC4-9B18-94D1C8DF9453}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{98A03405-CA53-4EC4-9B18-94D1C8DF9453}.Debug|Any CPU.Build.0 = Debug|Any CPU
{98A03405-CA53-4EC4-9B18-94D1C8DF9453}.Release|Any CPU.ActiveCfg = Release|Any CPU
{98A03405-CA53-4EC4-9B18-94D1C8DF9453}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@ -198,9 +205,10 @@ Global
{70D3FA34-B0B6-488F-812D-7E076CC0DFB3} = {9859F5E8-5504-4AFE-B955-9497A0A0CD66}
{7390DF41-E804-4F12-B441-6EB5119C81BF} = {9859F5E8-5504-4AFE-B955-9497A0A0CD66}
{317C0D8F-75B3-4248-83E8-17AADDCF247A} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
{FF67E6C1-205B-49D6-BA83-2DC536F0D414} = {E5A4516C-4080-4346-991D-57A7AA76ADA6}
{E365B1B1-4D13-4839-9763-A7A7C5F32FD4} = {BCA74168-F015-4B5B-B4CD-C83AE06B9822}
{98A03405-CA53-4EC4-9B18-94D1C8DF9453} = {E5A4516C-4080-4346-991D-57A7AA76ADA6}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {B251C48E-7D05-4662-8DFF-20843CB2DB99}
SolutionGuid = {9492E141-292E-4D60-9C6E-3738AB234DB2}
EndGlobalSection
EndGlobal

View File

@ -97,7 +97,7 @@ class GraphEvaluator:
return self._outputs[index]
def run(self):
self._outputs = self._func.body.evaluate(self._inputs, self._params).to_runtime_tensors()
self._outputs = self._func.body.evaluate(self._params, self._inputs).to_runtime_tensors()
@ property
def outputs_size(self) -> int:
@ -117,18 +117,19 @@ class IRModule():
class Compiler:
_target: _nncase.Target
_session: _nncase.CompileSession
_compiler: _nncase.Compiler
_compile_options: _nncase.CompileOptions
_quantize_options: _nncase.QuantizeOptions
_module: IRModule
def __init__(self) -> None:
def __init__(self, compile_options: CompileOptions) -> None:
self._compile_options = _nncase.CompileOptions()
self._compiler = _nncase.Compiler(self._compile_options)
self._quantize_options = None
def set_compile_options(self, compile_options: CompileOptions):
self.__process_compile_options(compile_options)
self._session = _nncase.CompileSession(self._target, self._compile_options)
self._compiler = self._session.compiler
self._quantize_options = None
def compile(self) -> None:
self._compiler.compile()
@ -168,14 +169,14 @@ class Compiler:
self._quantize_options = _nncase.QuantizeOptions()
self._compile_options.quantize_options = self._quantize_options
self._quantize_options.calibration_dataset = provider
self._compile_options.model_quant_mode = _nncase.ModelQuantMode.UsePTQ
self._quantize_options.model_quant_mode = _nncase.ModelQuantMode.UsePTQ
def dump_range_options(self) -> DumpRangeTensorOptions:
raise NotImplementedError("dump_range_options")
def __process_compile_options(self, compile_options: CompileOptions) -> ClCompileOptions:
self._compile_options.target = compile_options.target
self._compile_options.dump_level = 3 if compile_options.dump_ir == True else 0
self._target = _nncase.Target(compile_options.target)
self._compile_options.dump_flags = _nncase.DumpFlags.Nothing
self._compile_options.dump_dir = compile_options.dump_dir
def _import_module(self, model_content: bytes | io.RawIOBase) -> None:
@ -188,7 +189,7 @@ def check_target(target: str):
return target in ["cpu", "k510", "k230"]
def target_exists(target: str):
return _nncase.target_exists(target)
return _nncase.Target.exists(target)
return test_target(target) and target_exists(target)

View File

@ -48,11 +48,7 @@ PYBIND11_MODULE(_nncase, m) {
true, std::memory_order_release);
}));
m.def("initialize", nncase_clr_initialize);
m.def("launch_debugger", nncase_clr_launch_debugger);
m.def("target_exists", [](std::string_view target_name) {
return nncase_clr_target_exists(target_name.data(),
target_name.length());
});
m.def("launch_debugger", []() { nncase_clr_api()->luanch_debugger(); });
#include "runtime_tensor.inl"
@ -61,28 +57,29 @@ PYBIND11_MODULE(_nncase, m) {
.value("UsePTQ", nncase_mqm_use_ptq)
.value("UseQAT", nncase_mqm_use_qat);
py::enum_<nncase_dump_flags_t>(m, "DumpFlags")
.value("Nothing", nncase_dump_flags_none);
py::class_<compile_options>(m, "CompileOptions")
.def(py::init())
.def_property(
"input_format", py::overload_cast<>(&compile_options::input_format),
py::overload_cast<std::string_view>(&compile_options::input_format))
.def_property(
"target", py::overload_cast<>(&compile_options::target),
py::overload_cast<std::string_view>(&compile_options::target))
.def_property("dump_level",
py::overload_cast<>(&compile_options::dump_level),
py::overload_cast<int32_t>(&compile_options::dump_level))
.def_property(
"dump_dir", py::overload_cast<>(&compile_options::dump_dir),
py::overload_cast<std::string_view>(&compile_options::dump_dir))
.def_property("dump_flags",
py::overload_cast<>(&compile_options::dump_flags),
py::overload_cast<nncase_dump_flags_t>(
&compile_options::dump_flags))
.def_property("quantize_options",
py::overload_cast<>(&compile_options::quantize_options),
py::overload_cast<const quantize_options &>(
&compile_options::quantize_options))
.def_property("model_quant_mode",
py::overload_cast<>(&compile_options::model_quant_mode),
py::overload_cast<nncase_model_quant_mode_t>(
&compile_options::model_quant_mode));
&compile_options::quantize_options));
py::class_<target>(m, "Target")
.def(py::init<std::string_view>())
.def_static("exists", &target::exists);
py::class_<quantize_options>(m, "QuantizeOptions")
.def(py::init())
@ -90,7 +87,11 @@ PYBIND11_MODULE(_nncase, m) {
"calibration_dataset",
py::overload_cast<>(&quantize_options::calibration_dataset),
py::overload_cast<const calibration_dataset_provider &>(
&quantize_options::calibration_dataset));
&quantize_options::calibration_dataset))
.def_property("model_quant_mode",
py::overload_cast<>(&quantize_options::model_quant_mode),
py::overload_cast<nncase_model_quant_mode_t>(
&quantize_options::model_quant_mode));
py::class_<calibration_dataset_provider>(m, "CalibrationDatasetProvider")
.def(py::init([](py::list dataset, size_t samples_count,
@ -138,21 +139,21 @@ PYBIND11_MODULE(_nncase, m) {
}
});
py::class_<expr>(m, "Expr").def("evaluate", [](expr &expr, py::list inputs,
py::list params) {
std::vector<clr_object_handle_t> input_handles(inputs.size());
py::class_<expr>(m, "Expr").def("evaluate", [](expr &expr, py::list params,
py::list inputs) {
std::vector<clr_object_handle_t> param_handles(params.size());
for (size_t i = 0; i < input_handles.size(); i++) {
input_handles[i] = inputs[i].cast<rtvalue &>().get();
}
std::vector<clr_object_handle_t> input_handles(inputs.size());
for (size_t i = 0; i < param_handles.size(); i++) {
param_handles[i] = params[i].cast<var &>().get();
}
for (size_t i = 0; i < input_handles.size(); i++) {
input_handles[i] = inputs[i].cast<rtvalue &>().get();
}
array params_arr(nncase_array_var, param_handles.data(), inputs.size());
array inputs_arr(nncase_array_rtvalue, input_handles.data(),
inputs.size());
array params_arr(nncase_array_var, param_handles.data(), inputs.size());
return expr.evaluate(inputs_arr, params_arr);
return expr.evaluate(params_arr, inputs_arr);
});
py::class_<var, expr>(m, "Var");
@ -167,11 +168,14 @@ PYBIND11_MODULE(_nncase, m) {
.def_property_readonly("entry", &ir_module::entry);
py::class_<compiler>(m, "Compiler")
.def(py::init<const compile_options &>())
.def("import_module", &compiler::import_module)
.def("compile", &compiler::compile)
.def("gencode", &compiler::gencode);
py::class_<compile_session>(m, "CompileSession")
.def(py::init<const target &, const compile_options &>())
.def_property_readonly("compiler", &compile_session::compiler);
py::class_<interpreter>(m, "Simulator")
.def(py::init())
.def("load_model",

View File

@ -40,7 +40,8 @@ struct nncase_buffer_slice {
uint32_t size_bytes;
};
NNCASE_API int nncase_object_free(nncase::object_node *node);
NNCASE_API int nncase_object_add_ref(nncase::object_node *node);
NNCASE_API int nncase_object_release(nncase::object_node *node);
NNCASE_API int nncase_interp_create(nncase::runtime::interpreter **interp);
NNCASE_API int nncase_interp_free(nncase::runtime::interpreter *interp);

View File

@ -43,6 +43,8 @@ typedef enum {
nncase_calib_kld = 1
} nncase_calib_method_t;
typedef enum { nncase_dump_flags_none = 0 } nncase_dump_flags_t;
typedef struct {
void (*add_ref)(nncase_stream_handle_t handle);
void (*release)(nncase_stream_handle_t handle);
@ -60,101 +62,77 @@ typedef struct {
size_t length);
} nncase_stream_mt_t;
typedef struct {
clr_object_handle_t (*array_create)(nncase_array_element_kind_t kind,
const clr_object_handle_t *elements,
size_t count);
clr_object_handle_t (*array_get_item)(clr_object_handle_t array,
size_t index);
size_t (*array_get_length)(clr_object_handle_t array);
clr_object_handle_t (*calibration_dataset_provider_create)(
clr_object_handle_t dataset, size_t samplesCount,
clr_object_handle_t fn_params);
void (*handle_free)(clr_object_handle_t handle);
clr_object_handle_t (*compile_options_create)();
void (*compile_options_set_input_file)(clr_object_handle_t compile_options,
const char *input_file,
size_t input_file_length);
void (*compile_options_set_input_format)(
clr_object_handle_t compile_options, const char *input_format,
size_t input_format_length);
void (*compile_options_set_dump_dir)(clr_object_handle_t compile_options,
const char *dump_dir,
size_t dump_dir_length);
void (*compile_options_set_dump_flags)(clr_object_handle_t compile_options,
nncase_dump_flags_t dump_flags);
void (*compile_options_set_quantize_options)(
clr_object_handle_t compile_options,
clr_object_handle_t quantize_options);
clr_object_handle_t (*compile_session_create)(
clr_object_handle_t target, clr_object_handle_t compile_options);
clr_object_handle_t (*compile_session_get_compiler)(
clr_object_handle_t compile_session);
void (*compiler_initialize)();
clr_object_handle_t (*compiler_import_module)(clr_object_handle_t compiler,
clr_object_handle_t stream);
void (*compiler_compile)(clr_object_handle_t compiler);
void (*compiler_gencode)(clr_object_handle_t compiler,
clr_object_handle_t stream);
clr_object_handle_t (*datatype_from_typecode)(nncase::typecode_t typecode);
clr_object_handle_t (*expr_evaluate)(clr_object_handle_t expr,
clr_object_handle_t parameters,
clr_object_handle_t inputs);
clr_object_handle_t (*function_get_body)(clr_object_handle_t function);
clr_object_handle_t (*function_get_parameters)(
clr_object_handle_t function);
clr_object_handle_t (*ir_module_get_entry)(clr_object_handle_t module);
void (*luanch_debugger)();
clr_object_handle_t (*quantize_options_create)();
void (*quantize_options_set_calibration_dataset)(
clr_object_handle_t quantize_options, clr_object_handle_t dataset);
void (*quantize_options_set_calibration_method)(
clr_object_handle_t quantize_options, nncase_calib_method_t method);
void (*quantize_options_set_model_quant_mode)(
clr_object_handle_t quantize_options,
nncase_model_quant_mode_t model_quant_mode);
void (*quantize_options_set_quant_type)(
clr_object_handle_t quantize_options, clr_object_handle_t quant_type);
clr_object_handle_t (*rtvalue_from_handle)(nncase::value_node *value);
nncase::value_node *(*rtvalue_get_handle)(clr_object_handle_t rtvalue);
clr_object_handle_t (*stream_create)(const nncase_stream_mt_t *mt,
void *handle);
clr_object_handle_t (*target_create)(const char *target_name,
size_t target_name_length);
bool (*target_exists)(const char *target_name, size_t target_name_length);
} nncase_api_mt_t;
NNCASE_API nncase_api_mt_t *nncase_clr_api();
NNCASE_API int nncase_clr_initialize(const char *root_assembly_path);
NNCASE_API int nncase_clr_uninitialize();
NNCASE_API int nncase_clr_array_create(nncase_array_element_kind_t kind,
const clr_object_handle_t *elements,
size_t count,
clr_object_handle_t *array);
NNCASE_API int nncase_clr_array_get_item(clr_object_handle_t array,
size_t index,
clr_object_handle_t *item);
NNCASE_API int nncase_clr_array_get_length(clr_object_handle_t array,
size_t *length);
NNCASE_API int nncase_clr_calibration_dataset_provider_create(
clr_object_handle_t dataset, size_t samplesCount,
clr_object_handle_t fn_params, clr_object_handle_t *provider);
NNCASE_API int nncase_clr_handle_free(clr_object_handle_t handle);
NNCASE_API int
nncase_clr_compile_options_create(clr_object_handle_t *compile_options);
NNCASE_API int
nncase_clr_compile_options_set_input_file(clr_object_handle_t compile_options,
const char *input_file,
size_t input_file_length);
NNCASE_API int
nncase_clr_compile_options_set_input_format(clr_object_handle_t compile_options,
const char *input_format,
size_t input_format_length);
NNCASE_API int
nncase_clr_compile_options_set_target(clr_object_handle_t compile_options,
const char *target, size_t target_length);
NNCASE_API int
nncase_clr_compile_options_set_dump_level(clr_object_handle_t compile_options,
int32_t dump_level);
NNCASE_API int
nncase_clr_compile_options_set_dump_dir(clr_object_handle_t compile_options,
const char *dump_dir,
size_t dump_dir_length);
NNCASE_API int nncase_clr_compile_options_set_quantize_options(
clr_object_handle_t compile_options, clr_object_handle_t quantize_options);
NNCASE_API int
nncase_clr_compile_options_set_quant_type(clr_object_handle_t compile_options,
clr_object_handle_t quant_type);
NNCASE_API int nncase_clr_compile_options_set_model_quant_mode(
clr_object_handle_t compile_options,
nncase_model_quant_mode_t model_quant_mode);
NNCASE_API int nncase_clr_compiler_create(clr_object_handle_t compile_options,
clr_object_handle_t *compiler);
NNCASE_API int nncase_clr_compiler_import_module(clr_object_handle_t compiler,
clr_object_handle_t stream,
clr_object_handle_t *module);
NNCASE_API int nncase_clr_compiler_compile(clr_object_handle_t compiler);
NNCASE_API int nncase_clr_compiler_gencode(clr_object_handle_t compiler,
clr_object_handle_t stream);
NNCASE_API int nncase_clr_datatype_from_typecode(nncase::typecode_t typecode,
clr_object_handle_t *datatype);
NNCASE_API int nncase_clr_expr_evaluate(clr_object_handle_t expr,
clr_object_handle_t parameters,
clr_object_handle_t inputs,
clr_object_handle_t *result);
NNCASE_API int nncase_clr_function_get_body(clr_object_handle_t function,
clr_object_handle_t *body);
NNCASE_API int
nncase_clr_function_get_parameters(clr_object_handle_t function,
clr_object_handle_t *parameters);
NNCASE_API int nncase_clr_ir_module_get_entry(clr_object_handle_t module,
clr_object_handle_t *entry);
NNCASE_API int nncase_clr_launch_debugger();
NNCASE_API int
nncase_clr_quantize_options_create(clr_object_handle_t *quantize_options);
NNCASE_API int nncase_clr_quantize_options_set_calibration_dataset(
clr_object_handle_t quantize_options, clr_object_handle_t dataset);
NNCASE_API int nncase_clr_quantize_options_set_calibration_method(
clr_object_handle_t quantize_options, nncase_calib_method_t method);
NNCASE_API int nncase_clr_rtvalue_from_handle(nncase::value_node *value,
clr_object_handle_t *rtvalue);
NNCASE_API int nncase_clr_rtvalue_get_handle(clr_object_handle_t rtvalue,
nncase::value_node **value);
NNCASE_API int nncase_clr_stream_create(const nncase_stream_mt_t *mt,
void *handle,
clr_object_handle_t *stream);
NNCASE_API bool nncase_clr_target_exists(const char *target_name,
size_t target_name_length);
}
DEFINE_ENUM_BITMASK_OPERATORS(nncase_dump_flags_t)
namespace nncase::clr {
class clr_object_ptr {
public:
@ -196,7 +174,7 @@ class clr_object_ptr {
void release() {
if (auto handle = handle_) {
handle_ = nullptr;
nncase_clr_handle_free(handle);
nncase_clr_api()->handle_free(handle);
}
}
@ -204,16 +182,16 @@ class clr_object_ptr {
clr_object_handle_t handle_;
};
#define CHECK_CLR(x) \
if (x) { \
throw std::runtime_error(#x); \
}
#define CHECK_CLR(x) x
class clr_object_base {
public:
constexpr clr_object_base(std::nullptr_t = nullptr) noexcept
: obj_(nullptr) {}
clr_object_base(std::in_place_t, clr_object_ptr ptr) noexcept
: obj_(std::move(ptr)) {}
clr_object_base(clr_object_base &&) = default;
clr_object_base &operator=(clr_object_base &&) = default;
@ -237,22 +215,15 @@ class array : public clr_object_base {
array(nncase_array_element_kind_t kind, const clr_object_handle_t *elements,
size_t length) {
CHECK_CLR(nncase_clr_array_create(kind, elements, length,
obj_.release_and_addressof()));
obj_ = nncase_clr_api()->array_create(kind, elements, length);
}
template <class T = clr_object_base> T at(size_t index) {
T value(nullptr);
CHECK_CLR(nncase_clr_array_get_item(obj_.get(), index,
value.release_and_addressof()));
return value;
return {std::in_place,
nncase_clr_api()->array_get_item(obj_.get(), index)};
}
size_t length() {
size_t length;
CHECK_CLR(nncase_clr_array_get_length(obj_.get(), &length));
return length;
}
size_t length() { return nncase_clr_api()->array_get_length(obj_.get()); }
template <class T = clr_object_base> std::vector<T> to_vector() {
std::vector<T> vector(length());
@ -269,9 +240,8 @@ class calibration_dataset_provider : public clr_object_base {
calibration_dataset_provider(array dataset, size_t samples_count,
array fn_params) {
CHECK_CLR(nncase_clr_calibration_dataset_provider_create(
dataset.get(), samples_count, fn_params.get(),
obj_.release_and_addressof()));
obj_ = nncase_clr_api()->calibration_dataset_provider_create(
dataset.get(), samples_count, fn_params.get());
}
};
@ -279,15 +249,18 @@ class quantize_options : public clr_object_base {
public:
using clr_object_base::clr_object_base;
quantize_options() {
CHECK_CLR(
nncase_clr_quantize_options_create(obj_.release_and_addressof()));
}
quantize_options() { obj_ = nncase_clr_api()->quantize_options_create(); }
calibration_dataset_provider calibration_dataset() { return nullptr; }
void calibration_dataset(const calibration_dataset_provider &value) {
CHECK_CLR(nncase_clr_quantize_options_set_calibration_dataset(
obj_.get(), value.get()));
nncase_clr_api()->quantize_options_set_calibration_dataset(obj_.get(),
value.get());
}
nncase_model_quant_mode_t model_quant_mode() { return nncase_mqm_no_quant; }
void model_quant_mode(nncase_model_quant_mode_t value) {
nncase_clr_api()->quantize_options_set_model_quant_mode(obj_.get(),
value);
}
};
@ -296,8 +269,7 @@ class cstream : public clr_object_base {
using clr_object_base::clr_object_base;
cstream(const nncase_stream_mt_t *mt, void *handle) {
CHECK_CLR(
nncase_clr_stream_create(mt, handle, obj_.release_and_addressof()));
obj_ = nncase_clr_api()->stream_create(mt, handle);
}
};
@ -305,44 +277,42 @@ class compile_options : public clr_object_base {
public:
using clr_object_base::clr_object_base;
compile_options() {
CHECK_CLR(
nncase_clr_compile_options_create(obj_.release_and_addressof()));
}
compile_options() { obj_ = nncase_clr_api()->compile_options_create(); }
std::string input_format() { return "cpu"; }
void input_format(std::string_view value) {
CHECK_CLR(nncase_clr_compile_options_set_input_format(
obj_.get(), value.data(), value.length()));
}
std::string target() { return "cpu"; }
void target(std::string_view value) {
CHECK_CLR(nncase_clr_compile_options_set_target(
obj_.get(), value.data(), value.length()));
}
int32_t dump_level() { return 0; }
void dump_level(int32_t value) {
CHECK_CLR(nncase_clr_compile_options_set_dump_level(obj_.get(), value));
nncase_clr_api()->compile_options_set_input_format(
obj_.get(), value.data(), value.length());
}
std::string dump_dir() { return "cpu"; }
void dump_dir(std::string_view value) {
CHECK_CLR(nncase_clr_compile_options_set_dump_dir(
obj_.get(), value.data(), value.length()));
nncase_clr_api()->compile_options_set_dump_dir(obj_.get(), value.data(),
value.length());
}
nncase_dump_flags_t dump_flags() { return nncase_dump_flags_none; }
void dump_flags(nncase_dump_flags_t value) {
nncase_clr_api()->compile_options_set_dump_flags(obj_.get(), value);
}
clr::quantize_options quantize_options() { return nullptr; }
void quantize_options(const clr::quantize_options &value) {
CHECK_CLR(nncase_clr_compile_options_set_quantize_options(obj_.get(),
value.get()));
nncase_clr_api()->compile_options_set_quantize_options(obj_.get(),
value.get());
}
};
class target : public clr_object_base {
public:
using clr_object_base::clr_object_base;
static bool exists(std::string_view name) {
return nncase_clr_api()->target_exists(name.data(), name.length());
}
nncase_model_quant_mode_t model_quant_mode() { return nncase_mqm_no_quant; }
void model_quant_mode(nncase_model_quant_mode_t value) {
CHECK_CLR(
nncase_clr_compile_options_set_model_quant_mode(obj_.get(), value));
target(std::string_view name) {
obj_ = nncase_clr_api()->target_create(name.data(), name.length());
}
};
@ -351,15 +321,12 @@ class rtvalue : public clr_object_base {
using clr_object_base::clr_object_base;
rtvalue(nncase::value_t value) {
CHECK_CLR(nncase_clr_rtvalue_from_handle(value.detach(),
obj_.release_and_addressof()));
obj_ = nncase_clr_api()->rtvalue_from_handle(value.get());
}
value_t to_value() const {
value_t value(nullptr);
CHECK_CLR(nncase_clr_rtvalue_get_handle(obj_.get(),
value.release_and_addressof()));
return value;
auto ptr = nncase_clr_api()->rtvalue_get_handle(obj_.get());
return ptr;
}
};
@ -368,10 +335,8 @@ class expr : public clr_object_base {
using clr_object_base::clr_object_base;
rtvalue evaluate(const array &params, const array &inputs) {
rtvalue value(nullptr);
CHECK_CLR(nncase_clr_expr_evaluate(get(), params.get(), inputs.get(),
value.release_and_addressof()));
return value;
return {std::in_place, nncase_clr_api()->expr_evaluate(
get(), params.get(), inputs.get())};
}
};
@ -390,17 +355,12 @@ class function : public base_function {
using base_function::base_function;
expr body() {
expr body(nullptr);
CHECK_CLR(
nncase_clr_function_get_body(get(), body.release_and_addressof()));
return body;
return {std::in_place, nncase_clr_api()->function_get_body(get())};
}
array parameters() {
array params(nullptr);
CHECK_CLR(nncase_clr_function_get_parameters(
get(), params.release_and_addressof()));
return params;
return {std::in_place,
nncase_clr_api()->function_get_parameters(get())};
}
};
@ -409,10 +369,7 @@ class ir_module : public clr_object_base {
using clr_object_base::clr_object_base;
function entry() {
function entry(nullptr);
CHECK_CLR(nncase_clr_ir_module_get_entry(
get(), entry.release_and_addressof()));
return entry;
return {std::in_place, nncase_clr_api()->ir_module_get_entry(get())};
}
};
@ -420,21 +377,29 @@ class compiler : public clr_object_base {
public:
using clr_object_base::clr_object_base;
compiler(const compile_options &options) {
CHECK_CLR(nncase_clr_compiler_create(options.get(),
obj_.release_and_addressof()));
}
ir_module import_module(cstream &stream) {
ir_module module(nullptr);
CHECK_CLR(nncase_clr_compiler_import_module(
get(), stream.get(), module.release_and_addressof()));
return module;
return {std::in_place,
nncase_clr_api()->compiler_import_module(get(), stream.get())};
}
void compile() { CHECK_CLR(nncase_clr_compiler_compile(obj_.get())); }
void compile() { nncase_clr_api()->compiler_compile(obj_.get()); }
void gencode(cstream &stream) {
CHECK_CLR(nncase_clr_compiler_gencode(obj_.get(), stream.get()));
nncase_clr_api()->compiler_gencode(obj_.get(), stream.get());
}
};
class compile_session : public clr_object_base {
public:
using clr_object_base::clr_object_base;
compile_session(const target &target, const compile_options &options) {
obj_ = nncase_clr_api()->compile_session_create(target.get(),
options.get());
}
clr::compiler compiler() {
return {std::in_place,
nncase_clr_api()->compile_session_get_compiler(obj_.get())};
}
};
} // namespace nncase::clr

View File

@ -59,7 +59,8 @@ class NNCASE_API object_node {
uint32_t release() const noexcept;
template <class T> friend class object_t;
friend int ::nncase_object_free(nncase::object_node *node);
friend int ::nncase_object_add_ref(nncase::object_node *node);
friend int ::nncase_object_release(nncase::object_node *node);
private:
mutable std::atomic<uint32_t> ref_count_;

View File

@ -68,7 +68,13 @@ result<strides_t> to_strides(const uint32_t *strides, uint32_t length) {
} // namespace
extern "C" {
int nncase_object_free(nncase::object_node *node) {
int nncase_object_add_ref(nncase::object_node *node) {
if (node)
node->add_ref();
return 0;
}
int nncase_object_release(nncase::object_node *node) {
if (node)
node->release();
return 0;
@ -178,10 +184,6 @@ int nncase_buffer_as_host(nncase::runtime::buffer_node *buffer,
return -EINVAL;
}
int nncase_buffer_free(nncase::runtime::buffer_node *buffer) {
return nncase_object_free(buffer);
}
int nncase_host_buffer_map(nncase::runtime::host_buffer_node *host_buffer,
nncase::runtime::map_access_t access, void **data,
uint32_t *bytes) {
@ -205,10 +207,6 @@ int nncase_host_buffer_unmap(nncase::runtime::host_buffer_node *host_buffer) {
return -EINVAL;
}
int nncase_host_buffer_free(nncase::runtime::host_buffer_node *host_buffer) {
return nncase_object_free(host_buffer);
}
int nncase_dtype_create_prime(nncase::typecode_t typecode,
nncase::datatype_node **dtype) {
if (dtype) {
@ -223,10 +221,6 @@ int nncase_dtype_get_typecode(nncase::datatype_node *dtype) {
return dtype->typecode();
}
int nncase_dtype_free(nncase::datatype_node *dtype) {
return nncase_object_free(dtype);
}
int nncase_value_is_tensor(nncase::value_node *value, bool *is_tensor) {
if (value && is_tensor) {
*is_tensor = value_t(value).is_a<tensor>();

View File

@ -38,67 +38,6 @@
#define UNMANAGEDCALLERSONLY_METHOD ((const char_t *)-1)
namespace {
struct c_api_mt {
clr_object_handle_t (*array_create)(nncase_array_element_kind_t kind,
const clr_object_handle_t *elements,
size_t count);
clr_object_handle_t (*array_get_item)(clr_object_handle_t array,
size_t index);
size_t (*array_get_length)(clr_object_handle_t array);
clr_object_handle_t (*calibration_dataset_provider_create)(
clr_object_handle_t dataset, size_t samplesCount,
clr_object_handle_t fn_params);
void (*handle_free)(clr_object_handle_t handle);
clr_object_handle_t (*compile_options_create)();
void (*compile_options_set_input_file)(clr_object_handle_t compile_options,
const char *input_file,
size_t input_file_length);
void (*compile_options_set_input_format)(
clr_object_handle_t compile_options, const char *input_format,
size_t input_format_length);
void (*compile_options_set_target)(clr_object_handle_t compile_options,
const char *target,
size_t target_length);
void (*compile_options_set_dump_level)(clr_object_handle_t compile_options,
int32_t dump_level);
void (*compile_options_set_dump_dir)(clr_object_handle_t compile_options,
const char *dump_dir,
size_t dump_dir_length);
void (*compile_options_set_quantize_options)(
clr_object_handle_t compile_options,
clr_object_handle_t quantize_options);
void (*compile_options_set_quant_type)(clr_object_handle_t compile_options,
clr_object_handle_t quant_type);
void (*compile_options_set_model_quant_mode)(
clr_object_handle_t compile_options,
nncase_model_quant_mode_t model_quant_mode);
void (*compiler_initialize)();
clr_object_handle_t (*compiler_create)(clr_object_handle_t compile_options);
clr_object_handle_t (*compiler_import_module)(clr_object_handle_t compiler,
clr_object_handle_t stream);
void (*compiler_compile)(clr_object_handle_t compiler);
void (*compiler_gencode)(clr_object_handle_t compiler,
clr_object_handle_t stream);
clr_object_handle_t (*datatype_from_typecode)(nncase::typecode_t typecode);
clr_object_handle_t (*expr_evaluate)(clr_object_handle_t expr,
clr_object_handle_t parameters,
clr_object_handle_t inputs);
clr_object_handle_t (*function_get_body)(clr_object_handle_t function);
clr_object_handle_t (*function_get_parameters)(
clr_object_handle_t function);
clr_object_handle_t (*ir_module_get_entry)(clr_object_handle_t module);
void (*luanch_debugger)();
clr_object_handle_t (*quantize_options_create)();
void (*quantize_options_set_calibration_dataset)(
clr_object_handle_t quantize_options, clr_object_handle_t dataset);
void (*quantize_options_set_calibration_method)(
clr_object_handle_t quantize_options, nncase_calib_method_t method);
clr_object_handle_t (*rtvalue_from_handle)(nncase::value_node *value);
nncase::value_node *(*rtvalue_get_handle)(clr_object_handle_t rtvalue);
clr_object_handle_t (*stream_create)(const nncase_stream_mt_t *mt,
void *handle);
bool (*target_exists)(const char *target_name, size_t target_name_length);
};
typedef int (*get_function_pointer_fn)(const char_t *type_name,
const char_t *method_name,
@ -106,7 +45,7 @@ typedef int (*get_function_pointer_fn)(const char_t *type_name,
void *load_context, void *reserved,
/*out*/ void **delegate);
typedef void (*c_api_initialize_fn)(c_api_mt *mt);
typedef void (*c_api_initialize_fn)(nncase_api_mt_t *mt);
#ifdef WIN32
#define THROW_WIN32_IF_NOT(x) \
@ -295,14 +234,16 @@ load_compiler_c_api_initializer(const char *root_assembly_path) {
return c_api_initialize;
}
c_api_mt g_c_api_mt;
nncase_api_mt_t g_nncase_api_mt;
} // namespace
nncase_api_mt_t *nncase_clr_api() { return &g_nncase_api_mt; }
int nncase_clr_initialize(const char *root_assembly_path) {
if (!g_c_api_mt.handle_free) {
if (!g_nncase_api_mt.handle_free) {
auto init = load_compiler_c_api_initializer(root_assembly_path);
init(&g_c_api_mt);
g_c_api_mt.compiler_initialize();
init(&g_nncase_api_mt);
g_nncase_api_mt.compiler_initialize();
// SetUnhandledExceptionFilter(MyUnhandledExceptionFilter);
}
@ -310,200 +251,6 @@ int nncase_clr_initialize(const char *root_assembly_path) {
}
int nncase_clr_uninitialize() {
g_c_api_mt = {};
g_nncase_api_mt = {};
return 0;
}
int nncase_clr_array_create(nncase_array_element_kind_t kind,
const clr_object_handle_t *elements, size_t count,
clr_object_handle_t *array) {
*array = g_c_api_mt.array_create(kind, elements, count);
return 0;
}
int nncase_clr_array_get_item(clr_object_handle_t array, size_t index,
clr_object_handle_t *item) {
*item = g_c_api_mt.array_get_item(array, index);
return 0;
}
int nncase_clr_array_get_length(clr_object_handle_t array, size_t *length) {
*length = g_c_api_mt.array_get_length(array);
return 0;
}
int nncase_clr_calibration_dataset_provider_create(
clr_object_handle_t dataset, size_t samplesCount,
clr_object_handle_t fn_params, clr_object_handle_t *provider) {
*provider = g_c_api_mt.calibration_dataset_provider_create(
dataset, samplesCount, fn_params);
return 0;
}
int nncase_clr_handle_free([[maybe_unused]] clr_object_handle_t handle) {
if (g_c_api_mt.handle_free)
g_c_api_mt.handle_free(handle);
return 0;
}
int nncase_clr_compile_options_create(clr_object_handle_t *compile_options) {
*compile_options = g_c_api_mt.compile_options_create();
return 0;
}
int nncase_clr_compile_options_set_inputfile(
clr_object_handle_t compile_options, const char *input_file,
size_t input_file_length) {
g_c_api_mt.compile_options_set_input_file(compile_options, input_file,
input_file_length);
return 0;
}
int nncase_clr_compile_options_set_input_format(
clr_object_handle_t compile_options, const char *input_format,
size_t input_format_length) {
g_c_api_mt.compile_options_set_input_format(compile_options, input_format,
input_format_length);
return 0;
}
int nncase_clr_compile_options_set_target(clr_object_handle_t compile_options,
const char *target,
size_t target_length) {
g_c_api_mt.compile_options_set_target(compile_options, target,
target_length);
return 0;
}
int nncase_clr_compile_options_set_dump_level(
clr_object_handle_t compile_options, int32_t dump_level) {
g_c_api_mt.compile_options_set_dump_level(compile_options, dump_level);
return 0;
}
int nncase_clr_compile_options_set_dump_dir(clr_object_handle_t compile_options,
const char *dump_dir,
size_t dump_dir_length) {
g_c_api_mt.compile_options_set_dump_dir(compile_options, dump_dir,
dump_dir_length);
return 0;
}
int nncase_clr_compile_options_set_quantize_options(
clr_object_handle_t compile_options, clr_object_handle_t quantize_options) {
g_c_api_mt.compile_options_set_quantize_options(compile_options,
quantize_options);
return 0;
}
int nncase_clr_compile_options_set_model_quant_mode(
clr_object_handle_t compile_options,
nncase_model_quant_mode_t model_quant_mode) {
g_c_api_mt.compile_options_set_model_quant_mode(compile_options,
model_quant_mode);
return 0;
}
int nncase_clr_compiler_create(clr_object_handle_t compile_options,
clr_object_handle_t *compiler) {
*compiler = g_c_api_mt.compiler_create(compile_options);
return 0;
}
int nncase_clr_compiler_import_module(clr_object_handle_t compiler,
clr_object_handle_t stream,
clr_object_handle_t *module) {
*module = g_c_api_mt.compiler_import_module(compiler, stream);
return 0;
}
int nncase_clr_compiler_compile(clr_object_handle_t compiler) {
g_c_api_mt.compiler_compile(compiler);
return 0;
}
int nncase_clr_compiler_gencode(clr_object_handle_t compiler,
clr_object_handle_t stream) {
g_c_api_mt.compiler_gencode(compiler, stream);
return 0;
}
int nncase_clr_datatype_from_typecode(nncase::typecode_t typecode,
clr_object_handle_t *datatype) {
*datatype = g_c_api_mt.datatype_from_typecode(typecode);
return 0;
}
int nncase_clr_expr_evaluate(clr_object_handle_t expr,
clr_object_handle_t inputs,
clr_object_handle_t parameters,
clr_object_handle_t *result) {
*result = g_c_api_mt.expr_evaluate(expr, parameters, inputs);
return 0;
}
int nncase_clr_function_get_body(clr_object_handle_t function,
clr_object_handle_t *body) {
*body = g_c_api_mt.function_get_body(function);
return 0;
}
int nncase_clr_function_get_parameters(clr_object_handle_t function,
clr_object_handle_t *parameters) {
*parameters = g_c_api_mt.function_get_parameters(function);
return 0;
}
int nncase_clr_ir_module_get_entry(clr_object_handle_t module,
clr_object_handle_t *entry) {
*entry = g_c_api_mt.ir_module_get_entry(module);
return 0;
}
int nncase_clr_launch_debugger() {
g_c_api_mt.luanch_debugger();
return 0;
}
int nncase_clr_quantize_options_create(clr_object_handle_t *quantize_options) {
*quantize_options = g_c_api_mt.quantize_options_create();
return 0;
}
int nncase_clr_quantize_options_set_calibration_dataset(
clr_object_handle_t quantize_options, clr_object_handle_t dataset) {
g_c_api_mt.quantize_options_set_calibration_dataset(quantize_options,
dataset);
return 0;
}
int nncase_clr_quantize_options_set_calibration_method(
clr_object_handle_t quantize_options, nncase_calib_method_t method) {
g_c_api_mt.quantize_options_set_calibration_method(quantize_options,
method);
return 0;
}
int nncase_clr_rtvalue_from_handle(nncase::value_node *value,
clr_object_handle_t *rtvalue) {
*rtvalue = g_c_api_mt.rtvalue_from_handle(nncase::value_t(value).detach());
return 0;
}
int nncase_clr_rtvalue_get_handle(clr_object_handle_t rtvalue,
nncase::value_node **value) {
nncase::value_t v(g_c_api_mt.rtvalue_get_handle(rtvalue));
*value = v.detach();
return 0;
}
int nncase_clr_stream_create(const nncase_stream_mt_t *mt, void *handle,
clr_object_handle_t *stream) {
*stream = g_c_api_mt.stream_create(mt, handle);
return 0;
}
bool nncase_clr_target_exists(const char *target_name,
size_t target_name_length) {
return g_c_api_mt.target_exists(target_name, target_name_length);
}

View File

@ -17,10 +17,12 @@
#include <nncase/api.h>
#include <nncase/compiler.h>
#include <nncase/io_utils.h>
#include <string_view>
using namespace nncase;
using namespace nncase::clr;
using namespace nncase::runtime;
using namespace std::string_view_literals;
#define TRY(x) \
if (x) \
@ -29,11 +31,13 @@ using namespace nncase::runtime;
int main() {
nncase_clr_initialize(
R"(E:\Work\Repos\nncase\src\Nncase.Compiler\bin\Debug\net6.0\Nncase.Compiler.dll)");
clr_object_ptr compiler, compile_options;
TRY(nncase_clr_compile_options_create(
compile_options.release_and_addressof()));
TRY(nncase_clr_compiler_create(compile_options.get(),
compiler.release_and_addressof()));
auto target_name = "cpu"sv;
auto nncapi = nncase_clr_api();
clr_object_ptr target, compile_session, compiler, compile_options;
compile_options = nncapi->compile_options_create();
target = nncapi->target_create(target_name.data(), target_name.length());
nncapi->compile_session_create(target.get(), compile_options.get());
compiler = nncapi->compile_session_get_compiler(compile_session.get());
auto kmodel = read_file(
R"(E:\Work\Repos\nncase\src\Nncase.Tests\bin\Debug\net6.0\TestCallFunction.kmodel)");
@ -62,7 +66,7 @@ int main() {
nullptr));
memcpy(x_buf_data, x, sizeof(x));
TRY(nncase_host_buffer_unmap(x_host_buf));
TRY(nncase_object_free((object_node *)x_host_buf));
TRY(nncase_object_release((object_node *)x_host_buf));
}
tensor_node *x_tensor;
@ -95,14 +99,14 @@ int main() {
std::cout << *ret_float_data << std::endl;
TRY(nncase_host_buffer_unmap(ret_host_buf));
TRY(nncase_object_free((object_node *)ret_host_buf));
TRY(nncase_object_release((object_node *)ret_host_buf));
}
TRY(nncase_object_free((object_node *)out_buffer_slice.buffer));
TRY(nncase_object_free((object_node *)ret));
TRY(nncase_object_free((object_node *)x_buf));
TRY(nncase_object_free((object_node *)x_tensor));
TRY(nncase_object_free((object_node *)dtype_float32));
TRY(nncase_object_release((object_node *)out_buffer_slice.buffer));
TRY(nncase_object_release((object_node *)ret));
TRY(nncase_object_release((object_node *)x_buf));
TRY(nncase_object_release((object_node *)x_tensor));
TRY(nncase_object_release((object_node *)dtype_float32));
TRY(nncase_interp_free(interp));
return 0;
}

View File

@ -7,12 +7,14 @@ using System.CommandLine;
using System.CommandLine.Invocation;
using System.IO;
using System.Linq;
using Autofac;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Nncase.CodeGen;
using Nncase.Compiler;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.Quantization;
using Nncase.Transform;
namespace Nncase.Cli.Commands;
@ -69,66 +71,83 @@ public class Compile : Command
description: "model quant options, default is Random",
getDefaultValue: () => Quantization.CalibMethod.Random));
Handler = CommandHandler.Create<CliCompileOptions, IHost>(Run);
Handler = CommandHandler.Create<CliCompileOptions, IHost>(RunAsync);
}
private void Run(CliCompileOptions cliOptions, IHost host)
private static DumpFlags DumpLevelToFlags(int dumpLevel)
{
var provider = host.Services.GetRequiredService<ICompilerServicesProvider>();
CompilerServices.Configure(provider);
return dumpLevel switch
{
0 => DumpFlags.None,
1 => DumpLevelToFlags(0) | DumpFlags.Compile,
2 => DumpLevelToFlags(1) | DumpFlags.PassIR,
3 => DumpLevelToFlags(2) | DumpFlags.Rewrite,
4 => DumpLevelToFlags(3) | DumpFlags.EGraphCost,
5 => DumpLevelToFlags(4) | DumpFlags.Evaluator,
6 => DumpLevelToFlags(5) | DumpFlags.Calibration,
7 => DumpLevelToFlags(6) | DumpFlags.Tiling,
8 => DumpLevelToFlags(7) | DumpFlags.Schedule,
>= 9 => DumpLevelToFlags(8) | DumpFlags.CodeGen,
_ => throw new ArgumentOutOfRangeException(nameof(dumpLevel)),
};
}
private async Task RunAsync(CliCompileOptions cliOptions, IHost host)
{
CompilerServices.Configure(host.Services);
// 1. setup the options
Quantization.QuantizeOptions quant_options = new() { CalibrationMethod = cliOptions.CalibMethod };
var compileOptions = new CompileOptions()
var compileOptions = new CompileOptions
{
InputFile = cliOptions.InputFile,
OutputFile = cliOptions.OutputFile,
Target = cliOptions.Target,
InputFormat = cliOptions.InputFormat,
DumpLevel = cliOptions.DumpLevel,
DumpFlags = DumpLevelToFlags(cliOptions.DumpLevel),
DumpDir = cliOptions.DumpDir,
QuantType = cliOptions.QuantType switch
QuantizeOptions = new()
{
QuantType.UInt8 => DataTypes.UInt8,
QuantType.Int8 => DataTypes.Int8,
QuantType.Int16 => DataTypes.Int16,
_ => throw new ArgumentOutOfRangeException(),
CalibrationMethod = cliOptions.CalibMethod,
QuantType = cliOptions.QuantType switch
{
QuantType.UInt8 => DataTypes.UInt8,
QuantType.Int8 => DataTypes.Int8,
QuantType.Int16 => DataTypes.Int16,
_ => throw new ArgumentException("Invalid quant type"),
},
WQuantType = cliOptions.WQuantType switch
{
QuantType.UInt8 => DataTypes.UInt8,
QuantType.Int8 => DataTypes.Int8,
QuantType.Int16 => DataTypes.Int16,
_ => throw new ArgumentException("Invalid weights quant type"),
},
ModelQuantMode = cliOptions.ModelQuantMode,
},
WQuantType = cliOptions.WQuantType switch
{
QuantType.UInt8 => DataTypes.UInt8,
QuantType.Int8 => DataTypes.Int8,
QuantType.Int16 => DataTypes.Int16,
_ => throw new ArgumentOutOfRangeException(),
},
ModelQuantMode = cliOptions.ModelQuantMode,
// todo add the quant options parser
QuantizeOptions = quant_options,
};
// 2. import the model
var compiler = new Compiler.Compiler(compileOptions);
var target = CompilerServices.GetTarget(cliOptions.Target);
using var compileSession = CompileSession.Create(target, compileOptions);
var compiler = compileSession.Compiler;
IRModule module;
using (var model_stream = File.OpenRead(compileOptions.InputFile))
{
module = compiler.ImportModule(model_stream);
module = await compiler.ImportModuleAsync(model_stream);
}
// 3. create the calib dataset
if (compileOptions.ModelQuantMode == Quantization.ModelQuantMode.UsePTQ)
if (compileOptions.QuantizeOptions.ModelQuantMode == Quantization.ModelQuantMode.UsePTQ)
{
if (quant_options.CalibrationMethod == Quantization.CalibMethod.Random)
if (compileOptions.QuantizeOptions.CalibrationMethod == Quantization.CalibMethod.Random)
{
quant_options.CalibrationDataset = new RandCalibrationDatasetProvider(((Function)module.Entry!).Parameters.ToArray());
compileOptions.QuantizeOptions.CalibrationDataset = new RandomCalibrationDatasetProvider(((Function)module.Entry!).Parameters.ToArray(), 5);
}
}
// 4. compile
compiler.Compile();
await compiler.CompileAsync();
// 5. code gen
using (var os = File.OpenWrite(compileOptions.OutputFile))
using (var os = File.OpenWrite(cliOptions.OutputFile))
{
compiler.Gencode(os);
}
@ -167,31 +186,3 @@ internal sealed class CliCompileOptions
/// <inheritdoc/>
public Quantization.CalibMethod CalibMethod { get; set; }
}
internal sealed class RandCalibrationDatasetProvider : Quantization.ICalibrationDatasetProvider
{
private const int CountValue = 5;
private readonly IReadOnlyDictionary<Var, IValue>[] _samples;
public RandCalibrationDatasetProvider(IEnumerable<Var> vars)
{
_samples = Enumerable.Range(0, CountValue).Select(i =>
{
var values = new Dictionary<Var, IValue>();
foreach (var var in vars)
{
CompilerServices.InferenceType(var);
var shape = var.CheckedShape.Select(d => d.IsUnknown ? 1 : d.FixedValue).ToArray();
var value = IR.F.Random.Normal(var.CheckedDataType, 0, 1, 0, shape).Evaluate();
values.Add(var, value);
}
return values;
}).ToArray();
}
public int? Count => CountValue;
public IAsyncEnumerable<IReadOnlyDictionary<Var, IValue>> Samples => _samples.ToAsyncEnumerable();
}

View File

@ -6,7 +6,6 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Autofac.Extensions.DependencyInjection" />
<PackageReference Include="Microsoft.Extensions.Hosting" />
<PackageReference Include="System.CommandLine.Hosting" />
</ItemGroup>

View File

@ -1,35 +1,38 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.CommandLine.Builder;
using System.CommandLine.Hosting;
using System.CommandLine.Parsing;
using System.IO;
using System.Threading.Tasks;
using Autofac;
using Autofac.Extensions.DependencyInjection;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Nncase.Hosting;
namespace Nncase.Cli
namespace Nncase.Cli;
internal partial class Program
{
internal partial class Program
public static async Task<int> Main(string[] args)
{
public static async Task<int> Main(string[] args)
{
return await BuildCommandLine()
.UseHost(
_ => CompilerHost.CreateHostBuilder(args),
host =>
{
host.UseConsoleLifetime();
})
.UseDefaults()
.Build().InvokeAsync(args);
}
return await BuildCommandLine()
.UseHost(ConfigureHost)
.UseDefaults()
.Build().InvokeAsync(args);
}
private static void ConfigureHost(IHostBuilder hostBuilder)
{
hostBuilder.ConfigureAppConfiguration(ConfigureAppConfiguration)
.UseConsoleLifetime()
.ConfigureCompiler();
}
private static void ConfigureAppConfiguration(HostBuilderContext context, IConfigurationBuilder builder)
{
var baseDirectory = Path.GetDirectoryName(typeof(Program).Assembly.Location);
builder.SetBasePath(baseDirectory)
.AddJsonFile("config.json", true, false);
}
}

View File

@ -4,9 +4,5 @@
"Default": "Information",
"Microsoft.Hosting.Lifetime": "Warning"
}
},
"Testing": {
"LogDir": "tests_output",
"LogLevel": 4
}
}

View File

@ -0,0 +1,20 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using DryIoc;
using Nncase.CodeGen;
using Nncase.Diagnostics;
using Nncase.Hosting;
namespace Nncase.Diagnostics;
/// <summary>
/// CodeGen module.
/// </summary>
internal class CodeGenModule : IApplicationPart
{
public void ConfigureServices(IRegistrator registrator)
{
registrator.Register<IModelBuilder, ModelBuilder>(reuse: Reuse.Scoped);
}
}

View File

@ -11,7 +11,7 @@ using Extension.Mathematics;
namespace Nncase.CodeGen;
public sealed class LinkedModel
internal sealed class LinkedModel : ILinkedModel
{
private const int _minAlignmnet = 8;

View File

@ -13,7 +13,7 @@ namespace Nncase.CodeGen;
/// <summary>
/// The Kmodel Builder.
/// </summary>
public sealed class ModelBuilder
public sealed class ModelBuilder : IModelBuilder
{
/// <summary>
/// Initializes a new instance of the <see cref="ModelBuilder"/> class.
@ -27,16 +27,6 @@ public sealed class ModelBuilder
CompileOptions = compileOptions;
}
/// <summary>
/// Initializes a new instance of the <see cref="ModelBuilder"/> class.
/// ctor from the global compile options.
/// </summary>
/// <param name="target"></param>
public ModelBuilder(ITarget target)
: this(target, CompilerServices.CompileOptions)
{
}
/// <summary>
/// Gets get the Target.
/// </summary>
@ -47,7 +37,7 @@ public sealed class ModelBuilder
/// </summary>
public CompileOptions CompileOptions { get; }
public LinkedModel Build(IRModule module)
public ILinkedModel Build(IRModule module)
{
var functionsByKind = module.Functions.GroupBy(x => x.ModuleKind).ToList();
var functionIds = MakeFunctionsIds(functionsByKind);

View File

@ -0,0 +1,32 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using DryIoc;
using Nncase.CodeGen;
using Nncase.Converters;
using Nncase.Diagnostics;
using Nncase.Hosting;
namespace Nncase;
/// <summary>
/// CodeGen application part extensions.
/// </summary>
public static class CodeGenApplicationPart
{
/// <summary>
/// Add diagnostics assembly.
/// </summary>
/// <param name="registrator">Service registrator.</param>
/// <returns>Configured service registrator.</returns>
public static IRegistrator AddCodeGen(this IRegistrator registrator)
{
return registrator.RegisterModule<CodeGenModule>();
}
}

View File

@ -1,13 +1,12 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Autofac;
using Autofac.Extensions.DependencyInjection;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Nncase.CodeGen;
using Nncase.Diagnostics;
using Nncase.Evaluator;
using Nncase.Hosting;
using Nncase.IR;
@ -16,152 +15,131 @@ using Nncase.Transform;
using Nncase.Transform.Passes;
using Nncase.Transform.Rules.Lower;
using Nncase.Utilities;
using OrtKISharp;
namespace Nncase.Compiler;
public class Compiler
internal class Compiler : ICompiler
{
private readonly CompileOptions _compileOptions;
private readonly CompileSession _compileSession;
private readonly IModelBuilder _modelBuilder;
private readonly IDumpper _dumpper;
private IRModule? _module;
public Compiler(CompileOptions compileOptions)
public Compiler(CompileSession compileSession, IModelBuilder modelBuilder, IDumpperFactory dumpperFactory)
{
_compileOptions = compileOptions;
_compileSession = compileSession;
_modelBuilder = modelBuilder;
_dumpper = dumpperFactory.Root;
}
public static void Initialize()
{
var iHost = CompilerHost.CreateHostBuilder().Build();
var provider = iHost.Services.GetRequiredService<ICompilerServicesProvider>();
CompilerServices.Configure(provider);
}
public IRModule Module => _module ?? throw new InvalidOperationException("Module has not been imported");
public IRModule ImportModule(Stream content)
public async Task<IRModule> ImportModuleAsync(Stream content)
{
CompilerServices.CompileOptions = _compileOptions;
// Console.WriteLine($"Target: {options.Target}");
var module = ImportModel(content);
DumpModule(module, "ir_import");
if (_dumpper.IsEnabled(DumpFlags.Compile))
{
_dumpper.DumpModule(module, "IRImport");
}
// Console.WriteLine("Infer Shape...");
if (_compileOptions.DumpLevel > 4)
{
DumpManager.RunWithDump("EvaluatorInShapeInfer", () => RunPass(pmg => pmg.Add(new ShapeInferPass()), "ShapeInferAfterImport"));
}
else
{
RunPass(pmg => pmg.Add(new ShapeInferPass()), "ShapeInferAfterImport");
}
await RunPassAsync(pmg => pmg.Add<ShapeInferPass>(), "ShapeInferAfterImport");
var inferSucc = CompilerServices.InferenceType(module.Entry!);
DumpModule(module, "ir_infertype");
if (!inferSucc)
{
throw new InvalidOperationException("InferShape Failed For This Model!");
}
// Console.WriteLine("ImportModule successful!");
return module;
}
public void TargetIndependentPass(PassManager passManager)
public void TargetIndependentPass(IPassManager passManager)
{
if (_compileOptions.ModelQuantMode == ModelQuantMode.UsePTQ)
var quantMode = _compileSession.CompileOptions.QuantizeOptions.ModelQuantMode;
if (quantMode == ModelQuantMode.UsePTQ)
{
passManager.Add(new EGraphPass("1_NeutralOptimize")
passManager.AddWithName<EGraphPass>("NeutralOptimize").Configure(p =>
{
new Transform.Rules.Neutral.FoldConstCall(),
new Transform.Rules.Neutral.FoldNopTranspose(),
new Transform.Rules.Neutral.FoldTwoTransposes(),
new Transform.Rules.Neutral.CombineTransposeUnary(),
new Transform.Rules.Neutral.CombineTransposePad(),
new Transform.Rules.Neutral.CombinePadTranspose(),
new Transform.Rules.Neutral.CombineTransposeBinary(),
new Transform.Rules.Neutral.CombineTransposeConstBinary(),
new Transform.Rules.Neutral.CombineTransposeReduce(),
new Transform.Rules.Neutral.CombineTransposeActivations(),
new Transform.Rules.Neutral.CombineActivationsTranspose(),
new Transform.Rules.Neutral.FoldNopPad(),
new Transform.Rules.Neutral.FoldConv2DPads(),
new Transform.Rules.Neutral.FoldReduceWindow2DPads(),
p.Add<Transform.Rules.Neutral.FoldConstCall>();
p.Add<Transform.Rules.Neutral.FoldNopTranspose>();
p.Add<Transform.Rules.Neutral.FoldTwoTransposes>();
p.Add<Transform.Rules.Neutral.CombineTransposeUnary>();
p.Add<Transform.Rules.Neutral.CombineTransposePad>();
p.Add<Transform.Rules.Neutral.CombinePadTranspose>();
p.Add<Transform.Rules.Neutral.CombineTransposeBinary>();
p.Add<Transform.Rules.Neutral.CombineTransposeConstBinary>();
p.Add<Transform.Rules.Neutral.CombineTransposeReduce>();
p.Add<Transform.Rules.Neutral.CombineTransposeActivations>();
p.Add<Transform.Rules.Neutral.CombineActivationsTranspose>();
p.Add<Transform.Rules.Neutral.FoldNopPad>();
p.Add<Transform.Rules.Neutral.FoldConv2DPads>();
p.Add<Transform.Rules.Neutral.FoldReduceWindow2DPads>();
});
}
if (_compileOptions.ModelQuantMode == ModelQuantMode.UsePTQ)
if (quantMode == ModelQuantMode.UsePTQ)
{
passManager.Add(new DataflowPass("2_AddRangeOfMarker")
passManager.AddWithName<DataflowPass>("AddRangeOfMarker").Configure(p =>
{
new Transform.Rules.Neutral.AddRangeOfAndMarkerToConv2D(),
new Transform.Rules.Neutral.AddRangeOfAndMarkerToMatMul(),
new Transform.Rules.Neutral.AddRangeOfAndMarkerToReduceWindow2D(),
new Transform.Rules.Neutral.AddRangeOfAndMarkerToConv2DTranspose(),
new Transform.Rules.Neutral.AddRangeOfAndMarkerToBinary(),
p.Add<Transform.Rules.Neutral.AddRangeOfAndMarkerToConv2D>();
p.Add<Transform.Rules.Neutral.AddRangeOfAndMarkerToMatMul>();
p.Add<Transform.Rules.Neutral.AddRangeOfAndMarkerToReduceWindow2D>();
p.Add<Transform.Rules.Neutral.AddRangeOfAndMarkerToConv2DTranspose>();
p.Add<Transform.Rules.Neutral.AddRangeOfAndMarkerToBinary>();
});
passManager.Add(new Quantization.EGraphPassWithQuantize("3_AssignRanges", _compileOptions.QuantizeOptions!));
passManager.AddWithName<EGraphPassWithQuantize>("AssignRanges");
}
}
public void Compile()
public async Task CompileAsync()
{
var t = CompilerServices.GetTarget(_compileOptions.Target);
if (_compileOptions.DumpLevel > 4)
{
DumpManager.RunWithDump("TargetIndependentEval", () => RunPass(p => TargetIndependentPass(p), "TargetIndependentPass"));
}
else
{
RunPass(p => TargetIndependentPass(p), "TargetIndependentPass");
}
var target = _compileSession.Target;
await RunPassAsync(p => TargetIndependentPass(p), "TargetIndependentPass");
await RunPassAsync(p => target.RegisterTargetDependentPass(p, _compileSession.CompileOptions), "TargetIndependentPass");
RunPass(p => t.RegisterTargetDependentPass(p, _compileOptions), "TargetDependentPass");
// RunPass(p => p.Add(new Quantization.EGraphPassWithBindQuantizeConfig("2.5_BindQuantizeConfig", options.QuantizeOptions!)));
if (_compileOptions.ModelQuantMode == ModelQuantMode.UsePTQ)
if (_compileSession.CompileOptions.QuantizeOptions.ModelQuantMode == ModelQuantMode.UsePTQ)
{
RunPass(p => t.RegisterQuantizePass(p, _compileOptions), "QuantizePass");
RunPass(p => t.RegisterTargetDependentAfterQuantPass(p, _compileOptions), "TargetDependentAfterQuantPass");
RunPass(t => t.Add(new DataflowPass("ClearMarker") { new RemoveMarker() }), "RemoveMarker");
await RunPassAsync(p => target.RegisterQuantizePass(p, _compileSession.CompileOptions), "QuantizePass");
await RunPassAsync(p => target.RegisterTargetDependentAfterQuantPass(p, _compileSession.CompileOptions), "TargetDependentAfterQuantPass");
await RunPassAsync(
pmgr => pmgr.Add<DataflowPass>().Configure(p =>
{
p.Name = "ClearMarker";
p.Add<RemoveMarker>();
}),
"RemoveMarker");
}
// fold constant
RunPass(p => p.Add(new Transform.Passes.ShapeInferPass()), "ShapeInferAfterCompile");
// Console.WriteLine("Compile successful");
await RunPassAsync(p => p.Add<ShapeInferPass>(), "ShapeInferAfterCompile");
}
public void Gencode(Stream output)
{
var target = CompilerServices.GetTarget(_compileOptions.Target);
var moduleBuilder = new ModelBuilder(target, _compileOptions);
var linkedModel = moduleBuilder.Build(_module);
var linkedModel = _modelBuilder.Build(Module);
linkedModel.Serialize(output);
// Console.WriteLine("Gencode successful");
}
private IRModule ImportModel(Stream content)
{
_module = _compileOptions.InputFormat switch
_module = _compileSession.CompileOptions.InputFormat switch
{
"tflite" => Importers.ImportTFLite(content, _compileOptions),
"onnx" => Importers.ImportOnnx(content, _compileOptions),
_ => throw new NotImplementedException($"Not Implement {_compileOptions.InputFormat} Impoter!"),
"tflite" => Importers.ImportTFLite(content, _compileSession),
"onnx" => Importers.ImportOnnx(content, _compileSession),
var inputFormat => throw new NotImplementedException($"Not Implement {inputFormat} Importer!"),
};
return _module;
}
private void DumpModule(IRModule module, string prefix)
private async Task RunPassAsync(Action<IPassManager> register, string name)
{
var dumpPath = Path.Combine(_compileOptions.DumpDir, "dump", prefix);
CompilerServices.DumpIR(module.Entry!, prefix, dumpPath);
}
private void RunPass(Action<PassManager> register, string dirName)
{
var pmgr = new PassManager(_module, new RunPassOptions(CompilerServices.GetTarget(_compileOptions.Target), _compileOptions.DumpLevel, Path.Join(_compileOptions.DumpDir, dirName), _compileOptions));
var pmgr = _compileSession.CreatePassManager(name);
register(pmgr);
pmgr.RunAsync().Wait();
_module = await pmgr.RunAsync(Module).ConfigureAwait(false);
if (_dumpper.IsEnabled(DumpFlags.Compile))
{
_dumpper.DumpModule(_module, name);
}
}
}

View File

@ -0,0 +1,40 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using DryIoc;
namespace Nncase.Compiler.Hosting;
/// <summary>
/// Compiler builder.
/// </summary>
public interface ICompilerBuilder
{
/// <summary>
/// Configure modules.
/// </summary>
/// <param name="configureModules">Configure modules action.</param>
/// <returns>Compiler builder.</returns>
ICompilerBuilder ConfigureModules(Action<IRegistrator> configureModules);
}
internal sealed class CompilerBuilder : ICompilerBuilder
{
private readonly IContainer _container;
public CompilerBuilder(IContainer container)
{
_container = container;
}
public ICompilerBuilder ConfigureModules(Action<IRegistrator> configureModules)
{
configureModules(_container);
return this;
}
}

View File

@ -1,70 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using Autofac;
using Autofac.Extensions.DependencyInjection;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Nncase.Hosting;
namespace Nncase.Hosting;
/// <summary>
/// Compiler host helper.
/// </summary>
public static class CompilerHost
{
/// <summary>
/// Create compiler host builder.
/// </summary>
/// <param name="args">Commandline arguments.</param>
/// <returns>Created host builder.</returns>
public static IHostBuilder CreateHostBuilder(string[]? args = null)
{
var host = Host.CreateDefaultBuilder(args);
host.ConfigureAppConfiguration(ConfigureAppConfiguration)
.UseServiceProviderFactory(new AutofacServiceProviderFactory())
.ConfigureContainer<ContainerBuilder>(ConfigureContainer)
.ConfigureServices(ConfigureServices)
.ConfigureLogging(ConfigureLogging);
return host;
}
private static void ConfigureContainer(ContainerBuilder builder)
{
var assemblies = ApplicationParts.LoadApplicationParts(c =>
{
c.AddCore()
.AddEvaluator()
.AddGraph()
.AddEGraph()
.AddStackVM();
});
builder.RegisterAssemblyModules(assemblies);
}
private static void ConfigureServices(HostBuilderContext context, IServiceCollection services)
{
services.AddLogging();
}
private static void ConfigureAppConfiguration(HostBuilderContext context, IConfigurationBuilder builder)
{
builder.SetBasePath(Directory.GetCurrentDirectory())
.AddJsonFile("config.json", true, false);
}
private static void ConfigureLogging(ILoggingBuilder loggingBuilder)
{
loggingBuilder.ClearProviders();
loggingBuilder.AddConsole();
}
}

View File

@ -0,0 +1,82 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using DryIoc;
using DryIoc.Microsoft.DependencyInjection;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Nncase;
using Nncase.Compiler;
using Nncase.Compiler.Hosting;
using Nncase.Hosting;
namespace Microsoft.Extensions.Hosting;
/// <summary>
/// Compiler host builder extentensions.
/// </summary>
public static class CompilerHostBuilderExtensions
{
/// <summary>
/// Configure compiler host builder.
/// </summary>
/// <param name="hostBuilder">Host builder.</param>
/// <param name="configureCompiler">Configure compiler builder.</param>
/// <returns>Created host builder.</returns>
public static IHostBuilder ConfigureCompiler(this IHostBuilder hostBuilder, Action<ICompilerBuilder>? configureCompiler = null)
{
hostBuilder.UseServiceProviderFactory(new DryIocServiceProviderFactory())
.ConfigureContainer<Container>(ConfigureBuiltinModules)
.ConfigureServices(ConfigureServices)
.ConfigureLogging(ConfigureLogging);
hostBuilder.ConfigureContainer<Container>(x => configureCompiler?.Invoke(new CompilerBuilder(x)));
hostBuilder.ConfigureContainer<Container>(ConfigurePlugins);
return hostBuilder;
}
private static void ConfigureBuiltinModules(Container builder)
{
builder.AddCore()
.AddDiagnostics()
.AddEvaluator()
.AddGraph()
.AddEGraph()
.AddCodeGen()
.AddStackVM()
.AddK210();
}
private static void ConfigureServices(HostBuilderContext context, IServiceCollection services)
{
services.AddLogging();
services.AddSingleton<PluginLoader>();
services.AddScoped<ICompiler, Compiler>();
}
private static void ConfigureLogging(ILoggingBuilder loggingBuilder)
{
loggingBuilder.ClearProviders();
loggingBuilder.AddConsole();
}
private static void ConfigurePlugins(Container builder)
{
var pluginLoader = builder.Resolve<PluginLoader>();
var plugins = pluginLoader.LoadPlugins();
foreach (var plugin in plugins)
{
plugin.ConfigureServices(builder);
}
}
}

View File

@ -0,0 +1,126 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Reflection.Metadata;
using System.Reflection.PortableExecutable;
using System.Runtime.CompilerServices;
using System.Runtime.Loader;
using LanguageExt;
using Microsoft.Extensions.Logging;
using Nncase.Compiler;
using Nncase.IR;
namespace Nncase.Hosting;
/// <summary>
/// Plugin loader helper.
/// </summary>
public sealed class PluginLoader
{
private const string _modulesDllPattern = "Nncase.Modules.*.dll";
private const string _pluginPathEnvName = "NNCASE_PLUGIN_PATH";
private static readonly string[] _builtinModules = new[]
{
"Nncase.Modules.StackVM.dll",
"Nncase.Modules.K210.dll",
};
private readonly ILogger<PluginLoader> _logger;
/// <summary>
/// Initializes a new instance of the <see cref="PluginLoader"/> class.
/// </summary>
/// <param name="logger">Logger.</param>
public PluginLoader(ILogger<PluginLoader> logger)
{
_logger = logger;
}
/// <summary>
/// Load plugins.
/// </summary>
/// <returns>Plugins.</returns>
public IReadOnlyList<IPlugin> LoadPlugins()
{
var pluginAsms = GetPluginsSearchDirectories().Select(GetPluginAssemblies).SelectMany(x => x)
.DistinctBy(Path.GetFileName).Select(Assembly.LoadFrom).Distinct().ToList();
var plugins = (from asm in pluginAsms
from t in asm.ExportedTypes
where t.IsClass
&& t.IsAssignableTo(typeof(IPlugin))
let ctor = t.GetConstructor(Type.EmptyTypes)
where ctor != null
select (IPlugin)ctor.Invoke(null)).ToList();
return plugins;
}
private static bool IsLoadableAssembly(string filePath)
{
using var fs = File.OpenRead(filePath);
using var peReader = new PEReader(fs);
if (!peReader.HasMetadata)
{
return false;
}
var metaReader = peReader.GetMetadataReader();
if (!metaReader.IsAssembly)
{
return false;
}
// Is reference assembly
if ((from cah in metaReader.CustomAttributes
let ca = metaReader.GetCustomAttribute(cah)
let ctor = metaReader.GetMemberReference((MemberReferenceHandle)ca.Constructor)
let attrType = metaReader.GetTypeReference((TypeReferenceHandle)ctor.Parent)
where metaReader.GetString(attrType.Namespace) == nameof(System.Runtime.CompilerServices)
&& metaReader.GetString(attrType.Name) == nameof(ReferenceAssemblyAttribute)
select cah).Any())
{
return false;
}
return true;
}
private IEnumerable<string> GetPluginAssemblies(string basePath)
{
return (from filePath in Directory.GetFiles(basePath, _modulesDllPattern, SearchOption.AllDirectories)
where !_builtinModules.Contains(Path.GetFileName(filePath))
&& IsLoadableAssembly(filePath)
select filePath).Distinct();
}
private IEnumerable<string> GetPluginsSearchDirectories()
{
var directories = new List<string>();
// 1. Environment variable
var targetPathEnv = Environment.GetEnvironmentVariable(_pluginPathEnvName);
if (string.IsNullOrWhiteSpace(targetPathEnv))
{
_logger.LogWarning($"{_pluginPathEnvName} is not set.");
}
else
{
var targetPaths = from path in targetPathEnv.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries)
select Environment.ExpandEnvironmentVariables(path);
directories.AddRange(targetPaths);
}
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.LogInformation($"Loading plugins from {string.Join(", ", directories)}.");
}
return directories.Distinct();
}
}

View File

@ -10,11 +10,14 @@ using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Nncase.Diagnostics;
using Nncase.Hosting;
using Nncase.IR;
using Nncase.Quantization;
using Nncase.Runtime;
using Nncase.Runtime.Interop;
using static Nncase.Compiler.PythonHelper;
namespace Nncase.Compiler.Interop;
@ -41,14 +44,12 @@ public unsafe struct CApiMT
public delegate* unmanaged<IntPtr> CompileOptionsCreatePtr;
public delegate* unmanaged<IntPtr, byte*, nuint, void> CompileOptionsSetInputFilePtr;
public delegate* unmanaged<IntPtr, byte*, nuint, void> CompileOptionsSetInputFormatPtr;
public delegate* unmanaged<IntPtr, byte*, nuint, void> CompileOptionsSetTargetPtr;
public delegate* unmanaged<IntPtr, int, void> CompileOptionsSetDumpLevelPtr;
public delegate* unmanaged<IntPtr, byte*, nuint, void> CompileOptionsSetDumpDirPtr;
public delegate* unmanaged<IntPtr, DumpFlags, void> CompileOptionsSetDumpFlagsPtr;
public delegate* unmanaged<IntPtr, IntPtr, void> CompileOptionsSetQuantizeOptionsPtr;
public delegate* unmanaged<IntPtr, IntPtr, void> CompileOptionsSetQuantTypePtr;
public delegate* unmanaged<IntPtr, ModelQuantMode, void> CompileOptionsSetModelQuantModePtr;
public delegate* unmanaged<IntPtr, IntPtr, IntPtr> CompileSessionCreatePtr;
public delegate* unmanaged<IntPtr, IntPtr> CompileSessionGetCompilerPtr;
public delegate* unmanaged<void> CompilerInitializePtr;
public delegate* unmanaged<IntPtr, IntPtr> CompilerCreatePtr;
public delegate* unmanaged<IntPtr, IntPtr, IntPtr> CompilerImportModulePtr;
public delegate* unmanaged<IntPtr, void> CompilerCompilePtr;
public delegate* unmanaged<IntPtr, IntPtr, void> CompilerGencodePtr;
@ -61,9 +62,12 @@ public unsafe struct CApiMT
public delegate* unmanaged<IntPtr> QuantizeOptionsCreatePtr;
public delegate* unmanaged<IntPtr, IntPtr, void> QuantizeOptionsSetCalibrationDatasetPtr;
public delegate* unmanaged<IntPtr, CalibMethod, void> QuantizeOptionsSetCalibrationMethodPtr;
public delegate* unmanaged<IntPtr, ModelQuantMode, void> QuantizeOptionsSetModelQuantModePtr;
public delegate* unmanaged<IntPtr, IntPtr, void> QuantizeOptionsSetQuantTypePtr;
public delegate* unmanaged<IntPtr, IntPtr> RTValueFromHandlePtr;
public delegate* unmanaged<IntPtr, IntPtr> RTValueGetHandlePtr;
public delegate* unmanaged<CStreamMT*, IntPtr, IntPtr> StreamCreatePtr;
public delegate* unmanaged<byte*, nuint, IntPtr> TargetCreatePtr;
public delegate* unmanaged<byte*, nuint, byte> TargetExistsPtr;
}
@ -83,14 +87,12 @@ public static unsafe class CApi
mt->CompileOptionsCreatePtr = &CompileOptionsCreate;
mt->CompileOptionsSetInputFilePtr = &CompileOptionsSetInputFile;
mt->CompileOptionsSetInputFormatPtr = &CompileOptionsSetInputFormat;
mt->CompileOptionsSetTargetPtr = &CompileOptionsSetTarget;
mt->CompileOptionsSetDumpLevelPtr = &CompileOptionsSetDumpLevel;
mt->CompileOptionsSetDumpDirPtr = &CompileOptionsSetDumpDir;
mt->CompileOptionsSetDumpFlagsPtr = &CompileOptionsSetDumpFlags;
mt->CompileOptionsSetQuantizeOptionsPtr = &CompileOptionsSetQuantizeOptions;
mt->CompileOptionsSetQuantTypePtr = &CompileOptionsSetQuantType;
mt->CompileOptionsSetModelQuantModePtr = &CompileOptionsSetModelQuantMode;
mt->CompileSessionCreatePtr = &CompileSessionCreate;
mt->CompileSessionGetCompilerPtr = &CompileSessionGetCompiler;
mt->CompilerInitializePtr = &CompilerInitialize;
mt->CompilerCreatePtr = &CompilerCreate;
mt->CompilerImportModulePtr = &CompilerImportModule;
mt->CompilerCompilePtr = &CompilerCompile;
mt->CompilerGencodePtr = &CompilerGencode;
@ -103,9 +105,12 @@ public static unsafe class CApi
mt->QuantizeOptionsCreatePtr = &QuantizeOptionsCreate;
mt->QuantizeOptionsSetCalibrationDatasetPtr = &QuantizeOptionsSetCalibrationDataset;
mt->QuantizeOptionsSetCalibrationMethodPtr = &QuantizeOptionsSetCalibrationMethod;
mt->QuantizeOptionsSetModelQuantModePtr = &QuantizeOptionsSetModelQuantMode;
mt->QuantizeOptionsSetQuantTypePtr = &QuantizeOptionsSetQuantType;
mt->RTValueFromHandlePtr = &RTValueFromHandle;
mt->RTValueGetHandlePtr = &RTValueGetHandle;
mt->StreamCreatePtr = &StreamCreate;
mt->TargetCreatePtr = &TargetCreate;
mt->TargetExistsPtr = &TargetExists;
}
@ -191,15 +196,9 @@ public static unsafe class CApi
}
[UnmanagedCallersOnly]
private static void CompileOptionsSetTarget(IntPtr compileOptionsHandle, byte* targetPtr, nuint targetLength)
private static void CompileOptionsSetDumpFlags(IntPtr compileOptionsHandle, DumpFlags dumpFlags)
{
Get<CompileOptions>(compileOptionsHandle).Target = ToString(targetPtr, targetLength);
}
[UnmanagedCallersOnly]
private static void CompileOptionsSetDumpLevel(IntPtr compileOptionsHandle, int dumpLevel)
{
Get<CompileOptions>(compileOptionsHandle).DumpLevel = dumpLevel;
Get<CompileOptions>(compileOptionsHandle).DumpFlags = dumpFlags;
}
[UnmanagedCallersOnly]
@ -215,27 +214,27 @@ public static unsafe class CApi
}
[UnmanagedCallersOnly]
private static void CompileOptionsSetQuantType(IntPtr compileOptionsHandle, IntPtr quantTypeHandle)
private static IntPtr CompileSessionCreate(IntPtr targetHandle, IntPtr compileOptionsHandle)
{
Get<CompileOptions>(compileOptionsHandle).QuantType = Get<DataType>(quantTypeHandle);
var target = Get<ITarget>(targetHandle);
var compileOptions = Get<CompileOptions>(compileOptionsHandle);
return GCHandle.ToIntPtr(GCHandle.Alloc(CompileSession.Create(target, compileOptions)));
}
[UnmanagedCallersOnly]
private static void CompileOptionsSetModelQuantMode(IntPtr compileOptionsHandle, ModelQuantMode quantMode)
private static IntPtr CompileSessionGetCompiler(IntPtr compileSessionHandle)
{
Get<CompileOptions>(compileOptionsHandle).ModelQuantMode = quantMode;
var compileSession = Get<CompileSession>(compileSessionHandle);
return GCHandle.ToIntPtr(GCHandle.Alloc(compileSession.Compiler));
}
[UnmanagedCallersOnly]
private static void CompilerInitialize()
{
Compiler.Initialize();
}
[UnmanagedCallersOnly]
private static IntPtr CompilerCreate(IntPtr compileOptionsHandle)
{
return GCHandle.ToIntPtr(GCHandle.Alloc(new Compiler(Get<CompileOptions>(compileOptionsHandle))));
var host = Host.CreateDefaultBuilder()
.ConfigureCompiler()
.Build();
CompilerServices.Configure(host.Services);
}
[UnmanagedCallersOnly]
@ -243,7 +242,7 @@ public static unsafe class CApi
{
var compiler = Get<Compiler>(compilerHandle);
var stream = Get<CStream>(streamHandle);
var module = compiler.ImportModule(stream);
var module = compiler.ImportModuleAsync(stream).Result;
return GCHandle.ToIntPtr(GCHandle.Alloc(module));
}
@ -251,7 +250,7 @@ public static unsafe class CApi
private static void CompilerCompile(IntPtr compilerHandle)
{
var compiler = Get<Compiler>(compilerHandle);
compiler.Compile();
compiler.CompileAsync().Wait();
}
[UnmanagedCallersOnly]
@ -330,10 +329,22 @@ public static unsafe class CApi
Get<QuantizeOptions>(quantizeOptionsHandle).CalibrationMethod = calibMethod;
}
[UnmanagedCallersOnly]
private static void QuantizeOptionsSetModelQuantMode(IntPtr quantizeOptionsHandle, ModelQuantMode quantMode)
{
Get<QuantizeOptions>(quantizeOptionsHandle).ModelQuantMode = quantMode;
}
[UnmanagedCallersOnly]
private static void QuantizeOptionsSetQuantType(IntPtr quantizeOptionsHandle, IntPtr quantTypeHandle)
{
Get<QuantizeOptions>(quantizeOptionsHandle).QuantType = Get<DataType>(quantTypeHandle);
}
[UnmanagedCallersOnly]
private static IntPtr RTValueFromHandle(IntPtr handle)
{
var rtValue = RTValue.FromHandle(handle);
var rtValue = RTValue.FromHandle(handle, true);
return GCHandle.ToIntPtr(GCHandle.Alloc(rtValue));
}
@ -350,6 +361,13 @@ public static unsafe class CApi
return GCHandle.ToIntPtr(GCHandle.Alloc(new CStream(mt, handle)));
}
[UnmanagedCallersOnly]
private static IntPtr TargetCreate(byte* targetNamePtr, nuint targetNameLength)
{
var targetName = ToString(targetNamePtr, targetNameLength);
return GCHandle.ToIntPtr(GCHandle.Alloc(CompilerServices.GetTarget(targetName)));
}
[UnmanagedCallersOnly]
private static byte TargetExists(byte* targetNamePtr, nuint targetNameLength)
{

View File

@ -9,13 +9,16 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Autofac.Extensions.DependencyInjection" />
<PackageReference Include="DryIoc.dll" />
<PackageReference Include="DryIoc.Microsoft.DependencyInjection" />
<PackageReference Include="Microsoft.Extensions.Hosting" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\modules\Nncase.Modules.K210\Nncase.Modules.K210.csproj" />
<ProjectReference Include="..\Nncase.CodeGen\Nncase.CodeGen.csproj" />
<ProjectReference Include="..\Nncase.Core\Nncase.Core.csproj" />
<ProjectReference Include="..\Nncase.Diagnostics\Nncase.Diagnostics.csproj" />
<ProjectReference Include="..\Nncase.Graph\Nncase.Graph.csproj" />
<ProjectReference Include="..\Nncase.EGraph\Nncase.EGraph.csproj" />
<ProjectReference Include="..\Nncase.Evaluator\Nncase.Evaluator.csproj" />

View File

@ -1,130 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System.Diagnostics;
using System.Runtime.InteropServices;
using NetFabric.Hyperlinq;
using Nncase.Evaluator;
using Nncase.IR;
using Nncase.IR.Math;
using Nncase.Quantization;
using Nncase.Runtime.Interop;
using Nncase.Utilities;
namespace Nncase.Compiler;
public static class PythonHelper
{
public static void LaunchDebugger()
{
Console.WriteLine(System.Environment.Version.ToString());
Debugger.Launch();
}
public static IValue TensorValueFromBytes(DataType type, byte[] span, int[] dimensions)
{
return Value.FromTensor(Tensor.FromBytes(type, span, dimensions));
}
public static Tensor TensorFromBytes(DataType type, byte[] span, int[] dimensions)
{
return Tensor.FromBytes(type, span, dimensions);
}
public static byte[] BytesBufferFromTensor(Tensor value)
{
return value.BytesBuffer.ToArray();
}
public static Memory<byte> ToMemory(byte[] bytes) => new(bytes);
public static byte[] GetRTTensorBytes(RTTensor tensor)
{
var buffer = tensor.Buffer.Buffer.AsHost()!;
using (var mmOwner = buffer.Map(RTMapAccess.Read))
{
return mmOwner.Memory.Span.ToArray();
}
}
public static uint[] GetRTTensorDims(RTTensor tensor)
{
return tensor.Dimensions.ToArray();
}
public static IValue Evaluate(Expr expr, IReadOnlyDictionary<Var, IValue>? varsValues = null)
{
if (CompilerServices.CompileOptions.DumpLevel > 4)
{
return DumpManager.RunWithDump("Evaluator", () => CompilerServices.Evaluate(expr, varsValues));
}
else
{
return CompilerServices.Evaluate(expr, varsValues);
}
}
public static RTTensor[] RunSimulator(RTInterpreter interp, RTValue[] input)
{
interp.SetDumpRoot(CompilerServices.CompileOptions.DumpDir);
var entry = interp.Entry;
var result = entry.Invoke(input);
if (result is RTTensor tensor)
{
return new[] { tensor };
}
else if (result is RTTuple tuple)
{
// todo: field maybe a tuple, but not process in this
return tuple.Fields.Select(x => (RTTensor)x).ToArray();
}
throw new NotImplementedException();
}
public static bool TargetExist(string target)
{
try
{
CompilerServices.GetTarget(target);
return true;
}
catch (Exception e)
{
Console.WriteLine(e);
return false;
}
}
// Tensor[sample_count * input_count] dataSet
public static PytestCalibrationDatasetProvider MakeDatasetProvider(Tensor[] dataSet, int sampleCount, Var[] fnParams)
{
var inputCount = dataSet[0].Length / sampleCount;
var samples = dataSet.Chunk(inputCount).Select(inputs => inputs.Zip(fnParams).ToDictionary(
item => item.Item2,
item => (IValue)Value.FromTensor(item.Item1))).ToArray().ToAsyncEnumerable();
return new PytestCalibrationDatasetProvider(samples, sampleCount);
}
public static QuantizeOptions MakeQuantizeOptions(ICalibrationDatasetProvider datasetProvider)
{
return new QuantizeOptions
{ BindQuantMethod = false, CalibrationDataset = datasetProvider, CalibrationMethod = CalibMethod.NoClip };
}
public class PytestCalibrationDatasetProvider : ICalibrationDatasetProvider
{
private readonly int _sampleCount;
public PytestCalibrationDatasetProvider(IAsyncEnumerable<IReadOnlyDictionary<Var, IValue>> samples, int sampleCount)
{
Samples = samples;
_sampleCount = sampleCount;
}
public int? Count => _sampleCount;
public IAsyncEnumerable<IReadOnlyDictionary<Var, IValue>> Samples { get; }
}
}

View File

@ -0,0 +1,46 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR;
namespace Nncase.CodeGen;
/// <summary>
/// Linked model.
/// </summary>
public interface ILinkedModel
{
/// <summary>
/// Gets entry function id.
/// </summary>
FunctionId? Entry { get; }
/// <summary>
/// Gets linked modules.
/// </summary>
IReadOnlyList<ILinkedModule> Modules { get; }
/// <summary>
/// Serialize model to stream.
/// </summary>
/// <param name="output">Stream to be written.</param>
void Serialize(Stream output);
}
/// <summary>
/// Model builder.
/// </summary>
public interface IModelBuilder
{
/// <summary>
/// Build linked model.
/// </summary>
/// <param name="module">Module.</param>
/// <returns>Linked model.</returns>
ILinkedModel Build(IRModule module);
}

View File

@ -2,146 +2,38 @@
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using Nncase.Diagnostics;
using Nncase.Quantization;
namespace Nncase;
/// <summary>
/// CompileOptions.
/// Compile options.
/// </summary>
public sealed class CompileOptions
public sealed record CompileOptions
{
/// <summary>
/// Initializes a new instance of the <see cref="CompileOptions"/> class.
/// copy ctor.
/// Gets or sets input file.
/// </summary>
public CompileOptions(CompileOptions other)
{
InputFile = other.InputFile;
InputFormat = other.InputFormat;
Target = other.Target;
DumpLevel = other.DumpLevel;
DumpDir = other.DumpDir;
ModelQuantMode = other.ModelQuantMode;
QuantType = other.QuantType;
WQuantType = other.WQuantType;
OutputFile = other.OutputFile;
}
public string InputFile { get; set; } = "<stream>";
/// <summary>
/// Initializes a new instance of the <see cref="CompileOptions"/> class.
/// CompileOptions.
/// Gets or sets the import model format.
/// </summary>
public CompileOptions()
{
InputFile = string.Empty;
InputFormat = string.Empty;
Target = string.Empty;
DumpLevel = -1;
DumpDir = string.Empty;
ModelQuantMode = ModelQuantMode.NoQuant;
QuantType = DataTypes.UInt8;
WQuantType = DataTypes.UInt8;
OutputFile = string.Empty;
}
public string InputFormat { get; set; } = "onnx";
/// <summary>
/// Initializes a new instance of the <see cref="CompileOptions"/> class.
/// init.
/// Gets or sets the dump flags.
/// </summary>
/// <param name="modelQuantMode"></param>
public CompileOptions(ModelQuantMode modelQuantMode)
{
InputFile = string.Empty;
InputFormat = string.Empty;
Target = string.Empty;
DumpLevel = -1;
DumpDir = string.Empty;
ModelQuantMode = modelQuantMode;
QuantType = DataTypes.UInt8;
WQuantType = DataTypes.UInt8;
OutputFile = string.Empty;
QuantizeOptions = QuantizeOptions;
}
public DumpFlags DumpFlags { get; set; } = DumpFlags.None;
/// <inheritdoc/>
public string InputFile { get; set; }
/// <summary>
/// Gets or sets the dump directory.
/// </summary>
public string DumpDir { get; set; } = string.Empty;
/// <inheritdoc/>
public string InputFormat { get; set; }
/// <inheritdoc/>
public string Target { get; set; }
/// <inheritdoc/>
public int DumpLevel { get; set; }
/// <inheritdoc/>
public string DumpDir { get; set; }
/// <inheritdoc/>
public DataType QuantType { get; set; }
/// <inheritdoc/>
public DataType WQuantType { get; set; }
/// <inheritdoc/>
public string OutputFile { get; set; }
/// <inheritdoc/>
public ModelQuantMode ModelQuantMode { get; set; }
/// <inheritdoc/>
public QuantizeOptions? QuantizeOptions { get; set; }
/// <summary>
/// Gets or sets quant options.
/// </summary>
public QuantizeOptions QuantizeOptions { get; set; } = QuantizeOptions.CreateNoQuant();
}
// /// <summary>
// /// Options of compile command.
// /// </summary>
// public interface CompileOptions
// {
// /// <summary>
// /// Gets or sets input file.
// /// </summary>
// public string InputFile { get; set; }
// /// <summary>
// /// Gets or sets output file.
// /// </summary>
// public string OutputFile { get; set; }
// /// <summary>
// /// Gets or sets the import model format.
// /// </summary>
// public string InputFormat { get; set; }
// /// <summary>
// /// Gets or sets target.
// /// </summary>
// public string Target { get; set; }
// /// <summary>
// /// Gets or sets the dump level.
// /// </summary>
// public int DumpLevel { get; set; }
// /// <summary>
// /// Gets or sets the dump directory.
// /// </summary>
// public string DumpDir { get; set; }
// /// <summary>
// /// weather use ptq
// /// </summary>
// public bool UsePTQ { get; set; }
// /// <summary>
// /// Gets or sets quant type
// /// </summary>
// public DataType QuantType { get; set; }
// /// <summary>
// /// Gets or sets quant mode
// /// </summary>
// public QuantMode QuantMode { get; set; }
// }

View File

@ -0,0 +1,104 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using DryIoc;
using Microsoft.Extensions.DependencyInjection;
using Nncase.CodeGen;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.Transform;
namespace Nncase;
/// <summary>
/// Compile session.
/// </summary>
public sealed class CompileSession : IServiceProvider, IDisposable
{
private readonly IResolverContext _serviceProvider;
private bool _disposedValue;
private ICompiler? _compiler;
/// <summary>
/// Initializes a new instance of the <see cref="CompileSession"/> class.
/// </summary>
/// <param name="serviceProvider">Service provider.</param>
/// <param name="target">Target.</param>
/// <param name="compileOptions">Compile options.</param>
internal CompileSession(IResolverContext serviceProvider, ITarget target, CompileOptions compileOptions)
{
_serviceProvider = serviceProvider;
Target = target;
CompileOptions = compileOptions;
}
/// <summary>
/// Gets target.
/// </summary>
public ITarget Target { get; }
/// <summary>
/// Gets compile options.
/// </summary>
public CompileOptions CompileOptions { get; }
/// <summary>
/// Gets compiler.
/// </summary>
public ICompiler Compiler => _compiler ??= this.GetRequiredService<ICompiler>();
/// <summary>
/// Create new compile session.
/// </summary>
/// <param name="target">Compile target.</param>
/// <param name="compileOptions">Compile options.</param>
/// <returns>Created compile session.</returns>
public static CompileSession Create(ITarget target, CompileOptions compileOptions)
{
var childContainer = CompilerServices.CreateScope();
childContainer.RegisterInstance(target);
childContainer.RegisterInstance(compileOptions);
var session = new CompileSession(childContainer, target, compileOptions);
childContainer.RegisterInstance(session);
return session;
}
/// <inheritdoc/>
public object? GetService(Type serviceType) => _serviceProvider.GetService(serviceType);
/// <summary>
/// Create new pass manager.
/// </summary>
/// <param name="name">Name.</param>
/// <returns>Created pass manager.</returns>
public IPassManager CreatePassManager(string name)
{
return new PassManager(name, this);
}
/// <inheritdoc/>
public void Dispose()
{
Dispose(disposing: true);
}
private void Dispose(bool disposing)
{
if (!_disposedValue)
{
if (disposing)
{
_serviceProvider.Dispose();
}
_disposedValue = true;
}
}
}

View File

@ -0,0 +1,37 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Nncase;
internal struct CompileSessionScope : IDisposable
{
private static readonly AsyncLocal<CompileSession?> _compileSession = new AsyncLocal<CompileSession?>();
private readonly bool _initialized;
private readonly CompileSession? _originalCompileSession;
public CompileSessionScope(CompileSession compileSession)
{
_initialized = true;
_originalCompileSession = _compileSession.Value;
_compileSession.Value = compileSession;
}
public static CompileSession? Current => _compileSession.Value;
public static CompileSession GetCurrentThrowIfNull() => Current ?? throw new InvalidOperationException($"Current {nameof(CompileSession)} is not set");
public void Dispose()
{
if (_initialized)
{
_compileSession.Value = _originalCompileSession;
}
}
}

View File

@ -3,12 +3,14 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Data;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using DryIoc;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;
using Nncase.CostModel;
using Nncase.Evaluator;
using Nncase.IR;
@ -23,12 +25,6 @@ namespace Nncase;
/// </summary>
public interface ICompilerServicesProvider
{
/// <summary>
/// Gets or sets get CompileOptions.
/// </summary>
/// <returns>CompileOptions.</returns>
CompileOptions CompileOptions { get; set; }
/// <summary>
/// Inference type of the expression tree.
/// </summary>
@ -151,7 +147,7 @@ public interface ICompilerServicesProvider
/// <param name="rules">Rewrite rules.</param>
/// <param name="options">Options.</param>
/// <returns>Rewrited expression.</returns>
Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options);
Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options);
/// <summary>
/// Match enodes as root.
@ -185,7 +181,7 @@ public interface ICompilerServicesProvider
/// <param name="rules">Rewrite rules.</param>
/// <param name="options">Options.</param>
/// <returns>Rewrited expression.</returns>
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options);
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options);
}
internal interface ICompilerServicesProviderInternal
@ -198,21 +194,13 @@ internal interface ICompilerServicesProviderInternal
/// </summary>
public static class CompilerServices
{
private static IServiceProvider? _serviceProvider;
private static ICompilerServicesProvider? _provider;
/// <summary>
/// Gets or sets get the compile options.
/// Gets root services.
/// </summary>
/// <returns></returns>
public static CompileOptions CompileOptions
{
get { return Provider.CompileOptions; }
set { Provider.CompileOptions = value; }
}
public static string CompileTarget => CompileOptions.Target;
public static ITarget GetCompileTarget => GetTarget(CompileTarget);
internal static IServiceProvider ServiceProvider => _serviceProvider ?? throw new InvalidOperationException("Compiler services provider must be set.");
internal static IDataTypeServiceProvider DataTypeService => ((ICompilerServicesProviderInternal)Provider).DataTypeService;
@ -221,10 +209,11 @@ public static class CompilerServices
/// <summary>
/// Configure compiler services.
/// </summary>
/// <param name="provider">Service provider.</param>
public static void Configure(ICompilerServicesProvider provider)
/// <param name="serviceProvider">Root service provider.</param>
public static void Configure(IServiceProvider serviceProvider)
{
_provider = provider;
_serviceProvider = serviceProvider;
_provider = serviceProvider.GetRequiredService<ICompilerServicesProvider>();
}
/// <summary>
@ -353,7 +342,7 @@ public static class CompilerServices
/// <param name="rules">Rewrite rules.</param>
/// <param name="options">Options.</param>
/// <returns>Rewrited expression.</returns>
public static Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options)
public static Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
{
return Provider.Rewrite(expr, rules, options);
}
@ -365,7 +354,7 @@ public static class CompilerServices
/// <param name="rules">Rewrite rules.</param>
/// <param name="options">Options.</param>
/// <returns>Rewrited expression.</returns>
public static Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options)
public static Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
{
return Provider.ERewrite(expr, rules, options);
}
@ -444,6 +433,22 @@ public static class CompilerServices
/// <param name="name">Target name.</param>
/// <returns>Target.</returns>
public static ITarget GetTarget(string name) => Provider.GetTarget(name);
internal static DryIoc.IContainer CreateScope()
{
var container = (DryIoc.IContainer)_serviceProvider!;
var childDefaultServiceKey = new object();
var rules = container.Rules
.WithDefaultRegistrationServiceKey(childDefaultServiceKey)
.WithFactorySelector(Rules.SelectKeyedOverDefaultFactory(childDefaultServiceKey));
return container!.With(
container.Parent,
rules,
container.ScopeContext,
RegistrySharing.CloneButKeepCache,
container.SingletonScope.Clone(false),
Scope.Of(container.OwnCurrentScope));
}
}
internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerServicesProviderInternal
@ -457,11 +462,8 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
private readonly IEGraphMatchProvider _eGraphMatchProvider;
private readonly IEGraphRewriteProvider _eGraphrewriteProvider;
private readonly ITargetProvider _targetProvider;
private CompileOptions _compileOptions;
public CompilerServicesProvider(
// IOptions<CompileOptions> compileOptions,
IEvaluateProvider evaluateProvider,
ITypeInferenceProvider typeInferenceProvider,
IIRPrinterProvider irprinterProvider,
@ -488,12 +490,6 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
public IDataTypeServiceProvider DataTypeService { get; }
public CompileOptions CompileOptions
{
get => _compileOptions;
set => _compileOptions = value;
}
/// <inheritdoc/>
public IValue Evaluate(Expr expr, IReadOnlyDictionary<Var, IValue>? varsValues = null, Dictionary<Type, IEvaluator>? evaluator_cache = null)
{
@ -551,19 +547,19 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
}
/// <inheritdoc/>
public Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options)
public Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
{
return _rewriteProvider.Rewrite(expr, rules, options);
}
/// <inheritdoc/>
public Cost EvaluateCost(Expr expr, IReadOnlyDictionary<Var, Cost>? varsValues = null)
public Cost? EvaluateCost(Expr expr, IReadOnlyDictionary<Var, Cost>? varsValues = null)
{
return _costEvaluateProvider.EvaluateCost(expr, varsValues);
}
/// <inheritdoc/>
public Cost EvaluateOpCost(Op op, ICostEvaluateContext context)
public Cost? EvaluateOpCost(Op op, ICostEvaluateContext context)
{
return _costEvaluateProvider.EvaluateOpCost(op, context);
}
@ -583,7 +579,7 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
return _targetProvider.GetTarget(name);
}
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options)
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
{
return _eGraphrewriteProvider.ERewrite(expr, rules, options);
}

View File

@ -1,31 +1,32 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Autofac;
using DryIoc;
using Nncase.Hosting;
namespace Nncase.Converters;
/// <summary>
/// Converters module.
/// </summary>
public class ConvertersModule : Module
internal class ConvertersModule : IApplicationPart
{
/// <inheritdoc/>
protected override void Load(ContainerBuilder builder)
public void ConfigureServices(IRegistrator registrator)
{
builder.RegisterType<BFloat16Converters>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<BooleanConverters>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<DoubleConverters>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<HalfConverters>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<Int16Converters>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<Int32Converters>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<Int64Converters>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<Int8Converters>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<SingleConverters>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<UInt16Converters>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<UInt32Converters>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<UInt64Converters>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<UInt8Converters>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<PointerConverters>().AsImplementedInterfaces().SingleInstance();
registrator.RegisterManyInterface<BFloat16Converters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<BooleanConverters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<DoubleConverters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<HalfConverters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<Int16Converters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<Int32Converters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<Int64Converters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<Int8Converters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<SingleConverters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<UInt16Converters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<UInt32Converters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<UInt64Converters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<UInt8Converters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<PointerConverters>(reuse: Reuse.Singleton);
}
}

View File

@ -7,6 +7,10 @@ using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using DryIoc;
using Nncase.Converters;
using Nncase.Hosting;
using Nncase.Targets;
namespace Nncase;
@ -18,11 +22,12 @@ public static class CoreApplicationPart
/// <summary>
/// Add core assembly.
/// </summary>
/// <param name="assemblies">Assembly collection.</param>
/// <returns>Updated assembly collection.</returns>
public static IList<Assembly> AddCore(this IList<Assembly> assemblies)
/// <param name="registrator">Service registrator.</param>
/// <returns>Configured service registrator.</returns>
public static IRegistrator AddCore(this IRegistrator registrator)
{
assemblies.Add(typeof(CoreApplicationPart).Assembly);
return assemblies;
return registrator.RegisterModule<CoreModule>()
.RegisterModule<ConvertersModule>()
.RegisterModule<TargetsModule>();
}
}

View File

@ -1,38 +1,40 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Autofac;
using DryIoc;
using Nncase.Hosting;
using Nncase.IR;
namespace Nncase;
/// <summary>
/// Core module.
/// </summary>
public class CoreModule : Module
internal class CoreModule : IApplicationPart
{
/// <inheritdoc/>
protected override void Load(ContainerBuilder builder)
public void ConfigureServices(IRegistrator registrator)
{
builder.RegisterType<CompilerServicesProvider>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<DataTypeServiceProvider>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<IR.IRPrinterProvider>().AsImplementedInterfaces().SingleInstance();
builder.RegisterType<CompileOptions>().AsImplementedInterfaces().SingleInstance();
registrator.RegisterManyInterface<CompilerServicesProvider>(reuse: Reuse.Singleton);
registrator.Register<IDataTypeServiceProvider, DataTypeServiceProvider>(reuse: Reuse.Singleton);
registrator.Register<IIRPrinterProvider, IRPrinterProvider>(reuse: Reuse.Singleton);
// Prim types
builder.RegisterType<BooleanType>().As<PrimType>().SingleInstance();
builder.RegisterType<Utf8CharType>().As<PrimType>().SingleInstance();
builder.RegisterType<Int8Type>().As<PrimType>().SingleInstance();
builder.RegisterType<Int16Type>().As<PrimType>().SingleInstance();
builder.RegisterType<Int32Type>().As<PrimType>().SingleInstance();
builder.RegisterType<Int64Type>().As<PrimType>().SingleInstance();
builder.RegisterType<UInt8Type>().As<PrimType>().SingleInstance();
builder.RegisterType<UInt16Type>().As<PrimType>().SingleInstance();
builder.RegisterType<UInt32Type>().As<PrimType>().SingleInstance();
builder.RegisterType<UInt64Type>().As<PrimType>().SingleInstance();
builder.RegisterType<Float16Type>().As<PrimType>().SingleInstance();
builder.RegisterType<Float32Type>().As<PrimType>().SingleInstance();
builder.RegisterType<Float64Type>().As<PrimType>().SingleInstance();
builder.RegisterType<BFloat16Type>().As<PrimType>().SingleInstance();
builder.RegisterType<QuantParamType>().As<ValueType>().SingleInstance();
registrator.Register<PrimType, BooleanType>(reuse: Reuse.Singleton);
registrator.Register<PrimType, Utf8CharType>(reuse: Reuse.Singleton);
registrator.Register<PrimType, Int8Type>(reuse: Reuse.Singleton);
registrator.Register<PrimType, Int16Type>(reuse: Reuse.Singleton);
registrator.Register<PrimType, Int32Type>(reuse: Reuse.Singleton);
registrator.Register<PrimType, Int64Type>(reuse: Reuse.Singleton);
registrator.Register<PrimType, UInt8Type>(reuse: Reuse.Singleton);
registrator.Register<PrimType, UInt16Type>(reuse: Reuse.Singleton);
registrator.Register<PrimType, UInt32Type>(reuse: Reuse.Singleton);
registrator.Register<PrimType, UInt64Type>(reuse: Reuse.Singleton);
registrator.Register<PrimType, Float16Type>(reuse: Reuse.Singleton);
registrator.Register<PrimType, Float32Type>(reuse: Reuse.Singleton);
registrator.Register<PrimType, Float64Type>(reuse: Reuse.Singleton);
registrator.Register<PrimType, BFloat16Type>(reuse: Reuse.Singleton);
// Value types
registrator.Register<ValueType, QuantParamType>(reuse: Reuse.Singleton);
}
}

View File

@ -6,7 +6,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Autofac;
using DryIoc;
namespace Nncase;
@ -28,27 +28,27 @@ internal class DataTypeServiceProvider : IDataTypeServiceProvider
private readonly Dictionary<RuntimeTypeHandle, PrimType> _primTypes = new();
private readonly Dictionary<Runtime.TypeCode, PrimType> _typeCodeToPrimTypes = new();
private readonly Dictionary<RuntimeTypeHandle, ValueType> _valueTypes = new();
private readonly IComponentContext _componentContext;
private readonly IResolver _resolver;
public DataTypeServiceProvider(PrimType[] primTypes, ValueType[] valueTypes, IComponentContext componentContext)
public DataTypeServiceProvider(PrimType[] primTypes, ValueType[] valueTypes, IResolver resolver)
{
_primTypes = primTypes.ToDictionary(x => x.CLRType.TypeHandle);
_typeCodeToPrimTypes = primTypes.Where(x => x.TypeCode < Runtime.TypeCode.ValueType).ToDictionary(x => x.TypeCode);
_valueTypes = valueTypes.ToDictionary(x => x.CLRType.TypeHandle);
_componentContext = componentContext;
_resolver = resolver;
}
public ISpanConverter GetConverter(Type fromType, Type toType)
{
if (fromType.IsGenericType && fromType.GetGenericTypeDefinition() == typeof(Pointer<>))
{
var converter = _componentContext.Resolve(typeof(IPointerSpanConverter<>).MakeGenericType(toType));
var converter = _resolver.Resolve(typeof(IPointerSpanConverter<>).MakeGenericType(toType));
var wrapperType = typeof(PointerSpanConverter<,>).MakeGenericType(fromType.GenericTypeArguments[0], toType);
return (ISpanConverter)Activator.CreateInstance(wrapperType, converter)!;
}
else
{
return (ISpanConverter)_componentContext.Resolve(typeof(ISpanConverter<,>).MakeGenericType(fromType, toType));
return (ISpanConverter)_resolver.Resolve(typeof(ISpanConverter<,>).MakeGenericType(fromType, toType));
}
}

View File

@ -0,0 +1,72 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Nncase.Diagnostics;
/// <summary>
/// Dump flags.
/// </summary>
[Flags]
public enum DumpFlags
{
/// <summary>
/// Nothing need to be dump.
/// </summary>
None = 0,
/// <summary>
/// Dump import ops.
/// </summary>
ImportOps = 1 << 1,
/// <summary>
/// Dump pass pre and post ir.
/// </summary>
PassIR = 1 << 2,
/// <summary>
/// Dump egraph costs.
/// </summary>
EGraphCost = 1 << 3,
/// <summary>
/// Dump rewrite.
/// </summary>
Rewrite = 1 << 4,
/// <summary>
/// Dump calibration.
/// </summary>
Calibration = 1 << 5,
/// <summary>
/// Dump evaluator values.
/// </summary>
Evaluator = 1 << 6,
/// <summary>
/// Dump compile stages.
/// </summary>
Compile = 1 << 7,
/// <summary>
/// Dump tiling.
/// </summary>
Tiling = 1 << 8,
/// <summary>
/// Dump schedule.
/// </summary>
Schedule = 1 << 9,
/// <summary>
/// Dump codegen.
/// </summary>
CodeGen = 1 << 10,
}

View File

@ -0,0 +1,68 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
namespace Nncase.Diagnostics;
/// <summary>
/// <see cref="IDumpper"/> scope.
/// </summary>
public struct DumpScope : IDisposable
{
private static readonly AsyncLocal<IDumpper> _dumpper = new AsyncLocal<IDumpper>();
private readonly bool _initialized;
private readonly IDumpper _originalDumpper;
/// <summary>
/// Initializes a new instance of the <see cref="DumpScope"/> struct.
/// </summary>
/// <param name="subDirectory">Sub directory.</param>
/// <param name="serviceProvider">Service provider.</param>
public DumpScope(string subDirectory, IServiceProvider? serviceProvider = null)
{
_initialized = true;
_originalDumpper = GetCurrent(serviceProvider);
_dumpper.Value = _originalDumpper.CreateSubDummper(subDirectory);
}
/// <summary>
/// Initializes a new instance of the <see cref="DumpScope"/> struct.
/// </summary>
/// <param name="newDumpper">New dumpper.</param>
/// <param name="serviceProvider">Service provider.</param>
public DumpScope(IDumpper newDumpper, IServiceProvider? serviceProvider = null)
{
_initialized = true;
_originalDumpper = GetCurrent(serviceProvider);
_dumpper.Value = newDumpper;
}
/// <summary>
/// Gets current dumpper.
/// </summary>
public static IDumpper Current => GetCurrent(null);
/// <summary>
/// Gets current <see cref="IDumpper"/> or use root of scope.
/// </summary>
/// <param name="serviceProvider">Service provider.</param>
/// <returns>Current dumpper.</returns>
public static IDumpper GetCurrent(IServiceProvider? serviceProvider) =>
_dumpper.Value ??= (serviceProvider ?? CompileSessionScope.Current)?.GetRequiredService<IDumpperFactory>().Root ?? NullDumpper.Instance;
/// <inheritdoc/>
public void Dispose()
{
if (_initialized)
{
_dumpper.Value = _originalDumpper;
}
}
}

View File

@ -0,0 +1,63 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR;
using Nncase.Utilities;
namespace Nncase.Diagnostics;
/// <summary>
/// Data dumpper.
/// </summary>
public interface IDumpper
{
/// <summary>
/// Gets dump directory.
/// </summary>
string Directory { get; }
/// <summary>
/// Gets a value indicating whether dump is enabled.
/// </summary>
/// <param name="dumpFlags">Dump flags.</param>
/// <returns>Whether dump is enabled.</returns>
bool IsEnabled(DumpFlags dumpFlags);
/// <summary>
/// Create sub dummper.
/// </summary>
/// <param name="subDirectory">Sub directory.</param>
/// <returns>Sub dummper.</returns>
IDumpper CreateSubDummper(string subDirectory);
void DumpIR(Expr expr, string prefix, string? reletivePath = null);
void DumpDotIR(Expr expr, string prefix, string? reletivePath = null);
void DumpModule(IRModule module, string? reletivePath = null);
Stream OpenFile(string reletivePath, FileMode fileMode = FileMode.Create);
}
/// <summary>
/// Dumpper factory.
/// </summary>
public interface IDumpperFactory
{
/// <summary>
/// Gets root dummper.
/// </summary>
IDumpper Root { get; }
/// <summary>
/// Creat dumpper.
/// </summary>
/// <param name="relativePath">Sub directory.</param>
/// <returns>Dumpper.</returns>
IDumpper CreateDummper(string relativePath);
}

View File

@ -0,0 +1,48 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR;
namespace Nncase.Diagnostics;
/// <summary>
/// <see cref="IDumpper"/> with no backing store.
/// </summary>
public sealed class NullDumpper : IDumpper
{
/// <summary>
/// Gets instance.
/// </summary>
public static NullDumpper Instance { get; } = new NullDumpper();
/// <inheritdoc/>
public string Directory => System.IO.Directory.GetCurrentDirectory();
/// <inheritdoc/>
public IDumpper CreateSubDummper(string subDirectory) => this;
/// <inheritdoc/>
public void DumpIR(Expr expr, string prefix, string? reletivePath = null)
{
}
public void DumpDotIR(Expr expr, string prefix, string? reletivePath = null)
{
}
/// <inheritdoc/>
public void DumpModule(IRModule module, string? reletivePath = null)
{
}
/// <inheritdoc/>
public bool IsEnabled(DumpFlags dumpFlags) => false;
/// <inheritdoc/>
public Stream OpenFile(string reletivePath, FileMode fileMode = FileMode.Create) => Stream.Null;
}

View File

@ -1,96 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using Nncase.IR;
namespace Nncase.Hosting;
/// <summary>
/// Application parts helper.
/// </summary>
public class ApplicationParts
{
private const string _appPartsDllPattern = "Nncase.Modules.*.dll";
private const string _targetPathEnvName = "NNCASE_TARGET_PATH";
/// <summary>
/// Load application parts.
/// </summary>
/// <param name="configureAction">Configure action.</param>
/// <returns>Application parts assemblies.</returns>
public static Assembly[] LoadApplicationParts(Action<IList<Assembly>> configureAction)
{
var defaultAssemblies = new List<Assembly>() { Assembly.GetCallingAssembly() };
configureAction(defaultAssemblies);
return defaultAssemblies.Concat(
GetApplicationPartsSearchDirectories().Select(LoadApplicationParts).SelectMany(x => x)
.DistinctBy(Path.GetFileName).Select(Assembly.LoadFrom))
.Distinct().ToArray();
}
private static IEnumerable<string> LoadApplicationParts(string basePath)
{
return Directory.GetFiles(basePath, _appPartsDllPattern, SearchOption.AllDirectories)
.Where(x => !Path.GetDirectoryName(x)!.EndsWith("ref"));
}
private static IEnumerable<string> GetApplicationPartsSearchDirectories()
{
var directories = new List<string>();
// 1. Executable base
var exePath = Path.GetDirectoryName(Assembly.GetCallingAssembly().Location);
if (!string.IsNullOrWhiteSpace(exePath))
{
directories.Add(exePath);
}
// 2. Environment variable
var targetPathEnv = Environment.GetEnvironmentVariable(_targetPathEnvName);
if (string.IsNullOrWhiteSpace(targetPathEnv))
{
// todo:log
Console.WriteLine("NNCASE_TARGET_PATH is not set.");
}
else
{
var targetPaths = from path in targetPathEnv.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries)
select Environment.ExpandEnvironmentVariables(path);
directories.AddRange(targetPaths);
}
foreach (var directory in directories)
{
Console.WriteLine(directory);
}
return directories.Distinct();
}
}
// Custom comparer for the Product class
internal class PathComparer : IEqualityComparer<string>
{
// Products are equal if their names and product numbers are equal.
public bool Equals(string x, string y)
{
Console.WriteLine("------------------");
Console.WriteLine(x);
Console.WriteLine(y);
Console.WriteLine("------------------");
return Path.GetFileName(x) == Path.GetFileName(y) || x == y;
}
// If Equals() returns true for a pair of objects
// then GetHashCode() must return the same value for these objects.
public int GetHashCode(string s)
{
return s.GetHashCode();
}
}

View File

@ -0,0 +1,53 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using DryIoc;
namespace Nncase.Hosting;
/// <summary>
/// Application part.
/// </summary>
public interface IApplicationPart
{
/// <summary>
/// Configure services.
/// </summary>
/// <param name="registrator">Service registrator.</param>
void ConfigureServices(IRegistrator registrator);
}
/// <summary>
/// Application part extensions.
/// </summary>
public static class ApplicationPartExtensions
{
/// <summary>
/// Register module.
/// </summary>
/// <typeparam name="TModule">Module type.</typeparam>
/// <param name="registrator">Service registrator.</param>
/// <returns>Configured service registrator.</returns>
public static IRegistrator RegisterModule<TModule>(this IRegistrator registrator)
where TModule : class, IApplicationPart, new()
{
var module = new TModule();
module.ConfigureServices(registrator);
return registrator;
}
/// <summary>Registers single registration for all implemented public interfaces and base classes.</summary>
/// <typeparam name="TImplementation">Implementation type.</typeparam>
/// <param name="registrator">Service registrator.</param>
/// <param name="reuse">Reuse strategy.</param>
public static void RegisterManyInterface<TImplementation>(this IRegistrator registrator, IReuse? reuse = null)
where TImplementation : class
{
registrator.RegisterMany<TImplementation>(reuse, serviceTypeCondition: t => t.IsInterface);
}
}

View File

@ -0,0 +1,17 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Nncase.Hosting;
/// <summary>
/// Plugin.
/// </summary>
public interface IPlugin : IApplicationPart
{
}

View File

@ -0,0 +1,36 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR;
namespace Nncase;
/// <summary>
/// Compiler.
/// </summary>
public interface ICompiler
{
/// <summary>
/// Import DL model as ir module.
/// </summary>
/// <param name="content">Model content.</param>
/// <returns>Imported ir module.</returns>
Task<IRModule> ImportModuleAsync(Stream content);
/// <summary>
/// Compile module.
/// </summary>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
Task CompileAsync();
/// <summary>
/// Generate code to stream.
/// </summary>
/// <param name="output">Stream to be written.</param>
void Gencode(Stream output);
}

View File

@ -8,82 +8,81 @@ using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Nncase.IR
namespace Nncase.IR;
/// <summary>
/// the interface that we can use parameterinfo the parameter.
/// </summary>
/// <typeparam name="T"></typeparam>
public interface IParameterList<T>
{
/// <summary>
/// the interface that we can use parameterinfo the parameter.
/// get parameter info.
/// </summary>
/// <typeparam name="T"></typeparam>
public interface IParameterList<T>
/// <param name="parameter"></param>
/// <returns></returns>
public T this[ParameterInfo parameter] { get; }
}
/// <summary>
/// Call expression.
/// </summary>
public sealed record Call(Expr Target, IRArray<Expr> Parameters) : Expr, IParameterList<Expr>
{
// /// <summary>
// /// used by fake ir, represents that whether this op permit int 16 quant.
// /// </summary>
// public bool PermitInt16Quant = false;
/// <summary>
/// quant config with cosine, List of DataType represents data types for each input might be quantized, List of QuantParam represents quant params for each input.
/// may be deleted in the future since there is EnodeBestQuantConfigWithCosine, reserve it now for debug and for unexpected usage when EnodeBestQuantConfigWithCosine is not enough.
/// </summary>
public List<Tuple<List<DataType>, List<List<QuantParam>>, float>> EnodeQuantConfigWithCosine;
/// <summary>
/// quant config with cosine, List of DataType represents data types for each input might be quantized, List of QuantParam represents quant params for each input.
/// </summary>
public Tuple<List<DataType>, List<List<QuantParam>>, float> EnodeBestQuantConfigWithCosine;
/// <summary>
/// Initializes a new instance of the <see cref="Call"/> class.
/// </summary>
/// <param name="target">Call target.</param>
/// <param name="parameters">Parameters.</param>
public Call(Expr target, params Expr[] parameters)
: this(target, new IRArray<Expr>(parameters.ToImmutableArray()))
{
/// <summary>
/// get parameter info.
/// </summary>
/// <param name="parameter"></param>
/// <returns></returns>
public T this[ParameterInfo parameter] { get; }
}
/// <summary>
/// Call expression.
/// get param expr.
/// </summary>
public sealed record Call(Expr Target, IRArray<Expr> Parameters) : Expr, IParameterList<Expr>
/// <param name="parameter"></param>
/// <returns></returns>
/// <exception cref="ArgumentOutOfRangeException"></exception>
public Expr this[ParameterInfo parameter]
{
// /// <summary>
// /// used by fake ir, represents that whether this op permit int 16 quant.
// /// </summary>
// public bool PermitInt16Quant = false;
/// <summary>
/// quant config with cosine, List of DataType represents data types for each input might be quantized, List of QuantParam represents quant params for each input.
/// may be deleted in the future since there is EnodeBestQuantConfigWithCosine, reserve it now for debug and for unexpected usage when EnodeBestQuantConfigWithCosine is not enough.
/// </summary>
public List<Tuple<List<DataType>, List<List<QuantParam>>, float>> EnodeQuantConfigWithCosine;
/// <summary>
/// quant config with cosine, List of DataType represents data types for each input might be quantized, List of QuantParam represents quant params for each input.
/// </summary>
public Tuple<List<DataType>, List<List<QuantParam>>, float> EnodeBestQuantConfigWithCosine;
/// <summary>
/// Initializes a new instance of the <see cref="Call"/> class.
/// </summary>
/// <param name="target">Call target.</param>
/// <param name="parameters">Parameters.</param>
public Call(Expr target, params Expr[] parameters)
: this(target, new IRArray<Expr>(parameters.ToImmutableArray()))
get
{
}
/// <summary>
/// get param expr.
/// </summary>
/// <param name="parameter"></param>
/// <returns></returns>
/// <exception cref="ArgumentOutOfRangeException"></exception>
public Expr this[ParameterInfo parameter]
{
get
var type = Target.GetType();
if (type == parameter.OwnerType)
{
var type = Target.GetType();
if (type == parameter.OwnerType)
{
return Parameters[parameter.Index];
}
else
{
throw new ArgumentOutOfRangeException($"Target {Target} doesn't have parameter: {parameter.Name}.");
}
return Parameters[parameter.Index];
}
else
{
throw new ArgumentOutOfRangeException($"Target {Target} doesn't have parameter: {parameter.Name}.");
}
}
}
public void ParametersForeach(Action<Expr, ParameterInfo> f)
public void ParametersForeach(Action<Expr, ParameterInfo> f)
{
var parameterInfos = ((Op)Target).Parameters.ToArray();
for (int i = 0; i < Parameters.Count; i++)
{
var parameterInfos = ((Op)Target).Parameters.ToArray();
for (int i = 0; i < Parameters.Count; i++)
{
f(Parameters[i], parameterInfos[i]);
}
f(Parameters[i], parameterInfos[i]);
}
}
}

View File

@ -0,0 +1,113 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR;
namespace Nncase.IR;
/// <summary>
/// Module.
/// </summary>
public sealed class IRModule
{
private readonly List<BaseFunction> _functions;
private int? _entryIndex;
/// <summary>
/// Initializes a new instance of the <see cref="IRModule"/> class.
/// </summary>
/// <param name="main">main func.</param>
public IRModule(BaseFunction main)
{
_functions = new() { main };
_entryIndex = 0;
}
/// <summary>
/// Initializes a new instance of the <see cref="IRModule"/> class.
/// the default IrModule ctor.
/// </summary>
public IRModule()
{
_functions = new();
}
/// <summary>
/// Gets functions.
/// </summary>
public IReadOnlyList<BaseFunction> Functions => _functions;
/// <summary>
/// Gets or sets entry function.
/// </summary>
public BaseFunction? Entry
{
get => _entryIndex.HasValue ? _functions[_entryIndex.Value] : null;
set => _entryIndex = value != null ? _functions.FindIndex(x => object.ReferenceEquals(x, value)) : null;
}
/// <summary>
/// Add function.
/// </summary>
/// <param name="function">Callable to add.</param>
public void Add(BaseFunction function)
{
_functions.Add(function);
}
/// <summary>
/// Replace the function defination.
/// </summary>
/// <param name="index">function index.</param>
/// <param name="function">the entry function defination.</param>
public void Replace(int index, BaseFunction function)
{
var old = _functions[index];
var replacer = new FunctionReplacer(old, function);
for (int i = 0; i < _functions.Count; i++)
{
replacer.Visit(_functions[i]);
}
for (int i = 0; i < _functions.Count; i++)
{
var originFunc = _functions[i];
if (replacer.ExpressionMemo.TryGetValue(originFunc, out var replace))
{
_functions[i] = (BaseFunction)replace;
}
}
}
/// <summary>
/// Replace the function call dependencer.
/// </summary>
private sealed class FunctionReplacer : DeepExprMutator
{
private readonly BaseFunction _original;
private readonly BaseFunction _replace;
public FunctionReplacer(BaseFunction original, BaseFunction replace)
{
_original = original;
_replace = replace;
}
public override Expr DefaultMutateLeaf(Expr expr)
{
if (expr is BaseFunction baseFunction
&& object.ReferenceEquals(baseFunction, _original))
{
return _replace;
}
return expr;
}
}
}

View File

@ -1,94 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Nncase.IR
{
/// <summary>
/// Module.
/// </summary>
public sealed class IRModule
{
private readonly List<BaseFunction> _functions;
/// <summary>
/// the index of the entry function.
/// </summary>
private int _entryIndex;
/// <summary>
/// Initializes a new instance of the <see cref="IRModule"/> class.
/// </summary>
/// <param name="main"> main func.</param>
public IRModule(BaseFunction main)
{
_functions = new();
_functions.Add(main);
_entryIndex = 0;
}
/// <summary>
/// Initializes a new instance of the <see cref="IRModule"/> class.
/// the default IrModule ctor.
/// </summary>
public IRModule()
{
_functions = new();
_entryIndex = -1;
}
/// <summary>
/// Gets functions.
/// </summary>
public IReadOnlyList<BaseFunction> Functions => _functions;
/// <summary>
/// Gets or sets entry function.
/// </summary>
public BaseFunction? Entry
{
get => _entryIndex == -1 ? null : Functions[_entryIndex];
set
{
if (value is null)
{
_entryIndex = -1;
}
else
{
_entryIndex = _functions.IndexOf(value);
if (_entryIndex == -1)
{
_functions.Add(value);
_entryIndex = _functions.Count - 1;
}
}
}
}
/// <summary>
/// Add function.
/// </summary>
/// <param name="function">Callable to add.</param>
public void Add(BaseFunction function)
{
_functions.Add(function);
}
/// <summary>
/// update the entry function defination.
/// </summary>
/// <param name="i">function index.</param>
/// <param name="function">the entry function defination.</param>
public void Update(int i, BaseFunction function)
{
_functions[i] = function;
}
}
}

View File

@ -6,6 +6,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Xml.Linq;
namespace Nncase.IR
{
@ -52,7 +53,7 @@ namespace Nncase.IR
/// <summary>
/// Variable expression.
/// </summary>
public record Var : Expr
public record Var : Expr, IEquatable<Var>
{
private static int _globalVarIndex;
@ -64,9 +65,10 @@ namespace Nncase.IR
/// <param name="typeAnnotation"></param>
public Var(string name, IRType typeAnnotation)
{
Name = name;
TypeAnnotation = typeAnnotation;
CheckedType = TypeAnnotation;
GlobalVarIndex = GetNextId();
Name = name;
}
/// <summary>
@ -74,8 +76,11 @@ namespace Nncase.IR
/// </summary>
/// <param name="typeAnnotation">Type annotation.</param>
public Var(IRType typeAnnotation)
: this($"var_{_globalVarIndex++}", typeAnnotation)
{
TypeAnnotation = typeAnnotation;
CheckedType = TypeAnnotation;
GlobalVarIndex = GetNextId();
Name = $"var_{GlobalVarIndex}";
}
/// <summary>
@ -92,14 +97,14 @@ namespace Nncase.IR
/// Initializes a new instance of the <see cref="Var"/> class.
/// </summary>
public Var()
: this($"var_{_globalVarIndex++}", AnyType.Default)
: this(AnyType.Default)
{
}
/// <summary>
/// Gets get the global var index.
/// Gets the global var index.
/// </summary>
private int GlobalVarIndex => _globalVarIndex;
public int GlobalVarIndex { get; }
/// <summary>
/// Gets name.
@ -140,5 +145,24 @@ namespace Nncase.IR
/// <param name="name"></param>
/// <returns></returns>
public static Var SizeVar(string name) => Scalar(name, DataTypes.Int32);
/// <inheritdoc/>
public virtual bool Equals(Var? other)
{
if (other == null)
{
return false;
}
return GlobalVarIndex == other.GlobalVarIndex;
}
/// <inheritdoc/>
public override int GetHashCode() => GlobalVarIndex.GetHashCode();
private static int GetNextId()
{
return Interlocked.Increment(ref _globalVarIndex);
}
}
}

View File

@ -27,23 +27,21 @@ public interface ITarget
/// Bind Quant Method And Quant Cosine With IR.
/// </summary>
/// <param name="calibrationDataset">calibration dataset.</param>
/// <param name="target">target.</param>
/// <param name="rangeOfs">rangeOf nodes.</param>
/// <param name="childrenOfRangeOfs">rangeOf nodes children.</param>
/// <param name="runPassOptions">options.</param>
/// <param name="quantizeOptions">options.</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
Task<Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>> BindQuantMethodCosine(ICalibrationDatasetProvider calibrationDataset, ITarget target, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, RunPassOptions runPassOptions);
Task<Dictionary<ENode, List<Tuple<List<DataType>, List<List<QuantParam>>, float>>>> BindQuantMethodCosine(ICalibrationDatasetProvider calibrationDataset, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, QuantizeOptions quantizeOptions);
/// <summary>
/// AdaRound Weights.
/// </summary>
/// <param name="calibrationDataset">calibration dataset.</param>
/// <param name="target">target.</param>
/// <param name="rangeOfs">rangeOf nodes.</param>
/// <param name="childrenOfRangeOfs">rangeOf nodes children.</param>
/// <param name="runPassOptions">options.</param>
/// <param name="quantizeOptions">options.</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
Task AdaRoundWeights(ICalibrationDatasetProvider calibrationDataset, ITarget target, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, RunPassOptions runPassOptions);
Task AdaRoundWeights(ICalibrationDatasetProvider calibrationDataset, List<ENode> rangeOfs, List<ENode> childrenOfRangeOfs, QuantizeOptions quantizeOptions);
/// <summary>
/// Parse Target Dependent Options.
@ -56,21 +54,21 @@ public interface ITarget
/// </summary>
/// <param name="passManager">pass manager.</param>
/// <param name="options">compile options.</param>
void RegisterTargetDependentPass(PassManager passManager, CompileOptions options);
void RegisterTargetDependentPass(IPassManager passManager, CompileOptions options);
/// <summary>
/// Register Quantize Pass.
/// </summary>
/// <param name="passManager">pass manager.</param>
/// <param name="options">compile options.</param>
void RegisterQuantizePass(PassManager passManager, CompileOptions options);
void RegisterQuantizePass(IPassManager passManager, CompileOptions options);
/// <summary>
/// Register Target Dependent After Quant Pass.
/// </summary>
/// <param name="passManager"></param>
/// <param name="options">compile options.</param>
void RegisterTargetDependentAfterQuantPass(PassManager passManager, CompileOptions options);
void RegisterTargetDependentAfterQuantPass(IPassManager passManager, CompileOptions options);
/// <summary>
/// Create module builder.

View File

@ -10,7 +10,7 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Autofac" />
<PackageReference Include="DryIoc.dll" />
<PackageReference Include="libtorch-cpu-linux-x64" />
<PackageReference Include="libtorch-cpu-osx-x64" />
<PackageReference Include="libtorch-cpu-win-x64" />
@ -19,6 +19,7 @@
<!-- <PackageReference Include="libtorch-cuda-11.3-linux-x64" /> -->
<!-- <PackageReference Include="libtorch-cuda-11.3-win-x64" /> -->
<PackageReference Include="Microsoft.Extensions.Hosting.Abstractions" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
<PackageReference Include="Microsoft.Extensions.Options" />
<PackageReference Include="Microsoft.Toolkit.HighPerformance" />
<PackageReference Include="NetFabric.Hyperlinq" />

View File

@ -49,7 +49,7 @@ public static partial class Utility
/// <summary>
/// match a call with op type T
/// auto set first param
/// it's always used for Fake to NoFake Rule with ReplaceCall.
/// it's always used for Fake to NoFake Pass with ReplaceCall.
/// </summary>ReplaceParams
/// <param name="callName"></param>
/// <param name="opName"></param>

View File

@ -0,0 +1,12 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading.Tasks;
[assembly: InternalsVisibleTo("Nncase.Tests")]
[assembly: InternalsVisibleTo("Nncase.Tests.TestFixture")]

View File

@ -80,4 +80,25 @@ public class QuantizeOptions
/// Gets or sets a value indicating whether enable adaround to fine tune weights.
/// </summary>
public bool UseAdaRound { get; set; }
/// <summary>
/// Gets or sets quant type.
/// </summary>
public DataType QuantType { get; set; } = DataTypes.UInt8;
/// <summary>
/// Gets or sets weights quant type.
/// </summary>
public DataType WQuantType { get; set; } = DataTypes.UInt8;
/// <summary>
/// Gets or sets model quant mode.
/// </summary>
public ModelQuantMode ModelQuantMode { get; set; } = ModelQuantMode.NoQuant;
/// <summary>
/// Creates no quantization options.
/// </summary>
/// <returns>No quant options.</returns>
public static QuantizeOptions CreateNoQuant() => new QuantizeOptions();
}

View File

@ -1,18 +1,18 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Autofac;
using DryIoc;
using Nncase.Hosting;
namespace Nncase.Targets;
/// <summary>
/// Targets module.
/// </summary>
public class TargetsModule : Module
internal class TargetsModule : IApplicationPart
{
/// <inheritdoc/>
protected override void Load(ContainerBuilder builder)
public void ConfigureServices(IRegistrator registrator)
{
builder.RegisterType<TargetProvider>().AsImplementedInterfaces().SingleInstance();
registrator.Register<ITargetProvider, TargetProvider>(reuse: Reuse.Singleton);
}
}

View File

@ -256,6 +256,18 @@ public abstract partial class Tensor : IStructuralComparable, IStructuralEquatab
return new Tensor<T>(memory, dimensions);
}
/// <summary>
/// Create tensor from an array, Set the shape as [n].
/// </summary>
/// <typeparam name="T">CLR type.</typeparam>
/// <param name="array">Array.</param>
/// <returns>Created tensor.</returns>
public static Tensor<T> From<T>(T[] array)
where T : unmanaged, IEquatable<T>
{
return From(array.AsMemory());
}
/// <summary>
/// Create tensor from an array, Set the shape as provided.
/// </summary>
@ -266,7 +278,7 @@ public abstract partial class Tensor : IStructuralComparable, IStructuralEquatab
public static Tensor<T> From<T>(T[] array, ReadOnlySpan<int> dimensions)
where T : unmanaged, IEquatable<T>
{
return new Tensor<T>(array, dimensions);
return From(array.AsMemory(), dimensions);
}
/// <summary>

View File

@ -6,103 +6,45 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.Diagnostics;
using Nncase.IR;
using static TorchSharp.torch.nn;
namespace Nncase.Transform;
/// <summary>
/// the basic pass.
/// </summary>
public abstract class BasePass
{
/// <summary>
/// Initializes a new instance of the <see cref="BasePass"/> class.
/// the base pass ctor.
/// </summary>
/// <param name="name"></param>
public BasePass(string name)
{
Name = name;
}
/// <summary>
/// Gets the pass name.
/// </summary>
public string Name { get; init; }
}
/// <summary>
/// Pass in Callable scope.
/// </summary>
public abstract class FunctionPass : BasePass
public abstract class FunctionPass : Pass<BaseFunction>
{
/// <summary>
/// Initializes a new instance of the <see cref="FunctionPass"/> class.
/// </summary>
/// <param name="name">Name.</param>
public FunctionPass(string name)
: base(name)
public FunctionPass()
{
}
/// <summary>
/// Run current pass for specific function.
/// </summary>
/// <param name="callable">Target function.</param>
/// <param name="options">Options.</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
public async Task<BaseFunction> RunAsync(BaseFunction callable, RunPassOptions options)
/// <inheritdoc/>
protected override Task OnPassStartAsync(BaseFunction input, RunPassContext context)
{
var new_options = options.IndentDir(Name).IndentDir(callable.Name);
OnPassStart(callable, new_options);
var post = await RunCoreAsync(callable, new_options);
OnPassEnd(post, new_options);
return post;
}
/// <summary>
/// Run pass implementation for derived class.
/// </summary>
/// <param name="callable">Target function.</param>
/// <param name="options">Options.</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
protected abstract Task<BaseFunction> RunCoreAsync(BaseFunction callable, RunPassOptions options);
/// <summary>
/// the callback function you can custom process func with run pass options.
/// </summary>
/// <param name="callable"> func without run pass.</param>
/// <param name="options"></param>
protected virtual void OnPassStart(BaseFunction callable, RunPassOptions options)
{
switch (options.DumpLevel)
if (DumpScope.Current.IsEnabled(DumpFlags.PassIR))
{
case >= 2:
CompilerServices.DumpIR(callable, "Start", options.DumpDir);
break;
case >= 1:
break;
default:
break;
DumpScope.Current.DumpIR(input, "Start");
}
return Task.CompletedTask;
}
/// <summary>
/// the callback function you can custom process func with run pass options.
/// </summary>
/// <param name="callable"> func with rewrited. </param>
/// <param name="options"></param>
protected virtual void OnPassEnd(BaseFunction callable, RunPassOptions options)
/// <inheritdoc/>
protected override Task OnPassEndAsync(BaseFunction post, RunPassContext context)
{
switch (options.DumpLevel)
if (DumpScope.Current.IsEnabled(DumpFlags.PassIR))
{
case >= 2:
CompilerServices.DumpIR(callable, "End", options.DumpDir);
break;
case >= 1:
break;
default:
break;
DumpScope.Current.DumpIR(post, "End");
}
return Task.CompletedTask;
}
private protected override string? GetDumpRelativePass(BaseFunction input) => input.Name;
}

View File

@ -6,6 +6,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR;
namespace Nncase.Transform;
@ -14,8 +15,45 @@ namespace Nncase.Transform;
/// </summary>
public interface IEGraph
{
/// <summary>
/// Gets version.
/// </summary>
int Version { get; }
/// <summary>
/// Gets eclasses.
/// </summary>
IEnumerable<EClass> Classes { get; }
/// <summary>
/// Gets nodes.
/// </summary>
IEnumerable<ENode> Nodes { get; }
/// <summary>
/// Add expr, get the eclass id.
/// </summary>
/// <param name="expr">Expression.</param>
/// <returns>Eclass of this node.</returns>
EClass Add(Expr expr);
/// <summary>
/// Find eclass of enode.
/// </summary>
/// <param name="node">ENode.</param>
/// <returns>EClass.</returns>
EClass Find(ENode node);
/// <summary>
/// Union two equal Eclass.
/// </summary>
/// <param name="classA">class a.</param>
/// <param name="classB">class b.</param>
/// <returns>If version changed.</returns>
bool Union(EClass classA, EClass classB);
/// <summary>
/// After merge, we use rebuild get new dep information.
/// </summary>
void Rebuild();
}

View File

@ -22,7 +22,7 @@ public interface IRewriteProvider
/// <param name="rules">Rewrite rules.</param>
/// <param name="options">Options.</param>
/// <returns>Rewrited expression.</returns>
Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options);
Expr Rewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options);
}
/// <summary>
@ -37,5 +37,14 @@ public interface IEGraphRewriteProvider
/// <param name="rules">Rewrite rules.</param>
/// <param name="options">Options.</param>
/// <returns>Rewrited expression.</returns>
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options);
Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options);
/// <summary>
/// Rewrite egraph.
/// </summary>
/// <param name="eGraph">EGraph.</param>
/// <param name="rules">Rewrite rules.</param>
/// <param name="options">Options.</param>
/// <returns>Rewrited EGraph.</returns>
IEGraph ERewrite(IEGraph eGraph, IEnumerable<IRewriteRule> rules, RunPassContext options);
}

View File

@ -22,16 +22,7 @@ public interface IRewriteRule
/// </summary>
/// <param name="result">Match result.</param>
/// <returns>Replace expression or null if nothing changed.</returns>
Expr? GetReplace(IMatchResult result, RunPassOptions options);
/// <summary>
/// check this pattern can be modify in multi branch.
/// </summary>
/// <returns></returns>
bool IsMultiBranchSafe()
{
return false;
}
Expr? GetReplace(IMatchResult result, RunPassContext options);
}
/// <summary>

View File

@ -6,6 +6,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.Diagnostics;
using Nncase.IR;
namespace Nncase.Transform;
@ -13,72 +14,40 @@ namespace Nncase.Transform;
/// <summary>
/// Pass in Callable scope.
/// </summary>
public abstract class ModulePass : BasePass
public abstract class ModulePass : Pass<IRModule>
{
/// <summary>
/// Initializes a new instance of the <see cref="ModulePass"/> class.
/// </summary>
/// <param name="name">Name.</param>
public ModulePass(string name)
: base(name)
public ModulePass()
{
}
/// <summary>
/// Run current pass for module.
/// </summary>
/// <param name="module">Target module.</param>
/// <param name="options">Options.</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
public async Task RunAsync(IRModule module, RunPassOptions options)
/// <inheritdoc/>
protected override Task OnPassStartAsync(IRModule input, RunPassContext context)
{
var new_options = options.IndentDir(Name);
OnPassStart(module, new_options);
await RunCoreAsync(module, new_options);
OnPassEnd(module, new_options);
if (DumpScope.Current.IsEnabled(DumpFlags.PassIR))
{
foreach (var func in input.Functions)
{
DumpScope.Current.DumpIR(func, func.Name, "Start");
}
}
return Task.CompletedTask;
}
/// <summary>
/// Run pass implementation for derived class.
/// </summary>
/// <param name="module">Target module.</param>
/// <param name="options">Options.</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
protected abstract Task RunCoreAsync(IRModule module, RunPassOptions options);
/// <summary>
/// the callback function you can custom process func with run pass options.
/// </summary>
/// <param name="module"> module without run pass.</param>
/// <param name="options"></param>
protected virtual void OnPassStart(IRModule module, RunPassOptions options)
/// <inheritdoc/>
protected override Task OnPassEndAsync(IRModule post, RunPassContext context)
{
if (options.DumpLevel < 3)
if (DumpScope.Current.IsEnabled(DumpFlags.PassIR))
{
return;
foreach (var func in post.Functions)
{
DumpScope.Current.DumpIR(func, func.Name, "End");
}
}
foreach (var (func, i) in module.Functions.Select((func, i) => (func, i)))
{
CompilerServices.DumpIR(func, $"fn_{i}", Path.Combine(options.DumpDir, "Start"));
}
}
/// <summary>
/// the callback function you can custom process func with run pass options.
/// </summary>
/// <param name="module"> module with rewrited. </param>
/// <param name="options"></param>
protected virtual void OnPassEnd(IRModule module, RunPassOptions options)
{
if (options.DumpLevel < 3)
{
return;
}
foreach (var (func, i) in module.Functions.Select((func, i) => (func, i)))
{
CompilerServices.DumpIR(func, $"fn_{i}", Path.Combine(options.DumpDir, "End"));
}
return Task.CompletedTask;
}
}

View File

@ -0,0 +1,97 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.TIR;
namespace Nncase.Transform;
/// <summary>
/// IR or TIR transformer.
/// </summary>
public interface IPass
{
/// <summary>
/// Gets or sets pass name.
/// </summary>
string Name { get; set; }
}
/// <summary>
/// IR or TIR transformer.
/// </summary>
/// <typeparam name="T">Type to transform.</typeparam>
public abstract class Pass<T> : IPass
where T : class
{
private string? _name;
/// <summary>
/// Initializes a new instance of the <see cref="Pass{T}"/> class.
/// </summary>
internal Pass()
{
CompileSession = CompileSessionScope.GetCurrentThrowIfNull();
}
/// <inheritdoc/>
public string Name
{
get => _name ??= GetType().Name;
set => _name = value;
}
/// <summary>
/// Gets compile session.
/// </summary>
protected CompileSession CompileSession { get; }
/// <summary>
/// Run pass.
/// </summary>
/// <param name="input">Input object.</param>
/// <param name="context">Run pass context.</param>
/// <returns>Output object.</returns>
public async Task<T> RunAsync(T input, RunPassContext context)
{
using var sessionScope = new CompileSessionScope(CompileSession);
using var dumpScope = new DumpScope(Path.Join($"{context.Index}_{Name}", GetDumpRelativePass(input)));
await OnPassStartAsync(input, context);
var output = await RunCoreAsync(input, context);
await OnPassEndAsync(output, context);
return output;
}
/// <summary>
/// Run pass implementation for derived class.
/// </summary>
/// <param name="input">Input object.</param>
/// <param name="context">Run pass context.</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
protected abstract Task<T> RunCoreAsync(T input, RunPassContext context);
/// <summary>
/// The callback function you can custom process func with run pass context.
/// </summary>
/// <param name="input">Input object.</param>
/// <param name="context">Run pass context.</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
protected abstract Task OnPassStartAsync(T input, RunPassContext context);
/// <summary>
/// The callback function you can custom process func with run pass context.
/// </summary>
/// <param name="post">Post object.</param>
/// <param name="context">Run pass context.</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
protected abstract Task OnPassEndAsync(T post, RunPassContext context);
private protected virtual string? GetDumpRelativePass(T input) => null;
}

View File

@ -8,218 +8,231 @@ using System.Collections.Immutable;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using NetFabric.Hyperlinq;
using Nncase.Diagnostics;
using Nncase.IR;
[assembly: System.Runtime.CompilerServices.InternalsVisibleTo("Nncase.Tests")]
using Nncase.TIR;
namespace Nncase.Transform;
/// <summary>
/// Passes addable.
/// </summary>
public interface IPassesAddable
{
/// <summary>
/// Add pass.
/// </summary>
/// <typeparam name="T">Pass type.</typeparam>
/// <param name="parameters">Pass constructor parameters.</param>
/// <returns>Add result.</returns>
AddPassResult<T> Add<T>(params object[] parameters)
where T : class, IPass;
/// <summary>
/// Add pass with name.
/// </summary>
/// <typeparam name="T">Pass type.</typeparam>
/// <param name="name">Pass name.</param>
/// <param name="parameters">Pass constructor parameters.</param>
/// <returns>Add result.</returns>
AddPassResult<T> AddWithName<T>(string name, params object[] parameters)
where T : class, IPass;
}
/// <summary>
/// Pass manager.
/// </summary>
public class PassManager : IEnumerable<BasePass>
public interface IPassManager : IPassesAddable
{
private readonly IRModule _module;
private readonly RunPassOptions _options;
private readonly List<BasePass> _passes = new List<BasePass>();
private readonly Dictionary<BaseFunction, BaseFunction> _functions_update_map = new(ReferenceEqualityComparer.Instance);
private readonly Dictionary<int, BaseFunction> _functions_mask = new();
/// <summary>
/// Gets name.
/// </summary>
public string Name { get; }
/// <summary>
/// Initializes a new instance of the <see cref="PassManager"/> class.
/// Gets passes.
/// </summary>
/// <param name="module">Module.</param>
/// <param name="options">Options.</param>
public PassManager(IRModule module, RunPassOptions options)
{
_module = module;
_options = options;
}
/// <summary>
/// foreach fix when the call target function has been updated.
/// </summary>
public static void FuncUpdateDependence(IRModule module, Dictionary<BaseFunction, BaseFunction> update_map, RunPassOptions options, string name)
{
var mutator = new DependenceMutator(update_map);
_ = mutator.Visit(module.Entry!);
if (!mutator.IsMutated)
{
return;
}
for (int i = 0; i < module.Functions.Count; i++)
{
if (update_map.TryGetValue(module.Functions[i], out var updated_func))
{
module.Update(i, updated_func);
}
}
if (options.DumpLevel > 2)
{
foreach (var item in module.Functions)
{
CompilerServices.DumpIR(item, string.Empty, Path.Combine(options.DumpDir, name, $"FuncUpdateDependence"));
}
}
}
/// <summary>
/// Add function pass.
/// </summary>
/// <param name="pass">Pass.</param>
public void Add(BasePass pass)
{
_passes.Add(pass);
}
/// <inheritdoc/>
public IEnumerator<BasePass> GetEnumerator()
{
return ((IEnumerable<BasePass>)_passes).GetEnumerator();
}
public IReadOnlyList<IPass> Passes { get; }
/// <summary>
/// Run passes and update the module.
/// </summary>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
public async Task RunAsync()
{
var passes = _passes.AsEnumerable();
while (passes.Any())
{
var type = passes.First().GetType();
Type base_type;
if (type.IsSubclassOf(typeof(FunctionPass)))
{
base_type = typeof(FunctionPass);
}
else if (type.IsSubclassOf(typeof(ModulePass)))
{
base_type = typeof(ModulePass);
}
else
{
throw new ArgumentOutOfRangeException();
}
var candiate = passes.TakeWhile(item => item.GetType().IsSubclassOf(base_type));
passes = passes.Skip(candiate.Count());
if (type.IsSubclassOf(typeof(FunctionPass)))
{
await RunFunctionAsync(candiate.OfType<FunctionPass>());
}
else if (type.IsSubclassOf(typeof(ModulePass)))
{
await RunModuleAsync(candiate.OfType<ModulePass>());
}
else
{
throw new ArgumentOutOfRangeException();
}
}
}
/// <inheritdoc/>
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)_passes).GetEnumerator();
}
private async Task RunFunctionAsync(IEnumerable<FunctionPass> passes)
{
int i = 0;
string name = string.Empty;
while (i < _module.Functions.Count)
{
foreach (var pass in passes)
{
var pre = _module.Functions[i];
var post = await pass.RunAsync(pre, _options);
if (!object.ReferenceEquals(pre, post))
{
FuncUpdateRecord(i, pre, post);
_module.Update(i, post);
}
name = pass.Name;
}
i++;
}
FuncUpdateDependence(_module, _functions_update_map, _options, name);
CleanFuncUpdateRecord();
}
private async Task RunModuleAsync(IEnumerable<ModulePass> passes)
{
foreach (var pass in passes)
{
await pass.RunAsync(_module, _options);
}
}
private void CleanFuncUpdateRecord()
{
_functions_update_map.Clear();
_functions_mask.Clear();
}
private void FuncUpdateRecord(int i, BaseFunction current, BaseFunction updated)
{
// if function[i] has not been update, record it to origin function.
if (!_functions_mask.TryGetValue(i, out var origin))
{
origin = current;
_functions_mask.Add(i, origin);
}
_functions_update_map[origin] = updated;
}
/// <param name="module">Input module.</param>
/// <returns>A <see cref="Task{IRModule}"/> representing the asynchronous operation.</returns>
Task<IRModule> RunAsync(IRModule module);
}
/// <summary>
/// Update the function call dependencer.
/// Add pass result.
/// </summary>
internal sealed class DependenceMutator : DeepExprMutator
/// <typeparam name="T">Pass type.</typeparam>
public struct AddPassResult<T> : IPassesAddable
where T : class, IPass
{
public Dictionary<BaseFunction, BaseFunction> FunctionsUpdated;
private readonly IPassManager _passManager;
private readonly CompileSession _compileSession;
public DependenceMutator(Dictionary<BaseFunction, BaseFunction> functions_updated)
internal AddPassResult(IPassManager passManager, CompileSession compileSession, T pass)
{
FunctionsUpdated = functions_updated;
_passManager = passManager;
_compileSession = compileSession;
Pass = pass;
}
public override Expr DefaultMutateLeaf(Expr expr)
/// <summary>
/// Gets pass.
/// </summary>
public T Pass { get; }
/// <inheritdoc/>
public AddPassResult<TPass> Add<TPass>(params object[] parameters)
where TPass : class, IPass => _passManager.Add<TPass>(parameters);
/// <inheritdoc/>
public AddPassResult<TPass> AddWithName<TPass>(string name, params object[] parameters)
where TPass : class, IPass => _passManager.AddWithName<TPass>(name, parameters);
/// <summary>
/// Configure pass.
/// </summary>
/// <param name="configureRule">Configure pass action.</param>
/// <returns>This add result.</returns>
public AddPassResult<T> Configure(Action<T> configureRule)
{
if (expr is BaseFunction baseFunction && FunctionsUpdated.TryGetValue(baseFunction, out var updated_basefunc))
{
return updated_basefunc;
}
return expr;
}
public override Expr Visit(BaseFunction baseFunction)
{
// first time enter function, mutate
var nexpr = base.Visit(baseFunction);
if (nexpr is BaseFunction updatedBasefunction && !object.ReferenceEquals(baseFunction, updatedBasefunction))
{
if (FunctionsUpdated.ContainsKey(baseFunction))
{
FunctionsUpdated[baseFunction] = updatedBasefunction;
}
else
{
FunctionsUpdated.Add(baseFunction, updatedBasefunction);
}
}
return nexpr;
using var scope = new CompileSessionScope(_compileSession);
configureRule(Pass);
return this;
}
}
internal sealed class PassManager : IPassManager
{
private readonly CompileSession _compileSession;
private readonly IDumpper _dummper;
private readonly List<IPass> _passes = new List<IPass>();
/// <summary>
/// Initializes a new instance of the <see cref="PassManager"/> class.
/// </summary>
/// <param name="name">Pass manager name.</param>
/// <param name="compileSession">Compile session.</param>
public PassManager(string name, CompileSession compileSession)
{
Name = name;
_compileSession = compileSession;
_dummper = DumpScope.GetCurrent(compileSession).CreateSubDummper(name);
}
/// <summary>
/// Gets name.
/// </summary>
public string Name { get; }
/// <summary>
/// Gets passes.
/// </summary>
public IReadOnlyList<IPass> Passes => _passes;
/// <inheritdoc/>
public AddPassResult<T> Add<T>(params object[] parameters)
where T : class, IPass
{
using var scope = new CompileSessionScope(_compileSession);
using var dumpScope = new DumpScope(_dummper);
var pass = ActivatorUtilities.CreateInstance<T>(_compileSession, parameters);
_passes.Add(pass);
return new(this, _compileSession, pass);
}
/// <inheritdoc/>
public AddPassResult<T> AddWithName<T>(string name, params object[] parameters)
where T : class, IPass
{
var result = Add<T>(parameters);
result.Configure(p => p.Name = name);
return result;
}
/// <inheritdoc/>
public async Task<IRModule> RunAsync(IRModule module)
{
using var dumpScope = new DumpScope(_dummper);
for (int i = 0; i < _passes.Count; i++)
{
var task = _passes[i] switch
{
FunctionPass fp => RunAsync(module, i, fp),
PrimFuncPass pfp => RunAsync(module, i, pfp),
ModulePass mp => RunAsync(module, i, mp),
_ => throw new NotSupportedException($"Unsupported pass type: {_passes[i].GetType().AssemblyQualifiedName}"),
};
module = await task;
}
return module;
}
private async Task<IRModule> RunAsync(IRModule module, int passIndex, FunctionPass pass)
{
for (int i = 0; i < module.Functions.Count; i++)
{
var pre = module.Functions[i];
var context = new RunPassContext { Index = passIndex };
var post = await pass.RunAsync(pre, context);
if (!object.ReferenceEquals(pre, post))
{
if (_dummper.IsEnabled(DumpFlags.PassIR))
{
_dummper.DumpModule(module, $"Before_{passIndex}_{pass.Name}");
}
module.Replace(i, post);
if (_dummper.IsEnabled(DumpFlags.PassIR))
{
_dummper.DumpModule(module, $"After_{passIndex}_{pass.Name}");
}
}
}
return module;
}
private async Task<IRModule> RunAsync(IRModule module, int passIndex, PrimFuncPass pass)
{
for (int i = 0; i < module.Functions.Count; i++)
{
var pre = module.Functions[i];
if (pre is PrimFunction pf)
{
var context = new RunPassContext { Index = passIndex };
var post = await pass.RunAsync(pf, context);
if (!object.ReferenceEquals(pre, post))
{
if (_dummper.IsEnabled(DumpFlags.PassIR))
{
_dummper.DumpModule(module, $"Before_{passIndex}_{pass.Name}");
}
module.Replace(i, post);
if (_dummper.IsEnabled(DumpFlags.PassIR))
{
_dummper.DumpModule(module, $"After_{passIndex}_{pass.Name}");
}
}
}
}
return module;
}
private async Task<IRModule> RunAsync(IRModule module, int passIndex, ModulePass pass)
{
var context = new RunPassContext { Index = passIndex };
return await pass.RunAsync(module, context);
}
}

View File

@ -0,0 +1,112 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using Microsoft.Extensions.DependencyInjection;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.TIR;
namespace Nncase.Transform;
/// <summary>
/// TIR Mutator Pass.
/// NOTE only apply on prim func
/// Because of we will mutate the expression multiple times, so use MutatorCreator create the new mutator.
/// </summary>
public class PrimFuncPass : Pass<PrimFunction>
{
private readonly List<MutatorDescriptor> _mutatorDescriptors = new();
/// <summary>
/// Initializes a new instance of the <see cref="PrimFuncPass"/> class.
/// </summary>
public PrimFuncPass()
{
}
/// <summary>
/// Add mutator.
/// </summary>
/// <typeparam name="T">Mutator type.</typeparam>
/// <param name="configureMutator">Configure mutator action.</param>
/// <param name="arguments">Mutator's constructor arguments.</param>
/// <returns>This primfunc pass.</returns>
public PrimFuncPass Add<T>(Action<T>? configureMutator, params object[] arguments)
where T : ExprMutator
{
_mutatorDescriptors.Add(new()
{
Factory = ActivatorUtilities.CreateFactory(typeof(T), arguments.Select(x => x.GetType()).ToArray()),
Configure = configureMutator == null ? null : x => configureMutator?.Invoke((T)x),
Arguments = arguments,
});
return this;
}
/// <inheritdoc/>
protected override Task<PrimFunction> RunCoreAsync(PrimFunction input, RunPassContext context)
{
var post = input;
int count = 0;
bool isMutated = false;
do
{
foreach (var descriptor in _mutatorDescriptors)
{
var mutator = descriptor.Activate(CompileSession);
post = (PrimFunction)mutator.Visit(post);
isMutated = mutator.IsMutated;
if (isMutated)
{
var typeInferSuccess = CompilerServices.InferenceType(post);
if (DumpScope.Current.IsEnabled(DumpFlags.PassIR))
{
DumpScope.Current.DumpIR(post, $"{count++}_{mutator.GetType().Name}");
}
Trace.Assert(typeInferSuccess);
break;
}
}
if (!isMutated)
{
break;
}
}
while (true);
return Task.FromResult(post);
}
/// <inheritdoc/>
protected override Task OnPassStartAsync(PrimFunction input, RunPassContext context) => Task.CompletedTask;
/// <inheritdoc/>
protected override Task OnPassEndAsync(PrimFunction post, RunPassContext context) => Task.CompletedTask;
private protected override string? GetDumpRelativePass(PrimFunction input) => input.Name;
private struct MutatorDescriptor
{
public ObjectFactory Factory;
public Action<ExprMutator>? Configure;
public object[] Arguments;
public ExprMutator Activate(CompileSession compileSession)
{
using var scope = new CompileSessionScope(compileSession);
var mutator = (ExprMutator)Factory(compileSession, Arguments);
Configure?.Invoke(mutator);
return mutator;
}
}
}

View File

@ -20,7 +20,7 @@ public abstract class QuantRule : RewriteRule<Pattern>
/// <summary>
/// NOTE the option will be set by SourceGenerator when the GetReplace called.
/// </summary>
public RunPassOptions Option = null!;
public RunPassContext Option = null!;
/// <summary>
/// the match result
@ -36,22 +36,22 @@ public abstract class QuantRule : RewriteRule<Pattern>
/// <summary>
/// Gets get ModelQuantMode.
/// </summary>
public ModelQuantMode ModelQuantMode => Option.CompileOptions.ModelQuantMode;
public ModelQuantMode ModelQuantMode => CompileSession.CompileOptions.QuantizeOptions.ModelQuantMode;
/// <summary>
/// Gets get QuantType.
/// </summary>
public DataType QuantType => Option.CompileOptions.QuantType;
public DataType QuantType => CompileSession.CompileOptions.QuantizeOptions.QuantType;
/// <summary>
/// Gets get WQuantType.
/// </summary>
public DataType WQuantType => Option.CompileOptions.WQuantType;
public DataType WQuantType => CompileSession.CompileOptions.QuantizeOptions.WQuantType;
/// <summary>
/// Gets a value indicating whether get UseMixQuant flag.
/// </summary>
public bool UseMixQuant => Option.CompileOptions.QuantizeOptions.BindQuantMethod;
public bool UseMixQuant => CompileSession.CompileOptions.QuantizeOptions.BindQuantMethod;
/// <summary>
/// check the datatype is the quant type.

View File

@ -18,6 +18,14 @@ namespace Nncase.Transform;
public abstract class RewriteRule<TPattern> : IRewriteRule
where TPattern : Pattern
{
/// <summary>
/// Initializes a new instance of the <see cref="RewriteRule{TPattern}"/> class.
/// </summary>
public RewriteRule()
{
CompileSession = CompileSessionScope.GetCurrentThrowIfNull();
}
/// <summary>
/// Gets pattern.
/// </summary>
@ -25,12 +33,11 @@ public abstract class RewriteRule<TPattern> : IRewriteRule
IPattern IRewriteRule.Pattern => Pattern;
/// <inheritdoc/>
public bool IsMultiBranchSafe { get; init; }
/// <summary>
/// Gets compile session.
/// </summary>
protected CompileSession CompileSession { get; }
/// <inheritdoc/>
bool IRewriteRule.IsMultiBranchSafe() => IsMultiBranchSafe;
/// <inheritdoc/>
public abstract Expr? GetReplace(IMatchResult result, RunPassOptions options);
public abstract Expr? GetReplace(IMatchResult result, RunPassContext options);
}

View File

@ -7,92 +7,102 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR;
using Microsoft.Extensions.DependencyInjection;
using Nncase.TIR;
namespace Nncase.Transform;
public abstract class RulesPass : FunctionPass, IEnumerable<IRewriteRule>
/// <summary>
/// Rules addable.
/// </summary>
public interface IRulesAddable
{
private readonly List<IRewriteRule> _rules = new();
/// <summary>
/// Add the rewrite rule.
/// </summary>
/// <typeparam name="T">Rule type.</typeparam>
/// <param name="parameters">Rule's constructor parameters.</param>
/// <returns>Add result.</returns>
RulesPass.AddResult<T> Add<T>(params object[] parameters)
where T : class, IRewriteRule;
/// <summary>
/// Initializes a new instance of the <see cref="RulesPass"/> class.
/// Add the rewrite rule.
/// </summary>
/// <param name="name">Name.</param>
public RulesPass(string name)
: base(name)
{
}
/// <param name="ruleType">Rule type.</param>
/// <param name="parameters">Rule's constructor parameters.</param>
/// <returns>Add result.</returns>
RulesPass.AddResult<IRewriteRule> Add(Type ruleType, params object[] parameters);
}
/// <summary>
/// Pass contains rewrite rules.
/// </summary>
public abstract class RulesPass : FunctionPass, IRulesAddable
{
private readonly List<IRewriteRule> _rules = new();
/// <summary>
/// Gets rules.
/// </summary>
public IReadOnlyList<IRewriteRule> Rules => _rules;
/// <summary>
/// add the pattern rule.
/// </summary>
/// <param name="rule">Rule.</param>
public void Add(IRewriteRule rule) => _rules.Add(rule);
/// <summary>
/// add the pattern rules.
/// </summary>
/// <param name="rules">Rules.</param>
public void Add(params IRewriteRule[] rules) => _rules.AddRange(rules);
/// <summary>
/// <see cref="Add(IRewriteRule[])"/>.
/// </summary>
/// <param name="rules">Rules.</param>
public void Add(IEnumerable<IRewriteRule> rules) => _rules.AddRange(rules);
/// <inheritdoc/>
public AddResult<T> Add<T>(params object[] parameters)
where T : class, IRewriteRule
{
using var scope = new CompileSessionScope(CompileSession);
var rule = ActivatorUtilities.CreateInstance<T>(CompileSession, parameters);
_rules.Add(rule);
return new(this, rule);
}
/// <inheritdoc/>
public IEnumerator<IRewriteRule> GetEnumerator()
public AddResult<IRewriteRule> Add(Type ruleType, params object[] parameters)
{
return _rules.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
using var scope = new CompileSessionScope(CompileSession);
var rule = (IRewriteRule)ActivatorUtilities.CreateInstance(CompileSession, ruleType, parameters);
_rules.Add(rule);
return new(this, rule);
}
/// <summary>
/// the callback function you can custom process func with run pass options.
/// Add rule result.
/// </summary>
/// <param name="callable"> func without run pass.</param>
/// <param name="options">Options.</param>
protected override void OnPassStart(BaseFunction callable, RunPassOptions options)
/// <typeparam name="T">Pass type.</typeparam>
public struct AddResult<T> : IRulesAddable
where T : class, IRewriteRule
{
switch (options.DumpLevel)
private readonly RulesPass _rulesPass;
internal AddResult(RulesPass rulesPass, T rule)
{
case >= 2:
CompilerServices.DumpIR((Expr)callable, "Start", options.DumpDir);
break;
case >= 1:
break;
default:
break;
_rulesPass = rulesPass;
Rule = rule;
}
}
/// <summary>
/// the callback function you can custom process func with run pass options.
/// </summary>
/// <param name="callable"> func with rewrited. </param>
/// <param name="options">Options.</param>
protected override void OnPassEnd(BaseFunction callable, RunPassOptions options)
{
switch (options.DumpLevel)
/// <summary>
/// Gets rule.
/// </summary>
public T Rule { get; }
/// <inheritdoc/>
public AddResult<TRule> Add<TRule>(params object[] parameters)
where TRule : class, IRewriteRule => _rulesPass.Add<TRule>(parameters);
/// <inheritdoc/>
public AddResult<IRewriteRule> Add(Type ruleType, params object[] parameters)
=> _rulesPass.Add(ruleType, parameters);
/// <summary>
/// Configure rule.
/// </summary>
/// <param name="configureRule">Configure rule action.</param>
/// <returns>This add result.</returns>
public AddResult<T> Configure(Action<T> configureRule)
{
case >= 2:
CompilerServices.DumpIR((Expr)callable, "End", options.DumpDir);
break;
case >= 1:
break;
default:
break;
configureRule(Rule);
return this;
}
}
}

View File

@ -0,0 +1,35 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.Diagnostics;
using Nncase.PatternMatch;
namespace Nncase.Transform;
/// <summary>
/// Options for running pass.
/// </summary>
public sealed record RunPassContext
{
/// <summary>
/// Gets or sets pass index in a <see cref="PassManager"/>.
/// </summary>
public int Index { get; set; }
/// <summary>
/// Gets or sets a value indicating whether control rewrite once or not.
/// Default is false.
/// </summary>
public bool RewriteOnce { get; set; }
/// <summary>
/// Gets or sets the match option.
/// </summary>
public MatchOptions MatchOptions { get; set; } = new MatchOptions();
}

View File

@ -1,169 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.PatternMatch;
namespace Nncase.Transform
{
/// <summary>
/// Options for running pass.
/// </summary>
public class RunPassOptions
{
/// <summary>
/// Initializes a new instance of the <see cref="RunPassOptions"/> class.
/// </summary>
/// <param name="compileOptions"></param>
public RunPassOptions(CompileOptions compileOptions)
{
Target = CompilerServices.GetTarget(compileOptions.Target);
DumpLevel = compileOptions.DumpLevel;
DumpDir = compileOptions.DumpDir;
CompileOptions = compileOptions;
PassName = string.Empty;
}
/// <summary>
/// Initializes a new instance of the <see cref="RunPassOptions"/> class.
/// parameterless ctor.
/// </summary>
public RunPassOptions(ITarget target)
{
Target = target;
DumpLevel = CompilerServices.CompileOptions.DumpLevel;
DumpDir = CompilerServices.CompileOptions.DumpDir;
CompileOptions = CompilerServices.CompileOptions;
PassName = string.Empty;
}
/// <summary>
/// Initializes a new instance of the <see cref="RunPassOptions"/> class.
/// constructor.
/// </summary>
/// <param name="target"> target device. </param>
/// <param name="dumpLevel"> int level. </param>
/// <param name="dumpDir"> dir. </param>
public RunPassOptions(ITarget target, int dumpLevel, string dumpDir)
: this(target, dumpLevel, dumpDir, CompilerServices.CompileOptions)
{
}
/// <summary>
/// Initializes a new instance of the <see cref="RunPassOptions"/> class.
/// create the run pass options.
/// </summary>
/// <param name="target"></param>
/// <param name="dumpLevel"></param>
/// <param name="dumpDir"></param>
/// <param name="options"></param>
public RunPassOptions(ITarget target, int dumpLevel, string dumpDir, CompileOptions options)
{
Target = target;
DumpLevel = dumpLevel;
DumpDir = dumpDir;
PassName = string.Empty;
RewriteOnce = false;
CompileOptions = options;
}
/// <summary>
/// Initializes a new instance of the <see cref="RunPassOptions"/> class.
/// copy construct.
/// </summary>
/// <param name="other"></param>
public RunPassOptions(RunPassOptions other)
{
Target = other.Target;
DumpLevel = other.DumpLevel;
DumpDir = other.DumpDir;
PassName = other.PassName;
RewriteOnce = other.RewriteOnce;
CompileOptions = other.CompileOptions;
MatchOptions = other.MatchOptions;
}
/// <summary>
/// Gets the invalid pass.
/// </summary>
public static RunPassOptions Invalid => new RunPassOptions(null!, -1, string.Empty);
/// <summary>
/// Gets target.
/// </summary>
public ITarget Target { get; }
/// <summary>
/// Gets dump level 0 = do nothing
/// Dump level 1 = print to std output
/// Dump level 2 = print dump to file.
/// </summary>
public int DumpLevel { get; private set; }
/// <summary>
/// Gets dump dir.
/// </summary>
public string DumpDir { get; private set; }
/// <summary>
/// Gets current pass name.
/// </summary>
public string PassName { get; private set; }
/// <summary>
/// Gets a value indicating whether control rewrite once or not.
/// Default is false.
/// </summary>
public bool RewriteOnce { get; private set; }
/// <summary>
/// Gets or sets get the compile options.
/// </summary>
public CompileOptions CompileOptions { get; set; }
/// <summary>
/// Gets or sets the match option.
/// </summary>
public MatchOptions MatchOptions { get; set; } = new MatchOptions();
/// <summary>
/// set the pass name.
/// </summary>
/// <param name="name"></param>
/// <returns></returns>
public RunPassOptions SetPassName(string name) => new(Target, DumpLevel, DumpDir, CompileOptions) { PassName = name, MatchOptions = MatchOptions };
/// <summary>
/// set the dumpDir.
/// </summary>
/// <param name="path"></param>
/// <returns></returns>
public RunPassOptions SetDumpDir(string path) => new(Target, DumpLevel, path, CompileOptions) { PassName = PassName, MatchOptions = MatchOptions };
/// <summary>
/// set the dump level.
/// </summary>
/// <param name="dumpLevel"></param>
/// <returns></returns>
public RunPassOptions SetDumpLevel(int dumpLevel) => new(Target, dumpLevel, DumpDir, CompileOptions) { PassName = PassName, MatchOptions = MatchOptions };
/// <summary>
/// set the RewriteOnce.
/// </summary>
/// <param name="once"></param>
/// <returns></returns>
public RunPassOptions SetRewriteOnce(bool once) => new(Target, DumpLevel, DumpDir, CompileOptions) { PassName = PassName, RewriteOnce = once, MatchOptions = MatchOptions };
/// <summary>
/// indent the dumpDir.
/// </summary>
/// <param name="path"></param>
/// <returns></returns>
public RunPassOptions IndentDir(string path) => new(Target, DumpLevel, Path.Combine(DumpDir, path), CompileOptions) { PassName = PassName, MatchOptions = MatchOptions };
}
}

View File

@ -1,145 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Nncase.IR;
using Nncase.TIR;
namespace Nncase.Transform
{
/// <summary>
/// TIR Mutator Pass.
/// NOTE only apply on prim func
/// Because of we will mutate the expression multiple times, so use MutatorCreator create the new mutator.
/// </summary>
public class PrimFuncPass : FunctionPass, IEnumerable<Func<ExprMutator>>
{
/// <summary>
/// Save rules.
/// </summary>
public readonly List<Func<ExprMutator>> MutatorCreators = new();
/// <summary>
/// Initializes a new instance of the <see cref="PrimFuncPass"/> class.
/// </summary>
/// <param name="name">Name.</param>
public PrimFuncPass(string name)
: base(name)
{
}
/// <inheritdoc/>
public IEnumerator<Func<ExprMutator>> GetEnumerator()
{
return ((IEnumerable<Func<ExprMutator>>)MutatorCreators).GetEnumerator();
}
/// <inheritdoc/>
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)MutatorCreators).GetEnumerator();
}
/// <summary>
/// add the mutator.
/// </summary>
/// <param name="mutator"></param>
public void Add(Func<ExprMutator> mutator)
{
MutatorCreators.Add(mutator);
}
/// <inheritdoc/>
protected override Task<BaseFunction> RunCoreAsync(BaseFunction callable, RunPassOptions options)
{
if (callable is not PrimFunction)
{
return Task.FromResult(callable);
}
var post = callable;
var last = post;
int count = 0;
var typeinfer_ret = true;
do
{
bool isMutated = false;
foreach (var creator in MutatorCreators)
{
var mutator = creator();
string mutator_name = mutator.GetType().Name;
last = post;
post = (BaseFunction)mutator.Visit(last);
isMutated = mutator.IsMutated;
if (isMutated)
{
typeinfer_ret = CompilerServices.InferenceType(post);
OnMutated(post, $"{count++}_{mutator.GetType().Name}", options);
if (!typeinfer_ret)
{
throw new InvalidOperationException($"{Name}: After Run Mutator {count - 1}_{mutator_name} , The Type Inference Failed!");
}
break;
}
}
if (!isMutated)
{
break;
}
}
while (true);
return Task.FromResult(post);
}
/// <summary>
/// the callback function you can custom process func with run pass options.
/// </summary>
/// <param name="callable"> func without run pass.</param>
/// <param name="options"></param>
protected override void OnPassStart(BaseFunction callable, RunPassOptions options)
{
if (callable is not PrimFunction)
{
return;
}
base.OnPassStart(callable, options);
}
/// <summary>
/// the callback function you can custom process func with run pass options.
/// </summary>
/// <param name="callable"> func with rewrited. </param>
/// <param name="options"></param>
protected override void OnPassEnd(BaseFunction callable, RunPassOptions options)
{
if (callable is not PrimFunction)
{
return;
}
base.OnPassEnd(callable, options);
}
private void OnMutated(BaseFunction callable, string prefix, RunPassOptions options)
{
switch (options.DumpLevel)
{
case >= 2:
CompilerServices.DumpIR((Expr)callable, prefix, options.DumpDir, false);
break;
case >= 1:
break;
default:
break;
}
}
}
}

View File

@ -68,17 +68,6 @@ public static class ValueDumper
DumpTensors(tensorValue.Select(x => x.AsTensor()).ToArray(), sr);
}
}
public static string GetMaybeDumpDir(string dir)
{
var root = Path.Join(CompilerServices.CompileOptions.DumpDir, dir);
if (!Directory.Exists(root))
{
Directory.CreateDirectory(root);
}
return root;
}
}
public static class DumpUtility
@ -169,93 +158,3 @@ public static class DumpUtility
}
}
}
public class DumpManager
{
public static bool Append;
public static int Count = 1;
public static string Dir;
public static bool OpenDump { get; private set; }
public string CountStr => Count.ToString();
public static void RunWithDump(string dir, Action f)
{
RunWithDump<int>(dir, () =>
{
f();
// discard return value
return -1;
});
}
public static T RunWithDump<T>(string dir, Func<T> f)
{
Dir = dir;
Count = 1;
OpenDump = true;
Append = false;
var result = f();
OpenDump = false;
return result;
}
public string GetMaybeDumpDir()
{
return ValueDumper.GetMaybeDumpDir(Dir);
}
protected void UpdateOrder(string root, string target, Shape shape)
{
using (var order = new StreamWriter(Path.Join(root, "!out_shape_list"), Append))
{
order.WriteLine($"{target}: {DumpUtility.SerializeShape(shape)}");
}
}
protected void DumpCallParam(string target, ParameterInfo info, Action<StreamWriter> f)
{
var path = Path.Join(GetMaybeDumpDir(), $"{CountStr}${target}${info.Name}");
using (var sr = new StreamWriter(path))
{
f(sr);
}
}
protected void DumpCall(string target, Shape shape, Action<StreamWriter> f)
{
var path = Path.Join(GetMaybeDumpDir(), $"{CountStr}${target}");
using (var sr = new StreamWriter(path))
{
f(sr);
}
UpdateOrder(GetMaybeDumpDir(), target, shape);
Append = true;
++Count;
}
}
public class Counter
{
private int _count;
public Counter(int count = 0)
{
_count = count;
}
public T Run<T>(Func<int, T> f)
{
return f(_count++);
}
public void Run(Action<int> f)
{
f(_count++);
}
}

View File

@ -44,7 +44,7 @@ public class ReplaceUtility
/// usage:
/// Call(FakeXXX, input, otherArg1, ...)
/// newInput => Call(op, newInput, otherArg1, ...)
/// it's always used for Fake to NoFake Rule with IsWildcardCall.
/// it's always used for Fake to NoFake Pass with IsWildcardCall.
/// </summary>
/// <param name="call"></param>
/// <param name="op"></param>

View File

@ -0,0 +1,19 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using DryIoc;
using Nncase.Diagnostics;
using Nncase.Hosting;
namespace Nncase.Diagnostics;
/// <summary>
/// Diagnostics module.
/// </summary>
internal class DiagnosticsModule : IApplicationPart
{
public void ConfigureServices(IRegistrator registrator)
{
registrator.Register<IDumpperFactory, DumpperFactory>(reuse: Reuse.Scoped);
}
}

View File

@ -0,0 +1,85 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR;
namespace Nncase.Diagnostics;
internal sealed class Dumpper : IDumpper
{
private readonly DumpFlags _dumpFlags;
private readonly string _dumpDirectory;
public Dumpper(DumpFlags dumpFlags, string dumpDirectory)
{
_dumpFlags = dumpFlags;
_dumpDirectory = dumpDirectory;
}
public string Directory => _dumpDirectory;
public IDumpper CreateSubDummper(string subDirectory)
{
return new Dumpper(_dumpFlags, Path.Join(_dumpDirectory, subDirectory));
}
public void DumpIR(Expr expr, string prefix, string? reletivePath = null)
{
var path = Path.Join(_dumpDirectory, reletivePath);
CompilerServices.DumpIR(expr, prefix, EnsureWritable(path));
}
public void DumpDotIR(Expr expr, string prefix, string? reletivePath = null)
{
var path = Path.Join(_dumpDirectory, reletivePath);
CompilerServices.DumpDotIR(expr, prefix, EnsureWritable(path));
}
public void DumpModule(IRModule module, string? reletivePath = null)
{
foreach (var func in module.Functions)
{
DumpIR(func, string.Empty, reletivePath);
}
}
public bool IsEnabled(DumpFlags dumpFlags)
{
return _dumpFlags.HasFlag(dumpFlags);
}
public Stream OpenFile(string reletivePath, FileMode fileMode)
{
var path = Path.Join(_dumpDirectory, reletivePath);
return File.Open(EnsureWritable(path), fileMode);
}
private static string EnsureWritable(string path)
{
var directory = Path.GetDirectoryName(path) ?? throw new ArgumentException($"Invalid path: {path}");
System.IO.Directory.CreateDirectory(directory);
return path;
}
}
internal sealed class DumpperFactory : IDumpperFactory
{
private readonly CompileOptions _compileOptions;
public DumpperFactory(CompileOptions compileOptions)
{
_compileOptions = compileOptions;
}
public IDumpper Root => new Dumpper(_compileOptions.DumpFlags, _compileOptions.DumpDir);
public IDumpper CreateDummper(string relativePath)
{
return new Dumpper(_compileOptions.DumpFlags, Path.Join(_compileOptions.DumpDir, relativePath));
}
}

View File

@ -0,0 +1,32 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using DryIoc;
using Nncase.Converters;
using Nncase.Diagnostics;
using Nncase.Hosting;
using Nncase.Targets;
namespace Nncase;
/// <summary>
/// Diagnostics application part extensions.
/// </summary>
public static class DiagnosticsApplicationPart
{
/// <summary>
/// Add diagnostics assembly.
/// </summary>
/// <param name="registrator">Service registrator.</param>
/// <returns>Configured service registrator.</returns>
public static IRegistrator AddDiagnostics(this IRegistrator registrator)
{
return registrator.RegisterModule<DiagnosticsModule>();
}
}

View File

@ -0,0 +1,15 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<RootNamespace>Nncase</RootNamespace>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\Nncase.Core\Nncase.Core.csproj" />
</ItemGroup>
</Project>

View File

@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Drawing;
using System.IO;
using System.Linq;
using GiGraph.Dot.Entities.Clusters;
using GiGraph.Dot.Entities.Graphs;
@ -23,12 +24,12 @@ namespace Nncase.Transform;
public partial class EGraphPrinter
{
internal static DotGraph DumpEgraphAsDot(EGraph eGraph, CostModel.EGraphCostModel costModel, EClass entry, string file)
internal static DotGraph DumpEgraphAsDot(EGraph eGraph, CostModel.EGraphCostModel costModel, EClass entry, Stream file)
{
var printer = new EGraphPrinter(eGraph);
printer.ConvertEGraphAsDot();
printer.AttachEGraphCost(costModel, entry);
return printer.SaveToFile(file);
return printer.SaveToStream(file);
}
private DotGraph AttachEGraphCost(CostModel.EGraphCostModel costModel, EClass entry)

View File

@ -7,6 +7,10 @@ using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using DryIoc;
using Nncase.Hosting;
using Nncase.PatternMatch;
using Nncase.Transform;
namespace Nncase;
@ -18,11 +22,11 @@ public static class EGraphApplicationPart
/// <summary>
/// Add egraph assembly.
/// </summary>
/// <param name="assemblies">Assembly collection.</param>
/// <returns>Updated assembly collection.</returns>
public static IList<Assembly> AddEGraph(this IList<Assembly> assemblies)
/// <param name="registrator">Service registrator.</param>
/// <returns>Configured service registrator.</returns>
public static IRegistrator AddEGraph(this IRegistrator registrator)
{
assemblies.Add(typeof(EGraphApplicationPart).Assembly);
return assemblies;
return registrator.RegisterModule<PatternMatchModule>()
.RegisterModule<TransformModule>();
}
}

View File

@ -1,18 +1,19 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Autofac;
using DryIoc;
using Nncase.Hosting;
using Nncase.Transform;
namespace Nncase.PatternMatch;
/// <summary>
/// PatternMatch module.
/// </summary>
public class PatternMatchModule : Module
internal class PatternMatchModule : IApplicationPart
{
/// <inheritdoc/>
protected override void Load(ContainerBuilder builder)
public void ConfigureServices(IRegistrator registrator)
{
builder.RegisterType<EGraphMatchProvider>().AsImplementedInterfaces().SingleInstance();
registrator.RegisterManyInterface<EGraphMatchProvider>(reuse: Reuse.Singleton);
}
}

View File

@ -46,24 +46,16 @@ public sealed partial class EGraph : IEGraph
Add(expr);
}
/// <summary>
/// Gets eclasses.
/// </summary>
/// <inheritdoc/>
public IEnumerable<EClass> Classes => _classes;
/// <inheritdoc/>
public IEnumerable<ENode> Nodes => _nodes.Keys;
/// <summary>
/// Gets version.
/// </summary>
/// <inheritdoc/>
public int Version => _version;
/// <summary>
/// Add expr, get the eclass id.
/// </summary>
/// <param name="expr">Expression.</param>
/// <returns>Eclass of this node.</returns>
/// <inheritdoc/>
public EClass Add(Expr expr)
{
if (expr.CheckedType is null)
@ -75,22 +67,13 @@ public sealed partial class EGraph : IEGraph
return converter.Visit(expr);
}
/// <summary>
/// Find eclass of enode.
/// </summary>
/// <param name="node">ENode.</param>
/// <returns>EClass.</returns>
/// <inheritdoc/>
public EClass Find(ENode node)
{
return _nodes[node].Class;
}
/// <summary>
/// Union two equal Eclass.
/// </summary>
/// <param name="classA">class a.</param>
/// <param name="classB">class b.</param>
/// <returns>If version changed.</returns>
/// <inheritdoc/>
public bool Union(EClass classA, EClass classB)
{
classA = classA.Find();
@ -123,9 +106,7 @@ public sealed partial class EGraph : IEGraph
return true;
}
/// <summary>
/// After merge, we use rebuild get new dep information.
/// </summary>
/// <inheritdoc/>
public void Rebuild()
{
while (_worklist.Count > 0)
@ -190,7 +171,7 @@ public sealed partial class EGraph : IEGraph
_nodes.Remove(enode);
originalClass.RemoveNode(enode);
// 2. Update node's children
// 2. Replace node's children
var newNode = enode.Canonicalize();
if (_nodes.TryGetValue(newNode, out var existingEntry))

View File

@ -6,7 +6,9 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using LanguageExt.ClassInstances;
using Nncase.CostModel;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.PatternMatch;
using static Nncase.PatternMatch.F.Math;
@ -25,9 +27,8 @@ public static class EGraphExtractExtensions
/// <param name="eGraph">eGraph.</param>
/// <param name="root">Root eclass.</param>
/// <param name="basefunc_cost_evaluator">base func cost evaluator.</param>
/// <param name="options">Options.</param>
/// <returns>Extracted root expression.</returns>
public static Expr Extract(this EGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, RunPassOptions options)
public static Expr Extract(this EGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator)
{
// 1. set the all expr checked shape
foreach (var eclass in eGraph.Classes)
@ -43,12 +44,13 @@ public static class EGraphExtractExtensions
// 2. start the cost evaluator
var costModel = new EGraphCostEvaluator(root.Find(), basefunc_cost_evaluator).Evaluate();
if (options.DumpLevel > 2)
if (DumpScope.Current.IsEnabled(DumpFlags.EGraphCost))
{
EGraphPrinter.DumpEgraphAsDot(eGraph, costModel, root.Find(), Path.Combine(options.DumpDir, "Costs", $"V{eGraph.Version}"));
using var fs = DumpScope.Current.OpenFile(Path.Combine("Costs", $"V{eGraph.Version}.dot"));
EGraphPrinter.DumpEgraphAsDot(eGraph, costModel, root.Find(), fs);
}
return new EGraphExtractor(costModel, options).Extract(root.Find());
return new EGraphExtractor(costModel).Extract(root.Find());
}
/// <summary>
@ -93,19 +95,29 @@ public static class EGraphExtractExtensions
internal class EGraphExtractor
{
private readonly EGraphCostModel _costModel;
private readonly RunPassOptions _options;
private readonly Dictionary<EClass, Expr> _eclassMemo = new();
private readonly Dictionary<EClass, Expr> _markerEclassMemo = new();
private StreamWriter? _dumpWriter;
public EGraphExtractor(EGraphCostModel costModel, RunPassOptions options)
public EGraphExtractor(EGraphCostModel costModel)
{
_costModel = costModel;
_options = options;
}
public Expr Extract(EClass root)
{
Visit(root);
_dumpWriter = DumpScope.Current.IsEnabled(DumpFlags.EGraphCost)
? new StreamWriter(DumpScope.Current.OpenFile($"{nameof(EGraphExtractor)}_Class_{root.Id}.txt"))
: null;
try
{
Visit(root);
}
finally
{
_dumpWriter?.Dispose();
}
return _eclassMemo[root];
}
@ -219,17 +231,17 @@ internal class EGraphExtractor
var parameters = children.Skip(1);
// for mix quant debug.
if (call.EnodeQuantConfigWithCosine != null && _options.DumpLevel > 3)
if (call.EnodeQuantConfigWithCosine != null && _dumpWriter != null)
{
Console.WriteLine(call + " " + call.CheckedType);
_dumpWriter.WriteLine(call + " " + call.CheckedType);
for (int i = 0; i < call.EnodeQuantConfigWithCosine.Count; i++)
{
for (int j = 0; j < call.EnodeQuantConfigWithCosine[i].Item1.Count; j++)
{
Console.Write(call.EnodeQuantConfigWithCosine[i].Item1[j] + " ");
_dumpWriter.Write(call.EnodeQuantConfigWithCosine[i].Item1[j] + " ");
}
Console.WriteLine(call.EnodeQuantConfigWithCosine[i].Item3);
_dumpWriter.WriteLine(call.EnodeQuantConfigWithCosine[i].Item3);
}
}

View File

@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Drawing;
using System.IO;
using GiGraph.Dot.Entities.Clusters;
using GiGraph.Dot.Entities.Graphs;
using GiGraph.Dot.Entities.Nodes;
@ -23,12 +24,12 @@ public partial class EGraphPrinter
{
private readonly Color[] _knownColors = new Color[] { Color.MediumAquamarine, Color.MediumBlue, Color.MediumOrchid, Color.MediumPurple, Color.MediumSeaGreen, Color.MediumSlateBlue, Color.MediumSpringGreen, Color.Maroon, Color.MediumTurquoise, Color.MidnightBlue, Color.MintCream, Color.MistyRose, Color.Moccasin, Color.NavajoWhite, Color.Navy, Color.OldLace, Color.MediumVioletRed, Color.Magenta, Color.Linen, Color.LimeGreen, Color.LavenderBlush, Color.LawnGreen, Color.LemonChiffon, Color.LightBlue, Color.LightCoral, Color.LightCyan, Color.LightGoldenrodYellow, Color.LightGray, Color.LightGreen, Color.LightPink, Color.LightSalmon, Color.LightSeaGreen, Color.LightSkyBlue, Color.LightSlateGray, Color.LightSteelBlue, Color.LightYellow, Color.Lime, Color.Olive, Color.OliveDrab, Color.Orange, Color.OrangeRed, Color.Silver, Color.SkyBlue, Color.SlateBlue, Color.SlateGray, Color.Snow, Color.SpringGreen, Color.SteelBlue, Color.Tan, Color.Teal, Color.Thistle, Color.Tomato, Color.Transparent, Color.Turquoise, Color.Violet, Color.Wheat, Color.White, Color.WhiteSmoke, Color.Sienna, Color.Lavender, Color.SeaShell, Color.SandyBrown, Color.Orchid, Color.PaleGoldenrod, Color.PaleGreen, Color.PaleTurquoise, Color.PaleVioletRed, Color.PapayaWhip, Color.PeachPuff, Color.Peru, Color.Pink, Color.Plum, Color.PowderBlue, Color.Purple, Color.Red, Color.RosyBrown, Color.RoyalBlue, Color.SaddleBrown, Color.Salmon, Color.SeaGreen, Color.Yellow, Color.Khaki, Color.Cyan, Color.DarkMagenta, Color.DarkKhaki, Color.DarkGreen, Color.DarkGray, Color.DarkGoldenrod, Color.DarkCyan, Color.DarkBlue, Color.Ivory, Color.Crimson, Color.Cornsilk, Color.CornflowerBlue, Color.Coral, Color.Chocolate, Color.DarkOliveGreen, Color.Chartreuse, Color.BurlyWood, Color.Brown, Color.BlueViolet, Color.Blue, Color.BlanchedAlmond, Color.Black, Color.Bisque, Color.Beige, Color.Azure, Color.Aquamarine, Color.Aqua, Color.AntiqueWhite, Color.AliceBlue, Color.CadetBlue, Color.DarkOrange, Color.YellowGreen, Color.DarkRed, Color.Indigo, Color.IndianRed, Color.DarkOrchid, Color.Honeydew, Color.GreenYellow, Color.Green, Color.Gray, Color.Goldenrod, Color.Gold, Color.GhostWhite, Color.Gainsboro, Color.Fuchsia, Color.ForestGreen, Color.HotPink, Color.Firebrick, Color.FloralWhite, Color.DodgerBlue, Color.DimGray, Color.DeepSkyBlue, Color.DeepPink, Color.DarkViolet, Color.DarkTurquoise, Color.DarkSlateGray, Color.DarkSlateBlue, Color.DarkSeaGreen, Color.DarkSalmon, };
public static DotGraph DumpEgraphAsDot(EGraph eGraph, IReadOnlyList<IMatchResult>? matches, string file)
public static DotGraph DumpEgraphAsDot(IEGraph eGraph, IReadOnlyList<IMatchResult>? matches, Stream output)
{
var printer = new EGraphPrinter(eGraph);
printer.ConvertEGraphAsDot();
printer.AttachEGraphMatches(matches);
return printer.SaveToFile(file);
return printer.SaveToStream(output);
}
public DotGraph AttachEGraphMatches(IReadOnlyList<IMatchResult>? matches)

View File

@ -7,6 +7,8 @@ using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.PatternMatch;
@ -17,90 +19,75 @@ namespace Nncase.Transform;
/// </summary>
public class EGraphPass : RulesPass
{
private readonly List<IRewriteRule> _rules = new();
private readonly IEGraphRewriteProvider _rewriteProvider;
private readonly Evaluator.IBaseFuncCostEvaluator? _baseFuncCostEvaluator;
/// <summary>
/// Initializes a new instance of the <see cref="EGraphPass"/> class.
/// </summary>
/// <param name="name">Name.</param>
public EGraphPass(string name)
: base(name)
{
_baseFuncCostEvaluator = null;
}
/// <summary>
/// Initializes a new instance of the <see cref="EGraphPass"/> class.
/// </summary>
/// <param name="name">Pass Name.</param>
/// <param name="baseFuncCostEvaluator">Extenal cost evaluator.</param>
public EGraphPass(string name, Evaluator.IBaseFuncCostEvaluator baseFuncCostEvaluator)
: base(name)
public EGraphPass(Evaluator.IBaseFuncCostEvaluator? baseFuncCostEvaluator = null)
{
_rewriteProvider = CompileSession.GetRequiredService<IEGraphRewriteProvider>();
_baseFuncCostEvaluator = baseFuncCostEvaluator;
}
/// <inheritdoc/>
protected override async Task<BaseFunction> RunCoreAsync(BaseFunction function, RunPassOptions options)
protected override async Task<BaseFunction> RunCoreAsync(BaseFunction function, RunPassContext context)
{
var graph = new EGraph();
var root = graph.Add(function);
EGraphRewriter.Rewrite(graph, Rules, options);
OnPostRewriteStart(graph, options);
await OnPostRewrite(graph, options);
OnPostRewriteEnd(graph, options);
var post = graph.Extract(root, _baseFuncCostEvaluator, options);
_rewriteProvider.ERewrite(graph, Rules, context);
await OnPostRewriteStartAsync(graph, context);
await OnPostRewriteAsync(graph, context);
await OnPostRewriteEndAsync(graph, context);
var post = graph.Extract(root, _baseFuncCostEvaluator);
CompilerServices.InferenceType(post);
return (BaseFunction)post;
}
protected virtual Task OnPostRewrite(EGraph graph, RunPassOptions options)
/// <summary>
/// The callback after egraph rewrite.
/// </summary>
/// <param name="eGraph">EGraph after rewrite.</param>
/// <param name="context">Run pass context.</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
protected virtual Task OnPostRewriteAsync(EGraph eGraph, RunPassContext context)
{
return Task.CompletedTask;
}
/// <summary>
/// the callback function you can custom process func with run pass options.
/// The callback function you can custom process func with run pass options.
/// </summary>
/// <param name="eGraph"> egraph without run pass.</param>
/// <param name="options">Options.</param>
protected virtual void OnPostRewriteStart(EGraph eGraph, RunPassOptions options)
/// <param name="eGraph">EGraph after rewrite.</param>
/// <param name="context">Run pass context.</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
protected virtual Task OnPostRewriteStartAsync(EGraph eGraph, RunPassContext context)
{
switch (options.DumpLevel)
if (DumpScope.Current.IsEnabled(DumpFlags.PassIR))
{
case >= 4:
EGraphPrinter.DumpEgraphAsDot(
eGraph,
null,
Path.Combine(options.DumpDir, options.PassName, "PostRewriteStart", $"V{eGraph.Version}"));
break;
case >= 1:
break;
default:
break;
using var fs = DumpScope.Current.OpenFile(Path.Combine("PostRewriteStart", $"V{eGraph.Version}.dot"));
EGraphPrinter.DumpEgraphAsDot(eGraph, null, fs);
}
return Task.CompletedTask;
}
/// <summary>
/// the callback function you can custom process func with run pass options.
/// The callback function you can custom process func with run pass options.
/// </summary>
/// <param name="eGraph"> egraph with rewrited. </param>
/// <param name="options">Options.</param>
protected virtual void OnPostRewriteEnd(EGraph eGraph, RunPassOptions options)
/// <param name="eGraph">EGraph after post rewrite.</param>
/// <param name="context">Run pass context.</param>
/// <returns>A <see cref="Task"/> representing the asynchronous operation.</returns>
protected virtual Task OnPostRewriteEndAsync(EGraph eGraph, RunPassContext context)
{
switch (options.DumpLevel)
if (DumpScope.Current.IsEnabled(DumpFlags.PassIR))
{
case >= 4:
EGraphPrinter.DumpEgraphAsDot(
eGraph,
null,
Path.Combine(options.DumpDir, options.PassName, "PostRewriteEnd", $"V{eGraph.Version}"));
break;
case >= 1:
break;
default:
break;
using var fs = DumpScope.Current.OpenFile(Path.Combine("PostRewriteEnd", $"V{eGraph.Version}.dot"));
EGraphPrinter.DumpEgraphAsDot(eGraph, null, fs);
}
return Task.CompletedTask;
}
}

View File

@ -44,7 +44,7 @@ public partial class EGraphPrinter
private readonly Dictionary<EClass, string> _opMaps = new();
private readonly EGraph _eGraph;
private readonly IEGraph _eGraph;
private readonly DotDumpVisitor _visitor = new DotDumpVisitor();
@ -55,7 +55,7 @@ public partial class EGraphPrinter
/// ctor for egraph.
/// </summary>
/// <param name="egraph"></param>
public EGraphPrinter(EGraph egraph)
public EGraphPrinter(IEGraph egraph)
{
_idCounter = 0;
DotGraph = new(directed: true);
@ -74,13 +74,26 @@ public partial class EGraphPrinter
}
}
/// <summary>
/// dump egraph as dot graph.
/// </summary>
/// <param name="eGraph">egraph.</param>
/// <param name="output">Output stream.</param>
/// <returns>Converted Graph.</returns>
public static DotGraph DumpEgraphAsDot(IEGraph eGraph, Stream output)
{
var printer = new EGraphPrinter(eGraph);
printer.ConvertEGraphAsDot();
return printer.SaveToStream(output);
}
/// <summary>
/// dump egraph as dot graph.
/// </summary>
/// <param name="eGraph">egraph.</param>
/// <param name="file">path.</param>
/// <returns>Converted Graph.</returns>
public static DotGraph DumpEgraphAsDot(EGraph eGraph, string file)
public static DotGraph DumpEgraphAsDot(IEGraph eGraph, string file)
{
var printer = new EGraphPrinter(eGraph);
printer.ConvertEGraphAsDot();
@ -188,10 +201,21 @@ public partial class EGraphPrinter
return DotGraph;
}
/// <summary>
/// Save the DotGraph into stream.
/// </summary>
/// <param name="output">Output stream.</param>
/// <returns>this dot graph.</returns>
public DotGraph SaveToStream(Stream output)
{
DotGraph.Build(new StreamWriter(output, leaveOpen: true));
return DotGraph;
}
/// <summary>
/// Save the DotGraph into file.
/// </summary>
/// <param name="file">file path.</param>
/// <param name="file">Output file.</param>
/// <returns>this dot graph.</returns>
public DotGraph SaveToFile(string file)
{

View File

@ -1,90 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Nncase.IR;
using Nncase.PatternMatch;
namespace Nncase.Transform;
internal static class EGraphRewriter
{
/// <summary>
/// Run egraph rewrite.
/// </summary>
/// <returns></returns>
public static EGraph Rewrite(EGraph eGraph, IEnumerable<IRewriteRule> rules, RunPassOptions options)
{
var matches = new List<(IRewriteRule, IReadOnlyList<IMatchResult>)> { };
var last_version = eGraph.Version;
int count = 0;
while (true)
{
foreach (var rule in rules)
{
if (EGraphMatcher.TryMatchRoot(eGraph.Nodes, rule.Pattern, out var results))
{
matches.Add((rule, results));
if (options.DumpLevel > 1 && results.Count != 0)
{
EGraphPrinter.DumpEgraphAsDot(
eGraph,
results,
Path.Combine(options.DumpDir, options.PassName, "Matches", $"V{eGraph.Version}_{count++}_{rule.GetType().Name}"));
}
}
}
foreach (var (rule, results) in matches)
{
var replacedExprs = (from result in results
let expr = rule.GetReplace(result, options)
where expr != null
select (eGraph.Find((ENode)result.Root), expr)).ToList();
foreach (var (oldEClass, newExpr) in replacedExprs)
{
if (!CompilerServices.InferenceType(newExpr))
{
CompilerServices.DumpIR(newExpr, "Replaced_Expr", options.DumpDir);
throw new InvalidOperationException("Can't Inference The Replace Expr Type!");
}
var newEClass = eGraph.Add(newExpr);
if (options.DumpLevel > 4)
{
Console.WriteLine($"Version {eGraph.Version} : Merge {{{oldEClass}}} to {{{newEClass}}}");
}
eGraph.Union(newEClass, oldEClass);
}
}
matches.Clear();
if (last_version == eGraph.Version)
{
break;
}
else
{
last_version = eGraph.Version;
}
eGraph.Rebuild();
if (options.DumpLevel > 1)
{
EGraphPrinter.DumpEgraphAsDot(
eGraph,
Path.Combine(options.DumpDir, options.PassName, "Rebuild", $"V{eGraph.Version}"));
}
}
return eGraph;
}
}

View File

@ -3,16 +3,29 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.PatternMatch;
using Tensorflow;
namespace Nncase.Transform;
internal class EGraphRewriteProvider : IEGraphRewriteProvider
{
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassOptions options)
private readonly ILogger _logger;
public EGraphRewriteProvider(ILogger<EGraphRewriteProvider> logger)
{
_logger = logger;
}
public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext options)
{
if (expr.CheckedType is null)
{
@ -21,8 +34,73 @@ internal class EGraphRewriteProvider : IEGraphRewriteProvider
var graph = new EGraph();
var root = graph.Add(expr);
EGraphRewriter.Rewrite(graph, rules, options);
var post = graph.Extract(root, null, options);
ERewrite(graph, rules, options);
var post = graph.Extract(root, null);
return post;
}
public IEGraph ERewrite(IEGraph eGraph, IEnumerable<IRewriteRule> rules, RunPassContext context)
{
var matches = new List<(IRewriteRule, IReadOnlyList<IMatchResult>)> { };
var last_version = eGraph.Version;
int count = 0;
while (true)
{
foreach (var rule in rules)
{
if (EGraphMatcher.TryMatchRoot(eGraph.Nodes, rule.Pattern, out var results))
{
matches.Add((rule, results));
if (DumpScope.Current.IsEnabled(DumpFlags.Rewrite) && results.Count != 0)
{
using var fs = DumpScope.Current.OpenFile(Path.Combine("Matches", $"V{eGraph.Version}_{count++}_{rule.GetType().Name}.dot"));
EGraphPrinter.DumpEgraphAsDot(eGraph, results, fs);
}
}
}
foreach (var (rule, results) in matches)
{
var replacedExprs = (from result in results
let expr = rule.GetReplace(result, context)
where expr != null
select (eGraph.Find((ENode)result.Root), expr)).ToList();
foreach (var (oldEClass, newExpr) in replacedExprs)
{
var typeInferSuccess = CompilerServices.InferenceType(newExpr);
Trace.Assert(typeInferSuccess);
var newEClass = eGraph.Add(newExpr);
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.LogTrace("Version {Version} : Merge {OldClass} to {NewClass}", eGraph.Version, oldEClass, newEClass);
}
eGraph.Union(newEClass, oldEClass);
}
}
matches.Clear();
if (last_version == eGraph.Version)
{
break;
}
else
{
last_version = eGraph.Version;
}
eGraph.Rebuild();
if (DumpScope.Current.IsEnabled(DumpFlags.Rewrite))
{
using var fs = DumpScope.Current.OpenFile(Path.Combine("Rebuild", $"V{eGraph.Version}.dot"));
EGraphPrinter.DumpEgraphAsDot(eGraph, fs);
}
}
return eGraph;
}
}

View File

@ -1,18 +1,18 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Autofac;
using DryIoc;
using Nncase.Hosting;
namespace Nncase.Transform;
/// <summary>
/// Transform module.
/// </summary>
public class TransformModule : Module
internal class TransformModule : IApplicationPart
{
/// <inheritdoc/>
protected override void Load(ContainerBuilder builder)
public void ConfigureServices(IRegistrator registrator)
{
builder.RegisterType<EGraphRewriteProvider>().AsImplementedInterfaces().SingleInstance();
registrator.Register<IEGraphRewriteProvider, EGraphRewriteProvider>(reuse: Reuse.ScopedOrSingleton);
}
}

View File

@ -1,7 +1,6 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Autofac;
using Nncase.IR;
using Nncase.IR.Buffer;

View File

@ -1,7 +1,6 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Autofac;
using Nncase.IR;
using Nncase.IR.Buffer;

View File

@ -1,23 +1,23 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Autofac;
using DryIoc;
using Nncase.Evaluator.Tensors;
using Nncase.Hosting;
namespace Nncase.Evaluator.Buffer;
/// <summary>
/// Buffer module.
/// </summary>
public class BufferModule : Module
internal class BufferModule : IApplicationPart
{
/// <inheritdoc/>
protected override void Load(ContainerBuilder builder)
public void ConfigureServices(IRegistrator registrator)
{
builder.RegisterType<DDrOfEvaluator>().AsImplementedInterfaces();
builder.RegisterType<BaseMentOfEvaluator>().AsImplementedInterfaces();
builder.RegisterType<StrideOfEvaluator>().AsImplementedInterfaces();
builder.RegisterType<AllocateEvaluator>().AsImplementedInterfaces();
builder.RegisterType<UninitializedEvaluator>().AsImplementedInterfaces();
registrator.RegisterManyInterface<DDrOfEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<BaseMentOfEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<StrideOfEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<AllocateEvaluator>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<UninitializedEvaluator>(reuse: Reuse.Singleton);
}
}

View File

@ -1,7 +1,6 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Autofac;
using Nncase.IR;
using Nncase.IR.Buffer;

Some files were not shown because too many files have changed in this diff Show More