Feature/cpu (#1019)

* add layernorm

* pass reduce

* add comment

* add layer norm test

* fix  layernorm

* fix layernorm

* add demo2

* fix build / add view

* update layernorm

* support layernorm of llama

* fix build

* add demo2

* pass ym

* pass demo2

* fix onnx external data importer

* fuse MHA of llama

* add more cpu kernels

* update MHA fusion

* reorder MHA weights

* add demo3

* add demo3 compute statge 1

* fix build

* fix __tdma_all_sync_apply

* add to v35

* dump const

* update demo3 golden

* compiled

* support multiple output compare

* fix MHA kernel

* resplit v2

* fix MHA kernel

* to v26

* push

* fix double free

* fix mha kernel

* fix  V35

* fix all

* support rmsnrom

* fix v22

* fix v22 v10

* pass v28

* pass v43

* remove dump

* add other part

* pass all llama65b decoder layers

* pass gather

* open 32 threads for demo4

* update binary/unary with external op

* fix codes using stdlib

* update kernel inputs

* Fix gather

* refactor cpu cmodel

* update demo head

* pass graph to tir

* fix head main

* pass norm case

* update demo names

* add xpu source gen

* Fix head kernel segment fault

* fix cost evaluator

* Fix head kernel cos similarity

* decoder layer pass input layernorm

* Add uanry demo

* pass v30 of decoder layer

* fix softmax

* Enable ImmOutput

* fix malloc

* remove debug macro

* Add ImmOut

* fix tdma store

* refactor cpu runtime

* refactor method table

* fix cpu test

* refactor auto distributed

* update cpu test

* fix rdata

* update cpu test with rdata

* fix typeinfer

* add XPU Op layernorm

* update layernorm cost

* fix cost evaluator

* fix layernorm

* add partial resplit

* add rvv matmul

* add codegen of cpu gather

* add concat/slice codegen

* merge

* add codegen of cpu softmax

* update slice cpu case

* fix slice

* fix cpu concat

* fix cpu concat

* Apply code-format changes

* fix build

* Apply code-format changes

* add codegen of transpose

* add reshape

* pass reshape2

* update stackvm

* merge

* fix build

* update compile

* fix to slicing

* fix negative axis

* fix matmul evaluator

* add NormAxis

* fix ToSlice

* fix matmul

* add GatherReduceScatter

* fix ToSlice

* refactor auto dist

* fix boxing partial to slice codegen

* softmax support split on axis

* add conv2d cpu kernel

* disable outter split on inner splited axis

* fix binary distributedtype infer

* fixGetPartialCandidateNDSBPs

* pass cpu conv2d

* support dilated conv2d

* add mha pattern

* add combine reshape transpose

* fix mha fusion/ add rules

* fix rdata map dispose

* add xpu reduce arg

* Apply code-format changes

* add VAE fusion

* fuse VAE

* support xpu instance norm

* Apply code-format changes

* add reduce arg

* Apply code-format changes

* fix to tir keep vars order

* Apply code-format changes

* add XPU resize

* Apply code-format changes

* fix resize cpu kernel op

* fix Resize

* Update layernom op for test

* Apply code-format changes

* fix conv2d kernel

* fix boxing with reshape

* fix build

* fix pytest compare

* add gelu kernels

* add xpu cast

* fix swish type infer

* support xpu expand

* Update layernorm rvv codes

* fix binary broadcast with distributed broadcast

* support multi outputs

* fix single output

* fix new linked section

* fuse Unet

* add cos dump

* fix build

* speed up onnx external data load

* add typeinfer case for binary/matmul

* move matmul rvv to kernels

* fix conv2d kernel

* fix Unet Fusion

* optimize dynamic onnx

* change fusion counter

* fix conv2d if split is partial

* split conv to conv+bias+clamp, and add xpu clamp

* update fusion merger

* fix slice with negative axis

* llama-4-decoder pass (x86/rv64)

* text encoder/vae decoder pass (x86/rv64)

* fix conan config

* Fix cpu/test compile

* fix cmake config

* fix synax err

* fix synax err

* normallize axes of slice

* disable module cpu on windows

* donot split softmax on axis

* add softmax kernel test

* add rvv instance norm

* clean modules dir

* add rvv clamp

* Clean modules dir

* fix match result

* Apply code-format changes

* Clean Tests

* fix csproj

* fix buffer schedule

* Add unet pytest

* add target's commands

* fix buffer and memspan hashcode and equals

* fix unitest

* fix unittest

* fix test_cli output dump

* fix command line

* fix type infer

* fix format

* fix all test

* Apply code-format changes

* fix merge

* Apply code-format changes

* fix merge

* fix runtime build

* fix kernel test build

* Apply code-format changes

* fix use mean

* Apply code-format changes

* optimize dot dump

* fix merge

* fix output when test cli

---------

Co-authored-by: xhuohai <xhuohai@users.noreply.github.com>
Co-authored-by: 郑启航 <597323109@qq.com>
Co-authored-by: zhen8838 <zhen8838@users.noreply.github.com>
Co-authored-by: lerenhua <2532375005@qq.com>
Co-authored-by: liuzhiming <liuzhiming@canaan-creative.com>
Co-authored-by: liuzm6217-jianan <liuzm6217-jianan@users.noreply.github.com>
pull/1121/head
huochenghai 2023-11-07 10:13:25 +08:00 committed by GitHub
parent 21eccd21b9
commit 338ba1070d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
243 changed files with 5138 additions and 3345 deletions

4
.gitignore vendored
View File

@ -68,7 +68,7 @@ artifacts/
*.pidb
*.svclog
*.scc
*.bin
# Chutzpah Test files
_Chutzpah*
@ -306,4 +306,4 @@ cmake-build-*
*gmodel_dump_dir*
*.ipynb_checkpoints*
# Auto generated files
# generated/
# generated/

View File

@ -12,7 +12,6 @@ endif()
if(NOT DEFINED NNCASE_VERSION_SUFFIX)
find_package (Git)
execute_process(
COMMAND ${GIT_EXECUTABLE} describe --always --dirty --tag
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
@ -274,5 +273,5 @@ if(BUILD_TESTING)
endif()
# Modules
#add_subdirectory(modules/k210)
#add_subdirectory(modules/vulkan)

View File

@ -49,8 +49,9 @@
<PackageVersion Include="OrtKISharp" Version="0.0.2" />
<PackageVersion Include="RazorLight" Version="2.3.0" />
<PackageVersion Include="Singulink.Collections.Weak" Version="1.0.2" />
<PackageVersion Include="StyleCop.Analyzers" Version="1.2.0-beta.507" />
<PackageVersion Include="System.CommandLine.Hosting" Version="0.3.0-alpha.21216.1" />
<PackageVersion Include="StyleCop.Analyzers" Version="1.2.0-beta.435" />
<PackageVersion Include="System.CommandLine.Hosting" Version="0.4.0-alpha.22272.1" />
<PackageVersion Include="System.CommandLine" Version="2.0.0-beta4.22272.1" />
<PackageVersion Include="System.Linq.Async" Version="6.0.1" />
<PackageVersion Include="System.Reactive" Version="5.0.0" />
<PackageVersion Include="Tomlyn.Extensions.Configuration" Version="1.0.5" />

View File

@ -1,295 +0,0 @@
{
"version": 2,
"dependencies": {
"net7.0": {
"StyleCop.Analyzers": {
"type": "Direct",
"requested": "[1.2.0-beta.435, )",
"resolved": "1.2.0-beta.435",
"contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==",
"dependencies": {
"StyleCop.Analyzers.Unstable": "1.2.0.435"
}
},
"Google.OrTools.runtime.linux-arm64": {
"type": "Transitive",
"resolved": "9.4.1874",
"contentHash": "Z46ndZcZa2Lt5b76xU9kxVYbPLg/LfuMufhUVsu3Qo3L7Bibf7WXd9j7RRldjnuv8RIHWTqb0b+2FwwMxs0c5A=="
},
"Google.OrTools.runtime.linux-x64": {
"type": "Transitive",
"resolved": "9.4.1874",
"contentHash": "zGeDb8FuvP9HXjrsU7krVXtSDFpR+DUGNEsH51k94jL9tzf2vWYI8+WUBRHZ/cGe50dpLr+vIjfcNo3gFyOpkQ=="
},
"Google.OrTools.runtime.osx-arm64": {
"type": "Transitive",
"resolved": "9.4.1874",
"contentHash": "Wo0ZfDaH6DhiQw0jZm4HWJm/oPGPpWNwOLUz+EYaoH3MLtocSxItHGQj/Ta3HyhXnYNOv+TliAH8L+8RCXu/2w=="
},
"Google.OrTools.runtime.osx-x64": {
"type": "Transitive",
"resolved": "9.4.1874",
"contentHash": "IAfGgKR1og6vU87axK1d37Ak/4jy8B4NMoElovG/KZc/2UY+cJEAQDA709UMegtI4lBhuxTWFNUiHQYmRIB9yQ=="
},
"Google.OrTools.runtime.win-x64": {
"type": "Transitive",
"resolved": "9.4.1874",
"contentHash": "fUs5qDnZA6itygolcX6nPuachQkY9CVvQbakIzIiRAWKcaj8umQAbFdGwbkyzp3qp34BKW5mtPVsmMyfQBBjOQ=="
},
"libortki": {
"type": "Transitive",
"resolved": "0.0.2",
"contentHash": "svfuG5mxGY/QC/5DVheHOCELmdSP90RtxQ73j23KarPXZ9ZXW+7v1l5J77hGDyQbEh1BGrnGgKBlyn76RauGHg==",
"dependencies": {
"libortki-linux": "0.0.2",
"libortki-osx": "0.0.2",
"libortki-osx-arm64": "0.0.2",
"libortki-win": "0.0.2"
}
},
"libortki-linux": {
"type": "Transitive",
"resolved": "0.0.2",
"contentHash": "b04LWD4lgGy60tys3hPFhnUpgWDM6dN5r1PI7GOcPj8VupXCaI70LKNQ5/5twbDE6rkowOGanVTw0S2wBGBqBQ=="
},
"libortki-osx": {
"type": "Transitive",
"resolved": "0.0.2",
"contentHash": "O6Q9GLULkDkZEPAZJVKLPH0ROXGVOE7BxuddgOcHNK2oiTEM7wIRnzp2OIlYgLpaOLyxJMisbGOhtWgdzt2Wng=="
},
"libortki-osx-arm64": {
"type": "Transitive",
"resolved": "0.0.2",
"contentHash": "4Qn2dirJmRicnUG945oWpq7HVGwgqCKKxYPMISv/MRvmpZBbXrZ1cVvRaF8WwTu4XXgfKTa1sLv+i8zLifUMeQ=="
},
"libortki-win": {
"type": "Transitive",
"resolved": "0.0.2",
"contentHash": "HAoROgAKn8XBun11X43HZuspKlo5JGy8/OYw5IUPo7FVh5TCaPrLjGmyGYYZ2dqLlv31yv/b6s254PIRGn95cA=="
},
"Microsoft.Extensions.Configuration.Abstractions": {
"type": "Transitive",
"resolved": "6.0.0",
"contentHash": "qWzV9o+ZRWq+pGm+1dF+R7qTgTYoXvbyowRoBxQJGfqTpqDun2eteerjRQhq5PQ/14S+lqto3Ft4gYaRyl4rdQ==",
"dependencies": {
"Microsoft.Extensions.Primitives": "6.0.0"
}
},
"Microsoft.Extensions.DependencyInjection.Abstractions": {
"type": "Transitive",
"resolved": "6.0.0",
"contentHash": "xlzi2IYREJH3/m6+lUrQlujzX8wDitm4QGnUu6kUXTQAWPuZY8i+ticFJbzfqaetLA6KR/rO6Ew/HuYD+bxifg=="
},
"Microsoft.Extensions.FileProviders.Abstractions": {
"type": "Transitive",
"resolved": "6.0.0",
"contentHash": "0pd4/fho0gC12rQswaGQxbU34jOS1TPS8lZPpkFCH68ppQjHNHYle9iRuHeev1LhrJ94YPvzcRd8UmIuFk23Qw==",
"dependencies": {
"Microsoft.Extensions.Primitives": "6.0.0"
}
},
"Microsoft.Extensions.Primitives": {
"type": "Transitive",
"resolved": "6.0.0",
"contentHash": "9+PnzmQFfEFNR9J2aDTfJGGupShHjOuGw4VUv+JB044biSHrnmCIMD+mJHmb2H7YryrfBEXDurxQ47gJZdCKNQ==",
"dependencies": {
"System.Runtime.CompilerServices.Unsafe": "6.0.0"
}
},
"NetFabric.Hyperlinq.Abstractions": {
"type": "Transitive",
"resolved": "1.3.0",
"contentHash": "WXnEcGwmXfa8gW9N2MlcaPNUzM3NLMwnAhacbtH554F8YcoXbIkTB+uGa1Aa+9gyb/9JZgYVHnmADgJUKP52nA=="
},
"StyleCop.Analyzers.Unstable": {
"type": "Transitive",
"resolved": "1.2.0.435",
"contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg=="
},
"System.Buffers": {
"type": "Transitive",
"resolved": "4.5.1",
"contentHash": "Rw7ijyl1qqRS0YQD/WycNst8hUUMgrMH4FCn1nNm27M4VxchZ1js3fVjQaANHO5f3sN4isvP4a+Met9Y4YomAg=="
},
"System.Runtime.CompilerServices.Unsafe": {
"type": "Transitive",
"resolved": "6.0.0",
"contentHash": "/iUeP3tq1S0XdNNoMz5C9twLSrM/TH+qElHkXWaPvuNOt+99G75NrV0OS2EqHx5wMN7popYjpc8oTjC1y16DLg=="
},
"nncase.codegen": {
"type": "Project",
"dependencies": {
"Extension.Mathematics": "[1.2.12, )",
"Nncase.Core": "[1.0.0, )",
"Nncase.IO": "[1.0.0, )"
}
},
"nncase.core": {
"type": "Project",
"dependencies": {
"DryIoc.dll": "[5.3.1, )",
"GiGraph.Dot": "[2.0.0, )",
"Microsoft.Extensions.Hosting.Abstractions": "[6.0.0, )",
"Microsoft.Extensions.Logging.Abstractions": "[6.0.0, )",
"Microsoft.Extensions.Options": "[6.0.0, )",
"Microsoft.Toolkit.HighPerformance": "[7.1.1, )",
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
"System.Reactive": "[5.0.0, )"
}
},
"nncase.diagnostics": {
"type": "Project",
"dependencies": {
"Nncase.Core": "[1.0.0, )"
}
},
"nncase.egraph": {
"type": "Project",
"dependencies": {
"GiGraph.Dot": "[2.0.0, )",
"Google.OrTools": "[9.4.1874, )",
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
"Nncase.Core": "[1.0.0, )",
"Nncase.Evaluator": "[1.0.0, )",
"Singulink.Collections.Weak": "[1.0.2, )"
}
},
"nncase.evaluator": {
"type": "Project",
"dependencies": {
"Nncase.Core": "[1.0.0, )",
"OrtKISharp": "[0.0.2, )"
}
},
"nncase.graph": {
"type": "Project",
"dependencies": {
"Nncase.Core": "[1.0.0, )",
"Nncase.Evaluator": "[1.0.0, )"
}
},
"nncase.io": {
"type": "Project"
},
"nncase.modules.stackvm": {
"type": "Project",
"dependencies": {
"Nncase.CodeGen": "[1.0.0, )",
"Nncase.Passes": "[1.0.0, )"
}
},
"nncase.passes": {
"type": "Project",
"dependencies": {
"Nncase.Core": "[1.0.0, )",
"Nncase.EGraph": "[1.0.0, )",
"Nncase.Evaluator": "[1.0.0, )",
"Nncase.Graph": "[1.0.0, )"
}
},
"DryIoc.dll": {
"type": "CentralTransitive",
"requested": "[5.3.1, )",
"resolved": "5.3.1",
"contentHash": "E3zclUh2CIBks1t2uBD1k18pyGFJ1YSKCrbCDbB7qCdl2RAB+k68AyDpjeplhF1ot2XPV82AgyCWBXMf0ggL1g=="
},
"Extension.Mathematics": {
"type": "CentralTransitive",
"requested": "[1.2.12, )",
"resolved": "1.2.12",
"contentHash": "D4mn5Cab4ztPLJ0V8uMErDrO/Y61098nwrvyIOLZymVAYOQcwP1vomVWKbTagf1aPU3cX5Q7adZtQEQwOy6XEg=="
},
"GiGraph.Dot": {
"type": "CentralTransitive",
"requested": "[2.0.0, )",
"resolved": "2.0.0",
"contentHash": "ThvS2mQVveSkTMUm04tMbRYzu1XFPV8xBHISrUMp02APjhv9IRbLu3v3upTPCywORx2Ds/c6AqEUL1WU6kPfuQ=="
},
"Google.OrTools": {
"type": "CentralTransitive",
"requested": "[9.4.1874, )",
"resolved": "9.4.1874",
"contentHash": "jqRoI+pYlym+fhoU25u+13oti5h+772bllQ9zDitTVMclDXVTiG6pxzvmYO74wnADBMdpb2SQlgiNQxoNk5dlA==",
"dependencies": {
"Google.OrTools.runtime.linux-arm64": "9.4.1874",
"Google.OrTools.runtime.linux-x64": "9.4.1874",
"Google.OrTools.runtime.osx-arm64": "9.4.1874",
"Google.OrTools.runtime.osx-x64": "9.4.1874",
"Google.OrTools.runtime.win-x64": "9.4.1874",
"Google.Protobuf": "3.19.4"
}
},
"Google.Protobuf": {
"type": "CentralTransitive",
"requested": "[3.19.4, )",
"resolved": "3.19.4",
"contentHash": "fd07/ykL4O4FhqrZIELm5lmiyOHfdPg9+o+hWr6tcfRdS7tHXnImg/2wtogLzlW2eEmr0J7j6ZrZvaWOLiJbxQ=="
},
"Microsoft.Extensions.Hosting.Abstractions": {
"type": "CentralTransitive",
"requested": "[6.0.0, )",
"resolved": "6.0.0",
"contentHash": "GcT5l2CYXL6Sa27KCSh0TixsRfADUgth+ojQSD5EkzisZxmGFh7CwzkcYuGwvmXLjr27uWRNrJ2vuuEjMhU05Q==",
"dependencies": {
"Microsoft.Extensions.Configuration.Abstractions": "6.0.0",
"Microsoft.Extensions.DependencyInjection.Abstractions": "6.0.0",
"Microsoft.Extensions.FileProviders.Abstractions": "6.0.0"
}
},
"Microsoft.Extensions.Logging.Abstractions": {
"type": "CentralTransitive",
"requested": "[6.0.0, )",
"resolved": "6.0.0",
"contentHash": "/HggWBbTwy8TgebGSX5DBZ24ndhzi93sHUBDvP1IxbZD7FDokYzdAr6+vbWGjw2XAfR2EJ1sfKUotpjHnFWPxA=="
},
"Microsoft.Extensions.Options": {
"type": "CentralTransitive",
"requested": "[6.0.0, )",
"resolved": "6.0.0",
"contentHash": "dzXN0+V1AyjOe2xcJ86Qbo233KHuLEY0njf/P2Kw8SfJU+d45HNS2ctJdnEnrWbM9Ye2eFgaC5Mj9otRMU6IsQ==",
"dependencies": {
"Microsoft.Extensions.DependencyInjection.Abstractions": "6.0.0",
"Microsoft.Extensions.Primitives": "6.0.0"
}
},
"Microsoft.Toolkit.HighPerformance": {
"type": "CentralTransitive",
"requested": "[7.1.1, )",
"resolved": "7.1.1",
"contentHash": "TRnvDpZPXO30hTOtjfLw6Y9BtTKtTpzk9lefeh4RMCaUihWrVKQR454nYH4/mMJAh+LXqfAPyk0kfkJs0Amopw=="
},
"NetFabric.Hyperlinq": {
"type": "CentralTransitive",
"requested": "[3.0.0-beta48, )",
"resolved": "3.0.0-beta48",
"contentHash": "oYUhXvxNS8bBJWqNkvx5g8y0P/0LtyqS2pN0w4OWjVDNWEpLbdbvPy9w/9z1n2PrqIjX3jxUsEnoCmxxGnI3gw==",
"dependencies": {
"NetFabric.Hyperlinq.Abstractions": "1.3.0",
"System.Buffers": "4.5.1",
"System.Runtime.CompilerServices.Unsafe": "5.0.0"
}
},
"OrtKISharp": {
"type": "CentralTransitive",
"requested": "[0.0.2, )",
"resolved": "0.0.2",
"contentHash": "q8j0yR5836Zhv9WB9BFkQt1UaEFyibq8bqJcTiULlILF6/sz8z7Wy2N8sgYdDKsdW25zncIz7j6IDbKM5ynePg==",
"dependencies": {
"libortki": "0.0.2"
}
},
"Singulink.Collections.Weak": {
"type": "CentralTransitive",
"requested": "[1.0.2, )",
"resolved": "1.0.2",
"contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA=="
},
"System.Reactive": {
"type": "CentralTransitive",
"requested": "[5.0.0, )",
"resolved": "5.0.0",
"contentHash": "erBZjkQHWL9jpasCE/0qKAryzVBJFxGHVBAvgRN1bzM0q2s1S4oYREEEL0Vb+1kA/6BKb5FjUZMp5VXmy+gzkQ=="
}
}
}
}

View File

@ -1,6 +1,6 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/18 下午5:04:31 +08:00. */
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:08 AM +00:00. */
using System;
using System.Collections.Generic;
@ -59,7 +59,7 @@ internal partial class CodeGenVisitor
Emitter.T.L2Normalization();
break;
case IR.NN.LayerNorm top:
Emitter.T.LayerNorm(top.Axis, top.Epsilon);
Emitter.T.LayerNorm(top.Axis, top.Epsilon, top.UseMean);
break;
case IR.NN.LeakyRelu top:
Emitter.T.LeakyRelu();
@ -176,7 +176,7 @@ internal partial class CodeGenVisitor
Emitter.T.Cast(top.NewType, top.CastMode);
break;
case IR.Tensors.Concat top:
Emitter.T.Concat();
Emitter.T.Concat(top.Axis);
break;
case IR.Tensors.ConstantOfShape top:
Emitter.T.ConstantOfShape();
@ -191,7 +191,7 @@ internal partial class CodeGenVisitor
Emitter.T.Flatten();
break;
case IR.Tensors.Gather top:
Emitter.T.Gather();
Emitter.T.Gather(top.Axis);
break;
case IR.Tensors.GatherElements top:
Emitter.T.GatherElements();
@ -205,9 +205,6 @@ internal partial class CodeGenVisitor
case IR.Tensors.IndexOf top:
Emitter.T.IndexOf();
break;
case IR.Tensors.LSTM top:
Emitter.T.LSTM(top.Direction, top.Layout, top.Activations);
break;
case IR.Tensors.Prod top:
Emitter.T.Prod();
break;
@ -289,6 +286,9 @@ internal partial class CodeGenVisitor
case IR.ShapeExpr.UnsqueezeShape top:
Emitter.T.UnsqueezeShape();
break;
case IR.RNN.LSTM top:
Emitter.T.LSTM(top.Direction, top.Layout, top.Activations);
break;
case IR.Random.Normal top:
Emitter.T.Normal(top.Type);
break;

View File

@ -1,6 +1,6 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:30 +08:00. */
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:08 AM +00:00. */
using System;
using System.Collections.Generic;
@ -723,10 +723,11 @@ public partial class StackVMEmitter
}
///<summary>.</summary>
public void Concat()
public void Concat(int axis)
{
_emitter.Write((byte)100);
_emitter.Write((ushort)11);
_emitter.Write(axis);
}
///<summary>.</summary>
@ -841,10 +842,11 @@ public partial class StackVMEmitter
}
///<summary>.</summary>
public void Gather()
public void Gather(int axis)
{
_emitter.Write((byte)100);
_emitter.Write((ushort)27);
_emitter.Write(axis);
}
///<summary>.</summary>
@ -925,12 +927,13 @@ public partial class StackVMEmitter
}
///<summary>.</summary>
public void LayerNorm(int axis, float epsilon)
public void LayerNorm(int axis, float epsilon, bool useMean)
{
_emitter.Write((byte)100);
_emitter.Write((ushort)39);
_emitter.Write(axis);
_emitter.Write(epsilon);
_emitter.Write(useMean);
}
///<summary>.</summary>

View File

@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.CommandLine.Invocation;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
@ -21,13 +22,15 @@ namespace Nncase.Targets;
/// </summary>
public class CPUTarget : ITarget
{
/// <summary>
/// Gets kind.
/// </summary>
public static readonly string Kind = "cpu";
public const string Kind = "cpu";
string ITarget.Kind => Kind;
public (System.CommandLine.Command Command, Func<InvocationContext, System.CommandLine.Command, ITargetCompileOptions> Parser) RegisterCommandAndParser()
{
return (new System.CommandLine.Command(Kind), (_, _) => DefaultTargetCompileOptions.Instance);
}
/// <inheritdoc/>
public void ParseTargetDependentOptions(IConfigurationSection configure)
{

View File

@ -4,11 +4,11 @@
"net7.0": {
"StyleCop.Analyzers": {
"type": "Direct",
"requested": "[1.2.0-beta.507, )",
"resolved": "1.2.0-beta.507",
"contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==",
"requested": "[1.2.0-beta.435, )",
"resolved": "1.2.0-beta.435",
"contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==",
"dependencies": {
"StyleCop.Analyzers.Unstable": "1.2.0.507"
"StyleCop.Analyzers.Unstable": "1.2.0.435"
}
},
"Google.OrTools.runtime.linux-arm64": {
@ -103,8 +103,8 @@
},
"StyleCop.Analyzers.Unstable": {
"type": "Transitive",
"resolved": "1.2.0.507",
"contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw=="
"resolved": "1.2.0.435",
"contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg=="
},
"System.Buffers": {
"type": "Transitive",
@ -134,6 +134,7 @@
"Microsoft.Extensions.Options": "[6.0.0, )",
"Microsoft.Toolkit.HighPerformance": "[7.1.1, )",
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
"System.CommandLine": "[2.0.0-beta4.22272.1, )",
"System.Reactive": "[5.0.0, )"
}
},
@ -271,6 +272,12 @@
"resolved": "1.0.2",
"contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA=="
},
"System.CommandLine": {
"type": "CentralTransitive",
"requested": "[2.0.0-beta4.22272.1, )",
"resolved": "2.0.0-beta4.22272.1",
"contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg=="
},
"System.Reactive": {
"type": "CentralTransitive",
"requested": "[5.0.0, )",

View File

@ -44,8 +44,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.IO", "src\Nncase.IO\
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Schedule", "src\Nncase.Schedule\Nncase.Schedule.csproj", "{8E0E0672-0F96-4EF1-BDCD-D31F96A3DF73}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "targets", "targets", "{A2590531-71C5-4326-88DD-6A9DB2EF0A2B}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Targets", "src\Nncase.Targets\Nncase.Targets.csproj", "{56283378-06E3-4C6E-A8BF-7BD85C92D42C}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Simulator", "src\Nncase.Simulator\Nncase.Simulator.csproj", "{901AC17C-7B53-4B10-A2AC-EA7AEA6DC614}"

View File

@ -1,6 +1,7 @@
// https://gist.github.com/asford/544323a5da7dddad2c9174490eb5ed06
#pragma once
#include <cstdint>
#include <nncase/compiler_defs.h>
#include <pybind11/pybind11.h>

View File

@ -298,7 +298,7 @@ class Compiler:
def check_target(target: str):
def test_target(target: str):
return target in ["cpu", "k510", "k230"]
return target in ["cpu", "k510", "k230", "xpu"]
def target_exists(target: str):
return _nncase.Target.exists(target)

View File

@ -1,5 +1,5 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
* +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM
* +00:00.
*
* Copyright 2019-2021 Canaan Inc.
*
@ -78,7 +78,7 @@ compare(runtime::stackvm::compare_op_t compare_op, value_t lhs, value_t rhs,
kernel_context &context = default_kernel_context());
NNCASE_API result<value_t>
concat(value_t input, value_t axis, value_t output = nullptr,
concat(int32_t axis, value_t input, value_t output = nullptr,
kernel_context &context = default_kernel_context());
NNCASE_API result<value_t>
@ -157,7 +157,7 @@ flatten(value_t input, value_t axis, value_t output = nullptr,
kernel_context &context = default_kernel_context());
NNCASE_API result<value_t>
gather(value_t input, value_t axis, value_t index, value_t output = nullptr,
gather(int32_t axis, value_t input, value_t index, value_t output = nullptr,
kernel_context &context = default_kernel_context());
NNCASE_API result<value_t>
@ -211,8 +211,8 @@ l2_normalization(value_t input, value_t output = nullptr,
kernel_context &context = default_kernel_context());
NNCASE_API result<value_t>
layer_norm(int32_t axis, float epsilon, value_t input, value_t scale,
value_t bias, value_t output = nullptr,
layer_norm(int32_t axis, float epsilon, bool use_mean, value_t input,
value_t scale, value_t bias, value_t output = nullptr,
kernel_context &context = default_kernel_context());
NNCASE_API result<value_t>

View File

@ -73,6 +73,7 @@ class NNCASE_API interpreter {
options_dict &options() noexcept;
result<runtime_module *> find_module_by_id(size_t index) noexcept;
result<size_t> find_id_by_module(runtime_module *module) noexcept;
/* V1 APIs */

View File

@ -58,6 +58,8 @@ class NNCASE_API runtime_module {
result<runtime_function *> find_function_by_id(size_t index) noexcept;
result<size_t> find_id_by_function(runtime_function *function) noexcept;
protected:
virtual result<void>
initialize_before_functions(runtime_module_init_context &context) noexcept;

View File

@ -1,5 +1,5 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
* +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM
* +00:00.
*
* Copyright 2019-2021 Canaan Inc.
*
@ -837,6 +837,7 @@ template <> struct tensor_op_reader<tensor_function_t::compare> {
template <> struct tensor_op_reader<tensor_function_t::concat> {
tensor_concat_op_t operator()(NNCASE_UNUSED span_reader &reader) const {
tensor_concat_op_t op;
op.axis = reader.read_unaligned<int32_t>();
return op;
}
};
@ -964,6 +965,7 @@ template <> struct tensor_op_reader<tensor_function_t::flatten> {
template <> struct tensor_op_reader<tensor_function_t::gather> {
tensor_gather_op_t operator()(NNCASE_UNUSED span_reader &reader) const {
tensor_gather_op_t op;
op.axis = reader.read_unaligned<int32_t>();
return op;
}
};
@ -1055,6 +1057,7 @@ template <> struct tensor_op_reader<tensor_function_t::layer_norm> {
tensor_layer_norm_op_t op;
op.axis = reader.read_unaligned<int32_t>();
op.epsilon = reader.read_unaligned<float>();
op.use_mean = reader.read_unaligned<bool>();
return op;
}
};

View File

@ -1,5 +1,5 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
* +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM
* +00:00.
*
* Copyright 2019-2021 Canaan Inc.
*
@ -190,7 +190,6 @@ enum class tensor_function_t : uint16_t {
gather_nd = 29,
get_item = 31,
index_of = 36,
lstm = 44,
prod = 52,
range = 55,
rank = 57,
@ -218,6 +217,7 @@ enum class tensor_function_t : uint16_t {
squeeze_shape = 81,
transpose_shape = 87,
unsqueeze_shape = 93,
lstm = 44,
normal = 47,
normal_like = 48,
uniform = 90,
@ -614,7 +614,9 @@ struct tensor_compare_op_t {
compare_op_t compare_op;
};
struct tensor_concat_op_t {};
struct tensor_concat_op_t {
int32_t axis;
};
struct tensor_condition_op_t {
bool can_fold_const_call;
@ -658,7 +660,9 @@ struct tensor_fix_shape_op_t {};
struct tensor_flatten_op_t {};
struct tensor_gather_op_t {};
struct tensor_gather_op_t {
int32_t axis;
};
struct tensor_gather_elements_op_t {};
@ -685,6 +689,7 @@ struct tensor_l2_normalization_op_t {};
struct tensor_layer_norm_op_t {
int32_t axis;
float epsilon;
bool use_mean;
};
struct tensor_leaky_relu_op_t {};
@ -964,8 +969,6 @@ inline std::string to_string(tensor_function_t tensor_funct) {
return "get_item";
case tensor_function_t::index_of:
return "index_of";
case tensor_function_t::lstm:
return "lstm";
case tensor_function_t::prod:
return "prod";
case tensor_function_t::range:
@ -1020,6 +1023,8 @@ inline std::string to_string(tensor_function_t tensor_funct) {
return "transpose_shape";
case tensor_function_t::unsqueeze_shape:
return "unsqueeze_shape";
case tensor_function_t::lstm:
return "lstm";
case tensor_function_t::normal:
return "normal";
case tensor_function_t::normal_like:

View File

@ -47,8 +47,9 @@ result<value_t> nncase::kernels::stackvm::batch_normalization(
}
result<value_t> nncase::kernels::stackvm::layer_norm(
int32_t axis, float epsilon, value_t input, value_t scale, value_t bias,
value_t output, [[maybe_unused]] kernel_context &context) {
int32_t axis, float epsilon, [[maybe_unused]] bool use_mean, value_t input,
value_t scale, value_t bias, value_t output,
[[maybe_unused]] kernel_context &context) {
try_input(input_mem, input);
try_input(scale_mem, scale);
try_input(bias_mem, bias);
@ -124,7 +125,7 @@ nncase::kernels::stackvm::clamp(value_t input, value_t min, value_t max,
KERNEL_FINISH;
}
result<value_t> nncase::kernels::stackvm::concat(value_t input, value_t axis,
result<value_t> nncase::kernels::stackvm::concat(int32_t axis, value_t input,
value_t output,
kernel_context &context) {
try_tuple_input(inputs_mem, input);
@ -132,7 +133,7 @@ result<value_t> nncase::kernels::stackvm::concat(value_t input, value_t axis,
try_var(strides, get_strides(input_tuple));
try_tuple_field0(input0, input_tuple);
auto dtype = input0->dtype();
try_positive_axis_with_rank(axis_value, axis, input0->shape().size());
auto axis_value = positive_index(axis, input0->shape().size());
auto out_shape = concat_infer_shape(shapes, axis_value);
try_output(out_mem, output, dtype, out_shape);
auto concat_dims = dims_t();
@ -293,14 +294,15 @@ nncase::kernels::stackvm::flatten(value_t input, value_t axis, value_t output,
KERNEL_FINISH;
}
result<value_t> nncase::kernels::stackvm::gather(value_t input, value_t axis,
result<value_t> nncase::kernels::stackvm::gather(int32_t axis, value_t input,
value_t index, value_t output,
kernel_context &context) {
try_input(input_mem, input);
try_input(index_mem, index);
auto dtype = input_tensor->dtype();
try_var(typecode, to_typecode(dtype));
try_positive_axis(axis_value, axis, input_tensor);
// try_positive_axis(axis_value, axis, input_tensor);
auto axis_value = positive_index(axis, input_tensor->shape().size());
auto out_shape = gather_infer_shape(input_tensor->shape(),
index_tensor->shape(), axis_value);
try_output(out_mem, output, dtype, out_shape);

View File

@ -54,6 +54,7 @@ else()
add_library(simulator OBJECT ${SRCS})
target_include_directories(simulator PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(simulator PUBLIC gsl::gsl-lite)
target_link_libraries(simulator PUBLIC fmt::fmt)
target_link_libraries(simulator PRIVATE kernels)
target_compile_definitions(simulator PUBLIC -DNNCASE_DLL -DNNCASE_SIMULATOR)
if (DEFAULT_BUILTIN_RUNTIMES)

View File

@ -246,6 +246,17 @@ result<runtime_module *> interpreter::find_module_by_id(size_t index) noexcept {
return ok(modules_[index].get());
}
result<size_t> interpreter::find_id_by_module(runtime_module *module) noexcept {
auto it = std::find_if(modules_.begin(), modules_.end(),
[&module](const std::unique_ptr<runtime_module> &p) {
return p.get() == module;
});
if (it == modules_.end()) {
return err(std::errc::result_out_of_range);
}
return ok((it - modules_.begin()));
}
options_dict &interpreter::options() noexcept { return options_; }
result<runtime_function *> interpreter::entry_function() noexcept {

View File

@ -189,6 +189,19 @@ runtime_module::find_function_by_id(size_t index) noexcept {
return ok(functions_[index].get());
}
result<size_t>
runtime_module::find_id_by_function(runtime_function *function) noexcept {
auto it =
std::find_if(functions_.begin(), functions_.end(),
[&function](const std::unique_ptr<runtime_function> &p) {
return p.get() == function;
});
if (it == functions_.end()) {
return err(std::errc::result_out_of_range);
}
return ok((it - functions_.begin()));
}
result<void> runtime_module::initialize_before_functions(
NNCASE_UNUSED runtime_module_init_context &context) noexcept {
return ok();

View File

@ -1,5 +1,5 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
* +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM
* +00:00.
*
* Copyright 2019-2021 Canaan Inc.
*

View File

@ -1,5 +1,5 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
* +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:08 AM
* +00:00.
*
* Copyright 2019-2021 Canaan Inc.
*
@ -207,9 +207,7 @@ result<void> stackvm_runtime_function::visit(
dump_op("concat");
try_var(input, pop_value());
dump_input(input);
try_var(axis, pop_value());
dump_input(axis);
try_var(output, kernels::stackvm::concat(input, axis, nullptr,
try_var(output, kernels::stackvm::concat(op.axis, input, nullptr,
module().kernel_context()));
dump_output(output);
stack_.push(std::move(output));
@ -491,11 +489,9 @@ result<void> stackvm_runtime_function::visit(
dump_op("gather");
try_var(input, pop_value());
dump_input(input);
try_var(axis, pop_value());
dump_input(axis);
try_var(index, pop_value());
dump_input(index);
try_var(output, kernels::stackvm::gather(input, axis, index, nullptr,
try_var(output, kernels::stackvm::gather(op.axis, input, index, nullptr,
module().kernel_context()));
dump_output(output);
stack_.push(std::move(output));
@ -683,9 +679,9 @@ result<void> stackvm_runtime_function::visit(
dump_input(scale);
try_var(bias, pop_value());
dump_input(bias);
try_var(output, kernels::stackvm::layer_norm(op.axis, op.epsilon, input,
scale, bias, nullptr,
module().kernel_context()));
try_var(output, kernels::stackvm::layer_norm(
op.axis, op.epsilon, op.use_mean, input, scale, bias,
nullptr, module().kernel_context()));
dump_output(output);
stack_.push(std::move(output));
return ok();

View File

@ -1,5 +1,5 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
* +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM
* +00:00.
*
* Copyright 2019-2021 Canaan Inc.
*

View File

@ -12,6 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <chrono>
#include <cstring>
#include <iostream>
#include <nncase/io_utils.h>
@ -19,6 +20,8 @@
using namespace nncase;
using namespace nncase::runtime;
// constexpr size_t loop_count = 10;
constexpr size_t loop_count = 1;
#define TRY(x) \
if (x) \
@ -34,8 +37,7 @@ result<void> write_tensor_buffer(value_t value, std::ofstream &of) {
}
result<void> run_core(const std::string &kmodel_path,
const std::vector<std::string> &input_bins,
const std::string &output_bin) {
const std::vector<std::string> &bins) {
auto kmodel = read_file(kmodel_path);
interpreter *interp = new interpreter();
// auto dump_path =
@ -47,16 +49,16 @@ result<void> run_core(const std::string &kmodel_path,
try_var(entry, interp->entry_function());
if (entry->parameters_size() != input_bins.size())
if (entry->parameters_size() > bins.size())
return err(std::errc::argument_list_too_long);
/* create the input parameters tensor
note the input tenosr must be contiguous
*/
std::vector<value_t> parameters;
for (int i = 0; i < input_bins.size(); i++) {
for (int i = 0; i < entry->parameters_size(); i++) {
try_var(type, entry->parameter_type(i));
try_var(ts_type, type.as<tensor_type>());
auto input_pool = read_file(input_bins[i]);
auto input_pool = read_file(bins[i]);
gsl::span<gsl::byte> input_pool_span = {
reinterpret_cast<gsl::byte *>(input_pool.data()),
input_pool.size()};
@ -66,21 +68,40 @@ result<void> run_core(const std::string &kmodel_path,
parameters.push_back(_.impl());
}
try_var(ret, entry->invoke({parameters.data(), parameters.size()}));
double total_time = 0.0;
for (size_t i = 0; i < loop_count; i++) {
auto start_time = std::chrono::steady_clock::now();
try_var(ret, entry->invoke({parameters.data(), parameters.size()}));
auto end_time = std::chrono::steady_clock::now();
total_time += (std::chrono::duration_cast<std::chrono::nanoseconds>(
end_time - start_time)
.count() /
1e6);
std::ofstream output_stream(output_bin, std::ios::binary);
if (ret.is_a<tensor>()) {
try_(write_tensor_buffer(ret, output_stream));
} else if (ret.is_a<tuple>()) {
try_var(tp, ret.as<tuple>());
for (auto &&ret_v : tp->fields()) {
try_(write_tensor_buffer(ret_v, output_stream));
if (i == (loop_count - 1) && (entry->parameters_size() < bins.size())) {
if (ret.is_a<tensor>()) {
auto output_bin = bins.back();
std::ofstream output_stream(output_bin, std::ios::binary);
try_(write_tensor_buffer(ret, output_stream));
output_stream.close();
} else if (ret.is_a<tuple>()) {
try_var(tp, ret.as<tuple>());
auto o = 0;
for (auto &&ret_v : tp->fields()) {
auto output_bin = bins[entry->parameters_size() + (o++)];
std::ofstream output_stream(output_bin, std::ios::binary);
try_(write_tensor_buffer(ret_v, output_stream));
output_stream.close();
}
} else {
return nncase::err(std::errc::bad_message);
}
}
} else {
return nncase::err(std::errc::bad_message);
}
output_stream.close();
std::cout << "interp run: " << (total_time / loop_count)
<< " ms, fps = " << 1000 / (total_time / loop_count) << std::endl;
return ok();
}
@ -92,13 +113,12 @@ result<void> run_core(const std::string &kmodel_path,
* @return int
*/
int main(NNCASE_UNUSED int argc, char **argv) {
assert(argc >= 4);
std::vector<std::string> input_bins;
for (int i = 2; i < argc - 1; i++) {
input_bins.push_back(argv[i]);
assert(argc >= 3);
std::vector<std::string> bins;
for (int i = 2; i < argc; i++) {
bins.push_back(argv[i]);
}
std::string kmodel_bin(argv[1]);
std::string output_bin(argv[argc - 1]);
run_core(kmodel_bin, input_bins, output_bin).unwrap_or_throw();
run_core(kmodel_bin, bins).unwrap_or_throw();
return 0;
}
}

View File

@ -1,291 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.CommandLine;
using System.CommandLine.Invocation;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Nncase.CodeGen;
using Nncase.Compiler;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.Passes;
using Nncase.Quantization;
namespace Nncase.Cli.Commands;
internal enum QuantType
{
UInt8,
Int8,
Int16,
}
internal enum DatasetFormat
{
Image,
Raw,
Pytest,
Random,
}
/// <summary>
/// Compile command.
/// </summary>
public sealed class Compile : Command
{
/// <summary>
/// Initializes a new instance of the <see cref="Compile"/> class.
/// </summary>
public Compile()
: base("compile")
{
AddArgument(new Argument("input-file"));
AddArgument(new Argument("output-file"));
AddOption(new Option<string>(
aliases: new string[] { "-t", "--target" },
description: "target architecture, e.g. cpu, k210"));
AddOption(new Option<string>(
aliases: new[] { "-i", "--input-format" },
description: "input format, e.g. tflite",
getDefaultValue: () => "tflite"));
AddOption(new Option<int>(
alias: "--dump-level",
description: $"dump ir to .il, default is {0}",
getDefaultValue: () => 0));
AddOption(new Option<string>(
alias: "--dump-dir",
description: "dump to directory, default is .",
getDefaultValue: () => "."));
AddOption(new Option<QuantType>(
alias: "--quant-type",
description: $"quant type, default is {QuantType.UInt8}",
getDefaultValue: () => QuantType.UInt8));
AddOption(new Option<QuantType>(
alias: "--wquant-type",
description: $"wquant type, default is {QuantType.UInt8}",
getDefaultValue: () => QuantType.UInt8));
AddOption(new Option<string>(
alias: "--dataset",
description: $"calibration dataset, used in post quantization, default is empty",
getDefaultValue: () => string.Empty));
AddOption(new Option<DatasetFormat>(
alias: "--dataset-format",
description: $"datset format: e.g. Image|Raw|Pytest",
getDefaultValue: () => DatasetFormat.Raw));
AddOption(new Option<Quantization.ModelQuantMode>(
alias: "--model-quant-mode",
description: $"model quant mode, default is {Quantization.ModelQuantMode.NoQuant}",
getDefaultValue: () => Quantization.ModelQuantMode.NoQuant));
AddOption(new Option<Quantization.CalibMethod>(
alias: "--calib-method",
description: $"model quant options, default is {Quantization.CalibMethod.Kld}",
getDefaultValue: () => Quantization.CalibMethod.Kld));
AddOption(new Option<bool>(
alias: "--pre-process",
description: "whether enable pre process, default is False",
getDefaultValue: () => false));
AddOption(new Option(
alias: "--input-layout",
description: "the model input data layout, default is empty. eg. NCHW/NHWC",
getDefaultValue: () => string.Empty));
AddOption(new Option(
alias: "--output-layout",
description: "the model output data layout, default is empty. eg. NCHW/NHWC",
getDefaultValue: () => string.Empty));
AddOption(new Option(
alias: "--input-type",
description: "the model input data value type, default is Float32",
getDefaultValue: () => InputType.Float32));
AddOption(new Option<IEnumerable<int>>(
alias: "--input-shape",
description: "the model input data shape, default is []. eg. `--input-shape 1 2 3 4`",
getDefaultValue: () => Array.Empty<int>()));
AddOption(new Option<IEnumerable<float>>(
alias: "--input-range",
description: "the model input data value range, default is []. eg `--input-range -100.3 200.4`",
getDefaultValue: () => Array.Empty<float>()));
AddOption(new Option<bool>(
alias: "--swap-rb",
description: "whether swap the model input data channel R and B",
getDefaultValue: () => false));
AddOption(new Option(
alias: "--letter-box-value",
description: "letterbox value, default 0.0",
getDefaultValue: () => 0.0f));
AddOption(new Option<IEnumerable<float>>(
alias: "--mean",
description: "the model input data mean, default []",
getDefaultValue: () => Array.Empty<float>()));
AddOption(new Option<IEnumerable<float>>(
alias: "--std",
description: "the model input data std, default []",
getDefaultValue: () => Array.Empty<float>()));
AddOption(new Option(
alias: "--model-layout",
description: "the model's input layout, default is empty. eg. NCHW/NHWC",
getDefaultValue: () => string.Empty));
AddOption(new Option<bool>(
alias: "--benchmark-only",
description: $"benchmark only",
getDefaultValue: () => false));
Handler = CommandHandler.Create<CliCompileOptions, IHost>(RunAsync);
}
private static DumpFlags DumpLevelToFlags(int dumpLevel)
{
return dumpLevel switch
{
0 => DumpFlags.None,
1 => DumpLevelToFlags(0) | DumpFlags.Compile,
2 => DumpLevelToFlags(1) | DumpFlags.PassIR,
3 => DumpLevelToFlags(2) | DumpFlags.Rewrite,
4 => DumpLevelToFlags(3) | DumpFlags.EGraphCost,
5 => DumpLevelToFlags(4) | DumpFlags.Evaluator,
6 => DumpLevelToFlags(5) | DumpFlags.Calibration,
7 => DumpLevelToFlags(6) | DumpFlags.Tiling,
8 => DumpLevelToFlags(7) | DumpFlags.Schedule,
>= 9 => DumpLevelToFlags(8) | DumpFlags.CodeGen,
_ => throw new ArgumentOutOfRangeException(nameof(dumpLevel)),
};
}
private async Task RunAsync(CliCompileOptions cliOptions, IHost host)
{
CompilerServices.Configure(host.Services);
// 1. setup the options
var compileOptions = new CompileOptions
{
InputFile = cliOptions.InputFile,
InputFormat = cliOptions.InputFormat,
DumpFlags = DumpLevelToFlags(cliOptions.DumpLevel),
DumpDir = cliOptions.DumpDir,
QuantizeOptions = new()
{
CalibrationMethod = cliOptions.CalibMethod,
QuantType = cliOptions.QuantType switch
{
QuantType.UInt8 => DataTypes.UInt8,
QuantType.Int8 => DataTypes.Int8,
QuantType.Int16 => DataTypes.Int16,
_ => throw new ArgumentException("Invalid quant type"),
},
WQuantType = cliOptions.WQuantType switch
{
QuantType.UInt8 => DataTypes.UInt8,
QuantType.Int8 => DataTypes.Int8,
QuantType.Int16 => DataTypes.Int16,
_ => throw new ArgumentException("Invalid weights quant type"),
},
ModelQuantMode = cliOptions.ModelQuantMode,
},
PreProcess = cliOptions.PreProcess,
InputLayout = cliOptions.InputLayout,
OutputLayout = cliOptions.OutputLayout,
InputType = cliOptions.InputType,
InputShape = cliOptions.InputShape.ToArray(),
InputRange = cliOptions.InputRange.ToArray(),
SwapRB = cliOptions.SwapRB,
LetterBoxValue = cliOptions.LetterBoxValue,
Mean = cliOptions.Mean.ToArray(),
Std = cliOptions.Std.ToArray(),
ModelLayout = cliOptions.ModelLayout,
IsBenchmarkOnly = cliOptions.BenchmarkOnly,
};
// 2. import the model
var target = CompilerServices.GetTarget(cliOptions.Target);
using var compileSession = CompileSession.Create(target, compileOptions);
var compiler = compileSession.Compiler;
var module = await compiler.ImportModuleAsync(compileOptions.InputFormat, compileOptions.InputFile, compileOptions.IsBenchmarkOnly);
// 3. create the calib dataset
if (compileOptions.QuantizeOptions.ModelQuantMode == Quantization.ModelQuantMode.UsePTQ)
{
if (cliOptions.DatasetFormat == DatasetFormat.Random)
{
compileOptions.QuantizeOptions.CalibrationDataset = new RandomCalibrationDatasetProvider(((Function)module.Entry!).Parameters.ToArray(), 5);
}
else if (cliOptions.DatasetFormat == DatasetFormat.Pytest)
{
compileOptions.QuantizeOptions.CalibrationDataset = new PytestCalibrationDatasetProvider(((Function)module.Entry!).Parameters.ToArray(), cliOptions.Dataset);
}
else
{
throw new NotSupportedException(cliOptions.DatasetFormat.ToString());
}
}
// 4. compile
await compiler.CompileAsync();
// 5. code gen
using (var os = File.OpenWrite(cliOptions.OutputFile))
{
compiler.Gencode(os);
}
}
}
// Validate null in command line parser.
#pragma warning disable CS8618
internal sealed class CliCompileOptions
{
public string InputFile { get; set; }
public string InputFormat { get; set; }
public string Target { get; set; }
public int DumpLevel { get; set; }
public string DumpDir { get; set; }
public QuantType QuantType { get; set; }
public QuantType WQuantType { get; set; }
public string OutputFile { get; set; }
public ModelQuantMode ModelQuantMode { get; set; }
public CalibMethod CalibMethod { get; set; }
public string Dataset { get; set; }
public DatasetFormat DatasetFormat { get; set; }
public bool BenchmarkOnly { get; set; }
public bool PreProcess { get; set; }
public string InputLayout { get; set; }
public string OutputLayout { get; set; }
public InputType InputType { get; set; }
public List<int> InputShape { get; set; }
public List<float> InputRange { get; set; }
public bool SwapRB { get; set; }
public float LetterBoxValue { get; set; }
public List<float> Mean { get; set; }
public List<float> Std { get; set; }
public string ModelLayout { get; set; }
}
#pragma warning restore CS8618

217
src/Nncase.Cli/Compile.cs Normal file
View File

@ -0,0 +1,217 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.CommandLine;
using System.Linq;
using Nncase.Diagnostics;
using Nncase.Quantization;
namespace Nncase.Cli;
internal enum QuantType
{
UInt8,
Int8,
Int16,
}
internal enum DatasetFormat
{
Image,
Raw,
Pytest,
Random,
}
/// <summary>
/// Compile command.
/// </summary>
internal sealed class CompileCommand : Command
{
/// <summary>
/// Initializes a new instance of the <see cref="CompileCommand"/> class.
/// </summary>
public CompileCommand()
: base("compile")
{
InputFile = new Argument<string>("input-file");
OutputFile = new Argument<string>("output-file");
InputFormat = new Option<string>(
aliases: new[] { "-i", "--input-format" },
description: "input format, e.g. tflite",
getDefaultValue: () => "tflite");
DumpFlags = new Option<IEnumerable<DumpFlags>>(
name: "--dump-flags",
description: "dump ir flags. \navailable value: None,ImportOps,PassIR,EGraphCost,Rewrite,Calibration,Evaluator,Compile,Tiling,Schedule,CodeGen.")
{
AllowMultipleArgumentsPerToken = true,
};
DumpDir = new Option<string>(
name: "--dump-dir",
description: "dump to directory.",
getDefaultValue: () => ".");
QuantType = new Option<QuantType>(
name: "--quant-type",
description: $"quant type",
getDefaultValue: () => Nncase.Cli.QuantType.UInt8);
WQuantType = new Option<QuantType>(
name: "--wquant-type",
description: $"wquant type",
getDefaultValue: () => Nncase.Cli.QuantType.UInt8);
Dataset = new Option<string>(
name: "--dataset",
description: $"calibration dataset, used in post quantization",
getDefaultValue: () => string.Empty);
DatasetFormat = new Option<DatasetFormat>(
name: "--dataset-format",
description: $"datset format.",
getDefaultValue: () => Nncase.Cli.DatasetFormat.Raw);
ModelQuantMode = new Option<Quantization.ModelQuantMode>(
name: "--model-quant-mode",
description: $"model quant mode",
getDefaultValue: () => Quantization.ModelQuantMode.NoQuant);
CalibMethod = new Option<Quantization.CalibMethod>(
name: "--calib-method",
description: $"model quant options",
getDefaultValue: () => Quantization.CalibMethod.Kld);
FixedVars = new Option<IEnumerable<(string, int)>>(
name: "--fixed-vars",
description: $"dynamic shape fixed vars, default is empty. \nset by `n:123`",
parseArgument: result =>
{
return result.Tokens.
Select(tk => tk.Value.Split(":").ToArray()).
Select(tp => (tp[0].Trim(), int.Parse(tp[1].Trim())));
})
{
AllowMultipleArgumentsPerToken = true,
};
PreProcess = new Option<bool>(
name: "--pre-process",
description: "whether enable pre process",
getDefaultValue: () => false);
InputLayout = new Option<string>(
name: "--input-layout",
description: "the model input data layout",
getDefaultValue: () => string.Empty).FromAmong("NCHW", "NHWC");
OutputLayout = new Option<string>(
name: "--output-layout",
description: "the model output data layout.",
getDefaultValue: () => string.Empty).FromAmong("NCHW", "NHWC");
InputType = new Option<InputType>(
name: "--input-type",
description: "the model input data value type, default is Float32",
getDefaultValue: () => Nncase.InputType.Float32);
InputShape = new Option<IEnumerable<int>>(
name: "--input-shape",
description: "the model input data shape. eg. `--input-shape 1 2 3 4`",
getDefaultValue: Array.Empty<int>)
{
AllowMultipleArgumentsPerToken = true,
};
InputRange = new Option<IEnumerable<float>>(
name: "--input-range",
description: "the model input data value range. eg `--input-range -100.3 200.4`",
getDefaultValue: Array.Empty<float>)
{
AllowMultipleArgumentsPerToken = true,
};
SwapRB = new Option<bool>(
name: "--swap-rb",
description: "whether swap the model input data channel, like cv2.BGRtoRGB(im)",
getDefaultValue: () => false);
LetterBoxValue = new Option<float>(
name: "--letter-box-value",
description: "letterbox fill value",
getDefaultValue: () => 0.0f);
Mean = new Option<IEnumerable<float>>(
name: "--mean",
description: "the model input data mean, default []",
getDefaultValue: Array.Empty<float>)
{
AllowMultipleArgumentsPerToken = true,
};
Std = new Option<IEnumerable<float>>(
name: "--std",
description: "the model input data std, default []",
getDefaultValue: Array.Empty<float>)
{
AllowMultipleArgumentsPerToken = true,
};
ModelLayout = new Option<string>(
name: "--model-layout",
description: "the model's input layout.",
getDefaultValue: () => string.Empty).FromAmong("NCHW", "NHWC");
AddArgument(InputFile);
AddArgument(OutputFile);
AddGlobalOption(InputFormat);
AddGlobalOption(DumpFlags);
AddGlobalOption(DumpDir);
AddGlobalOption(QuantType);
AddGlobalOption(WQuantType);
AddGlobalOption(Dataset);
AddGlobalOption(DatasetFormat);
AddGlobalOption(ModelQuantMode);
AddGlobalOption(CalibMethod);
AddGlobalOption(FixedVars);
AddGlobalOption(PreProcess);
AddGlobalOption(InputLayout);
AddGlobalOption(OutputLayout);
AddGlobalOption(InputType);
AddGlobalOption(InputShape);
AddGlobalOption(InputRange);
AddGlobalOption(SwapRB);
AddGlobalOption(LetterBoxValue);
AddGlobalOption(Mean);
AddGlobalOption(Std);
AddGlobalOption(ModelLayout);
}
public Argument<string> InputFile { get; }
public Argument<string> OutputFile { get; }
public Option<string> InputFormat { get; }
public Option<IEnumerable<DumpFlags>> DumpFlags { get; }
public Option<string> DumpDir { get; }
public Option<QuantType> QuantType { get; }
public Option<QuantType> WQuantType { get; }
public Option<string> Dataset { get; }
public Option<DatasetFormat> DatasetFormat { get; }
public Option<ModelQuantMode> ModelQuantMode { get; }
public Option<CalibMethod> CalibMethod { get; }
public Option<IEnumerable<(string Name, int Value)>> FixedVars { get; }
public Option<bool> PreProcess { get; }
public Option<string> InputLayout { get; }
public Option<string> OutputLayout { get; }
public Option<InputType> InputType { get; }
public Option<IEnumerable<int>> InputShape { get; }
public Option<IEnumerable<float>> InputRange { get; }
public Option<bool> SwapRB { get; }
public Option<float> LetterBoxValue { get; }
public Option<IEnumerable<float>> Mean { get; }
public Option<IEnumerable<float>> Std { get; }
public Option<string> ModelLayout { get; }
}

View File

@ -26,4 +26,8 @@
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>
<ItemGroup>
<Folder Include="Properties\" />
</ItemGroup>
</Project>

View File

@ -1,26 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.CommandLine;
using System.CommandLine.Builder;
using System.Linq;
namespace Nncase.Cli;
internal partial class Program
{
private static CommandLineBuilder BuildCommandLine()
{
var commands = from t in typeof(Program).Assembly.ExportedTypes
where t.Namespace == "Nncase.Cli.Commands" && t.IsAssignableTo(typeof(Command))
select (Command)Activator.CreateInstance(t)!;
var root = new RootCommand();
foreach (var command in commands)
{
root.AddCommand(command);
}
return new CommandLineBuilder(root);
}
}

View File

@ -1,10 +1,14 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.CommandLine;
using System.CommandLine.Builder;
using System.CommandLine.Hosting;
using System.CommandLine.Parsing;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Hosting;
@ -16,12 +20,155 @@ internal partial class Program
{
public static async Task<int> Main(string[] args)
{
return await BuildCommandLine()
return await ConfigureCommandLine()
.UseHost(ConfigureHost)
.UseDefaults()
.Build().InvokeAsync(args);
}
private static async Task RunAsync(string targetKind, CompileOptions compileOptions, DatasetFormat datasetFormat, string dataset, string outputFile, IHost host)
{
CompilerServices.Configure(host.Services);
// 2. import the model
var target = CompilerServices.GetTarget(targetKind);
using var compileSession = CompileSession.Create(target, compileOptions);
var compiler = compileSession.Compiler;
IR.IRModule module = await compiler.ImportModuleAsync(Path.GetExtension(compileOptions.InputFile).Trim('.'), compileOptions.InputFile);
// 3. create the calib dataset
if (compileOptions.QuantizeOptions.ModelQuantMode == Quantization.ModelQuantMode.UsePTQ)
{
if (datasetFormat == DatasetFormat.Random)
{
compileOptions.QuantizeOptions.CalibrationDataset = new Quantization.RandomCalibrationDatasetProvider(((Nncase.IR.Function)module.Entry!).Parameters.ToArray(), 5);
}
else if (datasetFormat == DatasetFormat.Pytest)
{
compileOptions.QuantizeOptions.CalibrationDataset = new Quantization.PytestCalibrationDatasetProvider(((IR.Function)module.Entry!).Parameters.ToArray(), dataset);
}
else
{
throw new NotSupportedException(datasetFormat.ToString());
}
}
// 4. compile
await compiler.CompileAsync();
// 5. code gen
using (var os = File.OpenWrite(outputFile))
{
compiler.Gencode(os);
}
}
private static CommandLineBuilder ConfigureCommandLine()
{
var compile = new CompileCommand();
foreach (var target in LoadTargets())
{
var (targetCmd, targetParser) = target.RegisterCommandAndParser();
Action<System.CommandLine.Invocation.InvocationContext> targetHandler = async (System.CommandLine.Invocation.InvocationContext context) =>
{
var options = ParseCompileOptions(context, compile);
options.TargetCompileOptions = targetParser(context, targetCmd);
await RunAsync(targetCmd.Name, options, context.ParseResult.GetValueForOption(compile.DatasetFormat), context.ParseResult.GetValueForOption(compile.Dataset)!, context.ParseResult.GetValueForArgument(compile.OutputFile), context.GetHost());
};
targetCmd.SetHandler(targetHandler);
compile.AddCommand(targetCmd);
}
return new CommandLineBuilder(new RootCommand() { compile });
}
private static CompileOptions ParseCompileOptions(System.CommandLine.Invocation.InvocationContext context, CompileCommand compilecmd)
{
// 1. setup the options
var compileOptions = new CompileOptions
{
InputFile = context.ParseResult.GetValueForArgument(compilecmd.InputFile),
InputFormat = context.ParseResult.GetValueForOption(compilecmd.InputFormat)!,
DumpFlags = context.ParseResult.GetValueForOption(compilecmd.DumpFlags)!.Aggregate(Diagnostics.DumpFlags.None, (a, b) => a | b),
DumpDir = context.ParseResult.GetValueForOption(compilecmd.DumpDir)!,
PreProcess = context.ParseResult.GetValueForOption(compilecmd.PreProcess)!,
InputLayout = context.ParseResult.GetValueForOption(compilecmd.InputLayout)!,
OutputLayout = context.ParseResult.GetValueForOption(compilecmd.OutputLayout)!,
InputType = context.ParseResult.GetValueForOption(compilecmd.InputType)!,
InputShape = context.ParseResult.GetValueForOption(compilecmd.InputShape)!.ToArray(),
InputRange = context.ParseResult.GetValueForOption(compilecmd.InputRange)!.ToArray(),
SwapRB = context.ParseResult.GetValueForOption(compilecmd.SwapRB)!,
LetterBoxValue = context.ParseResult.GetValueForOption(compilecmd.LetterBoxValue)!,
Mean = context.ParseResult.GetValueForOption(compilecmd.Mean)!.ToArray(),
Std = context.ParseResult.GetValueForOption(compilecmd.Std)!.ToArray(),
ModelLayout = context.ParseResult.GetValueForOption(compilecmd.ModelLayout)!,
QuantizeOptions = new()
{
CalibrationMethod = context.ParseResult.GetValueForOption(compilecmd.CalibMethod),
QuantType = context.ParseResult.GetValueForOption(compilecmd.QuantType) switch
{
QuantType.UInt8 => DataTypes.UInt8,
QuantType.Int8 => DataTypes.Int8,
QuantType.Int16 => DataTypes.Int16,
_ => throw new ArgumentException("Invalid quant type"),
},
WQuantType = context.ParseResult.GetValueForOption(compilecmd.WQuantType) switch
{
QuantType.UInt8 => DataTypes.UInt8,
QuantType.Int8 => DataTypes.Int8,
QuantType.Int16 => DataTypes.Int16,
_ => throw new ArgumentException("Invalid weights quant type"),
},
ModelQuantMode = context.ParseResult.GetValueForOption(compilecmd.ModelQuantMode),
},
};
foreach (var item in context.ParseResult.GetValueForOption(compilecmd.FixedVars)!)
{
compileOptions.ShapeBucketOptions.FixVarMap.Add(item.Name, item.Value);
}
return compileOptions;
}
private static IReadOnlyList<ITarget> LoadTargets()
{
var loadContext = System.Runtime.Loader.AssemblyLoadContext.Default;
var pluginAsms = PluginLoader.GetPluginsSearchDirectories(PluginLoader.PluginPathEnvName, null).
Select(PluginLoader.GetPluginAssemblies).
SelectMany(x => x).
DistinctBy(Path.GetFileName).
Select(x => PluginLoader.LoadPluginAssembly(x, loadContext)).
Distinct().
ToList();
pluginAsms.AddRange(new[] { Path.GetDirectoryName(typeof(Program).Assembly.Location)! }.
Select(basePath =>
{
if (Directory.Exists(basePath))
{
return (from filePath in Directory.GetFiles(basePath, PluginLoader.ModulesDllPattern, SearchOption.AllDirectories)
where PluginLoader.IsLoadableAssembly(filePath)
select filePath).Distinct();
}
else
{
return Array.Empty<string>();
}
}).
SelectMany(x => x).
DistinctBy(Path.GetFileName).
Select(x => PluginLoader.LoadPluginAssembly(x, loadContext)).
Distinct());
var targets = (from asm in pluginAsms
from t in asm.ExportedTypes
where t.IsClass
&& t.IsAssignableTo(typeof(ITarget))
let ctor = t.GetConstructor(Type.EmptyTypes)
where ctor != null
select (ITarget)ctor.Invoke(null)).ToList();
return targets;
}
private static void ConfigureHost(IHostBuilder hostBuilder)
{
hostBuilder.ConfigureAppConfiguration(ConfigureAppConfiguration)

View File

@ -33,21 +33,22 @@
},
"StyleCop.Analyzers": {
"type": "Direct",
"requested": "[1.2.0-beta.507, )",
"resolved": "1.2.0-beta.507",
"contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==",
"requested": "[1.2.0-beta.435, )",
"resolved": "1.2.0-beta.435",
"contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==",
"dependencies": {
"StyleCop.Analyzers.Unstable": "1.2.0.507"
"StyleCop.Analyzers.Unstable": "1.2.0.435"
}
},
"System.CommandLine.Hosting": {
"type": "Direct",
"requested": "[0.3.0-alpha.21216.1, )",
"resolved": "0.3.0-alpha.21216.1",
"contentHash": "zP8QEUH8dSUYUHdGk6k71kOJy8uFgEPZG2RfhA0cMjDH3/Jov5AjUNaxOvpSNHh+ewu8eIUCYgV8+fEkCPyNlw==",
"requested": "[0.4.0-alpha.22272.1, )",
"resolved": "0.4.0-alpha.22272.1",
"contentHash": "x9JhHxBLxlKyCIZADFYC8q16L9yGHdTakrLFjHabwR7Tk0761aTexiGgMTIS744HGuhc8pk9MoLUzsr/TlRfMQ==",
"dependencies": {
"Microsoft.Extensions.Hosting": "3.1.5",
"System.CommandLine": "2.0.0-beta1.21216.1"
"Microsoft.Extensions.Hosting": "6.0.0",
"System.CommandLine": "2.0.0-beta4.22272.1",
"System.CommandLine.NamingConventionBinder": "2.0.0-beta4.22272.1"
}
},
"Google.OrTools.runtime.linux-arm64": {
@ -344,8 +345,8 @@
},
"StyleCop.Analyzers.Unstable": {
"type": "Transitive",
"resolved": "1.2.0.507",
"contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw=="
"resolved": "1.2.0.435",
"contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg=="
},
"System.Buffers": {
"type": "Transitive",
@ -362,13 +363,12 @@
"System.Runtime": "4.3.0"
}
},
"System.CommandLine": {
"System.CommandLine.NamingConventionBinder": {
"type": "Transitive",
"resolved": "2.0.0-beta1.21216.1",
"contentHash": "Nbv/tW8sbOKN5T+4SSVBMdk4ADSIpJpY4UHMsj3VkcNtOckIT4iyzagjF+W5FEh2YBRvmvVQijOTIZbUJ1+1aA==",
"resolved": "2.0.0-beta4.22272.1",
"contentHash": "ux2eUA/syF+JtlpMDc/Lsd6PBIBuwjH3AvHnestoh5uD0WKT5b+wkQxDWVCqp9qgVjMBTLNhX19ZYFtenunt9A==",
"dependencies": {
"Microsoft.CSharp": "4.4.1",
"system.memory": "4.5.4"
"System.CommandLine": "2.0.0-beta4.22272.1"
}
},
"System.Diagnostics.Contracts": {
@ -696,6 +696,7 @@
"Microsoft.Extensions.Options": "[6.0.0, )",
"Microsoft.Toolkit.HighPerformance": "[7.1.1, )",
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
"System.CommandLine": "[2.0.0-beta4.22272.1, )",
"System.Reactive": "[5.0.0, )"
}
},
@ -937,6 +938,12 @@
"resolved": "1.0.2",
"contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA=="
},
"System.CommandLine": {
"type": "CentralTransitive",
"requested": "[2.0.0-beta4.22272.1, )",
"resolved": "2.0.0-beta4.22272.1",
"contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg=="
},
"System.Linq.Async": {
"type": "CentralTransitive",
"requested": "[6.0.1, )",

View File

@ -15,7 +15,6 @@ public class LinkedFunction : ILinkedFunction
public LinkedFunction(uint id, Callable sourceFunction, ulong textBegin, ulong textLength, IReadOnlyList<ILinkedSection> sections)
{
Id = id;
CompilerServices.InferenceType(sourceFunction);
ParameterTypes = ((CallableType)sourceFunction.CheckedType).Parameters.ToArray();
ReturnType = ((CallableType)sourceFunction.CheckedType).ReturnType;
TextBegin = textBegin;

View File

@ -10,11 +10,11 @@
},
"StyleCop.Analyzers": {
"type": "Direct",
"requested": "[1.2.0-beta.507, )",
"resolved": "1.2.0-beta.507",
"contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==",
"requested": "[1.2.0-beta.435, )",
"resolved": "1.2.0-beta.435",
"contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==",
"dependencies": {
"StyleCop.Analyzers.Unstable": "1.2.0.507"
"StyleCop.Analyzers.Unstable": "1.2.0.435"
}
},
"Microsoft.Extensions.Configuration.Abstractions": {
@ -53,8 +53,8 @@
},
"StyleCop.Analyzers.Unstable": {
"type": "Transitive",
"resolved": "1.2.0.507",
"contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw=="
"resolved": "1.2.0.435",
"contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg=="
},
"System.Buffers": {
"type": "Transitive",
@ -76,6 +76,7 @@
"Microsoft.Extensions.Options": "[6.0.0, )",
"Microsoft.Toolkit.HighPerformance": "[7.1.1, )",
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
"System.CommandLine": "[2.0.0-beta4.22272.1, )",
"System.Reactive": "[5.0.0, )"
}
},
@ -138,6 +139,12 @@
"System.Runtime.CompilerServices.Unsafe": "5.0.0"
}
},
"System.CommandLine": {
"type": "CentralTransitive",
"requested": "[2.0.0-beta4.22272.1, )",
"resolved": "2.0.0-beta4.22272.1",
"contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg=="
},
"System.Reactive": {
"type": "CentralTransitive",
"requested": "[5.0.0, )",

View File

@ -88,13 +88,15 @@ internal class Compiler : ICompiler
public void TargetIndependentPass(IPassManager passManager)
{
passManager.AddWithName<DataflowPass>("ReshapeMatMul").Configure(p =>
passManager.AddWithName<DataflowPass>("NormAxisAndShape").Configure(p =>
{
p.Add<Passes.Rules.Neutral.ReshapeMatMul>();
});
passManager.AddWithName<DataflowPass>("SqueezeShape").Configure(p =>
{
p.Add<Passes.Rules.Neutral.NormAxisGather>();
p.Add<Passes.Rules.Neutral.NormAxisConcat>();
p.Add<Passes.Rules.Neutral.NormAxisReduce>();
p.Add<Passes.Rules.Neutral.NormAxisReshape>();
p.Add<Passes.Rules.Neutral.NormAxisReduceArg>();
p.Add<Passes.Rules.Neutral.NormAxisSlice>();
p.Add<Passes.Rules.Neutral.SqueezeTransposeShape>();
p.Add<Passes.Rules.Neutral.Squeeze5DTranspose>();
p.Add<Passes.Rules.Neutral.SqueezeBinaryShape>();
@ -102,6 +104,7 @@ internal class Compiler : ICompiler
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern2>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern3>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern4>();
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern5>();
p.Add<Passes.Rules.Neutral.FoldGeluWithScale>();
p.Add<Passes.Rules.Neutral.FoldGeneralGelu>();
p.Add<Passes.Rules.Neutral.FoldSwishPattern1>();
@ -124,6 +127,7 @@ internal class Compiler : ICompiler
p.Add<Passes.Rules.Neutral.FoldNopCast>();
p.Add<Passes.Rules.Neutral.FoldNopReshape>();
p.Add<Passes.Rules.Neutral.FoldNopSlice>();
p.Add<Passes.Rules.Neutral.FoldPrePostReshapeSoftmax>();
p.Add<Passes.Rules.Neutral.FoldSqueezeUnsqueeze>();
p.Add<Passes.Rules.Neutral.FoldUnsqueezeSqueeze>();
p.Add<Passes.Rules.Neutral.FoldTwoTransposes>();
@ -157,6 +161,8 @@ internal class Compiler : ICompiler
p.Add<Passes.Rules.Neutral.CombineUnaryReshape>();
p.Add<Passes.Rules.Neutral.CombineActivationsReshape>();
p.Add<Passes.Rules.Neutral.CombineReshapePad>();
p.Add<Passes.Rules.Neutral.CombineReshapeTranspose>();
p.Add<Passes.Rules.Neutral.CombineTransposeReshape>();
p.Add<Passes.Rules.Neutral.FoldNopPad>();
p.Add<Passes.Rules.Neutral.FoldConv2DPads>();
p.Add<Passes.Rules.Neutral.FuseClampConv2D>();
@ -168,6 +174,7 @@ internal class Compiler : ICompiler
p.Add<Passes.Rules.Neutral.ReshapeToTranspose>();
p.Add<Passes.Rules.Neutral.FoldNopReshape>();
p.Add<Passes.Rules.Neutral.FoldTwoReshapes>();
p.Add<Passes.Rules.Neutral.FoldReshapeBinaryConstReshape>();
p.Add<Passes.Rules.Neutral.ReluToClamp>();
p.Add<Passes.Rules.Neutral.Relu6ToClamp>();
p.Add<Passes.Rules.Neutral.FoldNopSlice>();

View File

@ -19,12 +19,14 @@ namespace Nncase.Hosting;
/// </summary>
public sealed class PluginLoader
{
private const string _modulesDllPattern = "Nncase.Modules.*.dll";
private const string _pluginPathEnvName = "NNCASE_PLUGIN_PATH";
public const string PluginPathEnvName = "NNCASE_PLUGIN_PATH";
public const string ModulesDllPattern = "Nncase.Modules.*.dll";
private static readonly string[] _builtinModules = new[]
{
"Nncase.Modules.StackVM.dll",
"Nncase.Modules.CPU.dll",
"Nncase.Modules.K210.dll",
};
@ -42,26 +44,60 @@ public sealed class PluginLoader
?? AssemblyLoadContext.Default;
}
/// <summary>
/// Load plugins.
/// </summary>
/// <returns>Plugins.</returns>
public IReadOnlyList<IPlugin> LoadPlugins()
public static Assembly LoadPluginAssembly(string assemblyFile, AssemblyLoadContext loadContext)
{
var pluginAsms = GetPluginsSearchDirectories().Select(GetPluginAssemblies).SelectMany(x => x)
.DistinctBy(Path.GetFileName).Select(LoadPluginAssembly).Distinct().ToList();
var plugins = (from asm in pluginAsms
from t in asm.ExportedTypes
where t.IsClass
&& t.IsAssignableTo(typeof(IPlugin))
let ctor = t.GetConstructor(Type.EmptyTypes)
where ctor != null
select (IPlugin)ctor.Invoke(null)).ToList();
return plugins;
return loadContext.LoadFromAssemblyPath(assemblyFile);
}
private static bool IsLoadableAssembly(string filePath)
public static IEnumerable<string> GetPluginAssemblies(string basePath)
{
if (Directory.Exists(basePath))
{
return (from filePath in Directory.GetFiles(basePath, ModulesDllPattern, SearchOption.AllDirectories)
where !_builtinModules.Contains(Path.GetFileName(filePath))
&& IsLoadableAssembly(filePath)
select filePath).Distinct();
}
else
{
return Array.Empty<string>();
}
}
public static IEnumerable<string> GetPluginsSearchDirectories(string pluginPathEnvName, ILogger? logger)
{
var directories = new List<string>();
// 1. Environment variable
var targetPathEnv = Environment.GetEnvironmentVariable(pluginPathEnvName);
if (string.IsNullOrWhiteSpace(targetPathEnv))
{
if (logger is not null)
{
logger.LogWarning($"{pluginPathEnvName} is not set.");
}
}
else
{
var targetPaths = from path in targetPathEnv!.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries)
select Environment.ExpandEnvironmentVariables(path);
directories.AddRange(targetPaths);
}
// 2. Python nncase modules
var rootPath = Path.GetDirectoryName(typeof(PluginLoader).Assembly.Location)!;
var modulesPath = Path.Combine(rootPath, "modules");
directories.Add(modulesPath);
if (logger is not null && logger.IsEnabled(LogLevel.Trace))
{
logger.LogInformation($"Loading plugins from {string.Join(", ", directories)}.");
}
return directories.Distinct();
}
public static bool IsLoadableAssembly(string filePath)
{
using var fs = File.OpenRead(filePath);
using var peReader = new PEReader(fs);
@ -93,53 +129,22 @@ public sealed class PluginLoader
return true;
}
private Assembly LoadPluginAssembly(string assemblyFile)
/// <summary>
/// Load plugins.
/// </summary>
/// <returns>Plugins.</returns>
public IReadOnlyList<IPlugin> LoadPlugins()
{
return _loadContext.LoadFromAssemblyPath(assemblyFile);
}
var pluginAsms = GetPluginsSearchDirectories(PluginPathEnvName, _logger).Select(GetPluginAssemblies).SelectMany(x => x)
.DistinctBy(Path.GetFileName).Select(x => LoadPluginAssembly(x, _loadContext)).Distinct().ToList();
var plugins = (from asm in pluginAsms
from t in asm.ExportedTypes
where t.IsClass
&& t.IsAssignableTo(typeof(IPlugin))
let ctor = t.GetConstructor(Type.EmptyTypes)
where ctor != null
select (IPlugin)ctor.Invoke(null)).ToList();
private IEnumerable<string> GetPluginAssemblies(string basePath)
{
if (Directory.Exists(basePath))
{
return (from filePath in Directory.GetFiles(basePath, _modulesDllPattern, SearchOption.AllDirectories)
where !_builtinModules.Contains(Path.GetFileName(filePath))
&& IsLoadableAssembly(filePath)
select filePath).Distinct();
}
else
{
return Array.Empty<string>();
}
}
private IEnumerable<string> GetPluginsSearchDirectories()
{
var directories = new List<string>();
// 1. Environment variable
var targetPathEnv = Environment.GetEnvironmentVariable(_pluginPathEnvName);
if (string.IsNullOrWhiteSpace(targetPathEnv))
{
_logger.LogWarning($"{_pluginPathEnvName} is not set.");
}
else
{
var targetPaths = from path in targetPathEnv.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries)
select Environment.ExpandEnvironmentVariables(path);
directories.AddRange(targetPaths);
}
// 2. Python nncase modules
var rootPath = Path.GetDirectoryName(typeof(PluginLoader).Assembly.Location)!;
var modulesPath = Path.Combine(rootPath, "modules");
directories.Add(modulesPath);
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.LogInformation($"Loading plugins from {string.Join(", ", directories)}.");
}
return directories.Distinct();
return plugins;
}
}

View File

@ -49,11 +49,11 @@
},
"StyleCop.Analyzers": {
"type": "Direct",
"requested": "[1.2.0-beta.507, )",
"resolved": "1.2.0-beta.507",
"contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==",
"requested": "[1.2.0-beta.435, )",
"resolved": "1.2.0-beta.435",
"contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==",
"dependencies": {
"StyleCop.Analyzers.Unstable": "1.2.0.507"
"StyleCop.Analyzers.Unstable": "1.2.0.435"
}
},
"Google.OrTools.runtime.linux-arm64": {
@ -350,8 +350,8 @@
},
"StyleCop.Analyzers.Unstable": {
"type": "Transitive",
"resolved": "1.2.0.507",
"contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw=="
"resolved": "1.2.0.435",
"contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg=="
},
"System.Buffers": {
"type": "Transitive",
@ -674,6 +674,7 @@
"Microsoft.Extensions.Options": "[6.0.0, )",
"Microsoft.Toolkit.HighPerformance": "[7.1.1, )",
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
"System.CommandLine": "[2.0.0-beta4.22272.1, )",
"System.Reactive": "[5.0.0, )"
}
},
@ -885,6 +886,12 @@
"resolved": "1.0.2",
"contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA=="
},
"System.CommandLine": {
"type": "CentralTransitive",
"requested": "[2.0.0-beta4.22272.1, )",
"resolved": "2.0.0-beta4.22272.1",
"contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg=="
},
"System.Linq.Async": {
"type": "CentralTransitive",
"requested": "[6.0.1, )",

View File

@ -119,4 +119,9 @@ public sealed record CompileOptions
/// Gets or sets a value indicating whether is benchmark only.
/// </summary>
public bool IsBenchmarkOnly { get; set; }
/// <summary>
/// Gets or sets the target compile options.
/// </summary>
public ITargetCompileOptions TargetCompileOptions { get; set; } = null!;
}

View File

@ -73,6 +73,14 @@ public interface ICompilerServicesProvider
/// <param name="randConst">false for save const into bin.</param>
public void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst);
/// <summary>
/// dump the expr as csharp code.
/// </summary>
/// <param name="expr">expression.</param>
/// <param name="prefix">file prefix.</param>
/// <param name="dumpDir">file dump ir.</param>
public void DumpPatternIR(Expr expr, string prefix, string dumpDir);
/// <summary>
/// print ir type.
/// </summary>
@ -468,6 +476,15 @@ public static class CompilerServices
public static void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst = true) =>
Provider.DumpCSharpIR(expr, prefix, dumpDir, randConst);
/// <summary>
/// dump the expr as csharp code.
/// </summary>
/// <param name="expr">expression.</param>
/// <param name="prefix">file prefix.</param>
/// <param name="dumpDir">file dump ir.</param>
public static void DumpPatternIR(Expr expr, string prefix, string dumpDir) =>
Provider.DumpPatternIR(expr, prefix, dumpDir);
public static string Print(IRType type) => Provider.Print(type);
public static string Print(Expr expr, bool useScript = false) => Provider.Print(expr, useScript);
@ -583,6 +600,10 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
public void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst) =>
_irprinterProvider.DumpCSharpIR(expr, prefix, dumpDir, randConst);
/// <inheritdoc/>
public void DumpPatternIR(Expr expr, string prefix, string dumpDir) =>
_irprinterProvider.DumpPatternIR(expr, prefix, dumpDir);
/// <inheritdoc/>
public string Print(IRType type) => _irprinterProvider.Print(type);

View File

@ -28,5 +28,6 @@ internal class ConvertersModule : IApplicationPart
registrator.RegisterManyInterface<UInt64Converters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<UInt8Converters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<PointerConverters>(reuse: Reuse.Singleton);
registrator.RegisterManyInterface<PointerIntConverters>(reuse: Reuse.Singleton);
}
}

View File

@ -30,3 +30,25 @@ internal class PointerConverters : IPointerSpanConverter<ulong>
}
}
}
internal class PointerIntConverters : IPointerSpanConverter<int>
{
public void ConvertTo<T>(ReadOnlySpan<Pointer<T>> source, Span<int> dest, CastMode castMode)
where T : unmanaged, IEquatable<T>
{
if (castMode != CastMode.KDefault)
{
throw new InvalidCastException();
}
if (dest.Length < source.Length)
{
throw new ArgumentException("Dest buffer is not sufficient.");
}
for (int i = 0; i < source.Length; i++)
{
dest[i] = checked((int)source[i].Value);
}
}
}

View File

@ -204,6 +204,7 @@ public static class CostUtility
{
TensorType t => (UInt128)(t.Shape.Aggregate(1D, (acc, x) => acc * (x.IsFixed ? x.FixedValue : 1)) * t.DType.SizeInBytes),
TupleType t => t.Fields.Sum(GetMemoryAccess),
DistributedType t => GetMemoryAccess(Utilities.DistributedUtility.GetDividedTensorType(t)),
_ => 0,
};
}
@ -229,6 +230,7 @@ public static class CostUtility
{
TensorType t => (UInt128)(t.Shape.Aggregate(1D, (acc, x) => acc * (x.IsFixed ? x.FixedValue : 1)) * cyclesPerElement),
TupleType t => t.Fields.Sum(GetMemoryAccess),
DistributedType t => GetCPUCycles(Utilities.DistributedUtility.GetDividedTensorType(t)),
_ => 0,
};
}
@ -328,7 +330,7 @@ public static class CostUtility
}
// cost for op similar to broadcast
public static Cost GetBroadcastCost(TensorType input, TensorType ret)
public static Cost GetBroadcastCost(IRType input, IRType ret)
{
return new()
{

View File

@ -114,7 +114,7 @@ public static class DataTypes
/// <returns> datatype name.</returns>
public static string GetDisplayName(this DataType dataType) => dataType switch
{
PointerType pointerType => $"({GetDisplayName(pointerType.ElemType)}*)",
PointerType pointerType => $"({GetDisplayName(pointerType.ElemType)} *)",
PrimType primType => primType.ShortName,
ValueType => dataType.ToString(),
_ => throw new ArgumentOutOfRangeException(dataType.GetType().Name),

View File

@ -42,6 +42,8 @@ public interface IDumpper
void DumpCSharpIR(Expr expr, string prefix, string? reletivePath = null);
void DumpPatternIR(Expr expr, string prefix, string? reletivePath = null);
void DumpModule(IRModule module, string? reletivePath = null);
Stream OpenFile(string reletivePath, FileMode fileMode = FileMode.Create);

View File

@ -46,6 +46,11 @@ public sealed class NullDumpper : IDumpper
{
}
/// <inheritdoc/>
public void DumpPatternIR(Expr expr, string prefix, string? reletivePath = null)
{
}
/// <inheritdoc/>
public bool IsEnabled(DumpFlags dumpFlags) => false;

View File

@ -0,0 +1,68 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using DryIoc.ImTools;
namespace Nncase.IR;
public abstract record SBP
{
public static SBPPartialSum P => SBPPartialSum.Instance;
public static SBPBroadCast B => SBPBroadCast.Instance;
public static SBPSplit S(int axis) => new SBPSplit(axis);
}
public sealed record SBPSplit(int Axis) : SBP
{
public override string ToString() => $"S({Axis})";
}
public sealed record SBPPartialSum : SBP
{
public static readonly SBPPartialSum Instance = new SBPPartialSum();
private SBPPartialSum()
{
}
public override string ToString() => "P";
}
public sealed record SBPBroadCast : SBP
{
public static readonly SBPBroadCast Instance = new SBPBroadCast();
private SBPBroadCast()
{
}
public override string ToString() => "B";
}
// public sealed record Placement(Placement.DeviceKind Kind, IRArray<int> Hierarchy, string Name)
public sealed record Placement(IRArray<int> Hierarchy, string Name)
{
// public enum DeviceKind : uint
// {
// CPU = 0,
// }
public int Rank => Hierarchy.Count;
// public override string ToString() => $"@{Kind} [{string.Join(',', Hierarchy.Zip(Name).Select(t => t.First.ToString() + '@' + t.Second.ToString()))}]";
public override string ToString() => $"@ [{string.Join(',', Hierarchy.Zip(Name).Select(t => t.First.ToString() + '@' + t.Second.ToString()))}]";
}
public sealed record DistributedType(TensorType TensorType, IRArray<SBP> NdSBP, Placement Placement) : IRType
{
public override string ToString() => $"{TensorType}, ({string.Join(',', NdSBP)}), {Placement}";
}

View File

@ -17,7 +17,7 @@ namespace Nncase
public HashSet<Function> Functions => _functions;
protected override int VisitLeafFunction(Function expr, Unit context)
protected override int VisitLeafFunction(Function expr)
{
_functions.Add(expr);
return 0;

View File

@ -0,0 +1,28 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
namespace Nncase.IR.Buffers;
/// <summary>
/// BufferLoad expression.
/// </summary>
[PatternFunctionalGenerator]
public sealed partial class BufferLoad : Op
{
/// <summary>
/// Get the input parameter.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(BufferLoad), 0, "input", IsTensor());
/// <summary>
/// Get the indices.
/// </summary>
public static readonly ParameterInfo Indices = new(typeof(BufferLoad), 1, "indices", IsTuple());
/// <inheritdoc/>
public override bool CanFoldConstCall => false;
}

View File

@ -16,7 +16,7 @@ public sealed partial class BufferOf : Op
/// </summary>
public static readonly ParameterInfo Input = new(typeof(BufferOf), 0, "input", IsTensor());
public Schedule.MemoryLocation MemoryLocation { get; }
public TIR.MemoryLocation MemoryLocation { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"Schedule.MemoryLocation.{MemoryLocation}";

View File

@ -0,0 +1,33 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
namespace Nncase.IR.Buffers;
/// <summary>
/// BufferStore op.
/// </summary>
[PatternFunctionalGenerator]
public sealed partial class BufferStore : Op
{
/// <summary>
/// Get the input parameter.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(BufferStore), 0, "input", IsTensor());
/// <summary>
/// Get the indices parameter.
/// </summary>
public static readonly ParameterInfo Indices = new(typeof(BufferStore), 1, "indices", IsTuple());
/// <summary>
/// Get the value parameter.
/// </summary>
public static readonly ParameterInfo Value = new(typeof(BufferStore), 2, "value", IsScalar());
/// <inheritdoc/>
public override bool CanFoldConstCall => false;
}

View File

@ -17,4 +17,7 @@ public sealed partial class DDrOf : Op
/// Get the input parameter.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(DDrOf), 0, "input", IsTensor());
/// <inheritdoc/>
public override bool CanFoldConstCall => false;
}

View File

@ -41,5 +41,5 @@ public static class Buffer
/// <summary>
/// create the uninitialized buffer.
/// </summary>
public static Call Uninitialized(DataType dataType, Schedule.MemoryLocation memoryLocation, Expr shape) => new Call(new Uninitialized(dataType, memoryLocation), shape);
public static Call Uninitialized(DataType dataType, TIR.MemoryLocation memoryLocation, Expr shape) => new Call(new Uninitialized(dataType, memoryLocation), shape);
}

View File

@ -0,0 +1,21 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
namespace Nncase.IR.Buffers;
/// <summary>
/// MatchBuffer op.
/// todo maybe need united matchbuffer and allocatebuffer.
/// </summary>
[PatternFunctionalGenerator]
public sealed partial class MatchBuffer : Op
{
public static readonly ParameterInfo Input = new(typeof(MatchBuffer), 0, "input", IsTensor());
/// <inheritdoc/>
public override bool CanFoldConstCall => false;
}

View File

@ -19,11 +19,11 @@ public sealed partial class Uninitialized : Op
public DataType DType { get; }
public Schedule.MemoryLocation MemoryLocation { get; }
public TIR.MemoryLocation MemoryLocation { get; }
/// <inheritdoc/>
public override bool CanFoldConstCall => false;
/// <inheritdoc/>
public override string DisplayProperty() => $"{DType.GetCSharpName()}, Schedule.MemoryLocation.{MemoryLocation}";
public override string DisplayProperty() => $"{DType.GetCSharpName()}, MemoryLocation.{MemoryLocation}";
}

View File

@ -17,7 +17,7 @@ public abstract class Callable : Expr
/// <summary>
/// StackVM module kind.
/// </summary>
public static readonly string StackVMModuleKind = "stackvm";
public const string StackVMModuleKind = "stackvm";
public Callable(string name, string moduleKind, Expr[] operands)
: base(operands)

View File

@ -1,4 +1,3 @@

//---------------------------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by T4 template.
@ -57,8 +56,7 @@ public partial class ExprCloner<TContext>
return expr.With(
condition: Clone(expr.Condition, context),
then: Clone(expr.Then, context),
@else: Clone(expr.Else, context),
paramList: expr.ParamList.Select(p => Clone(p, context)).ToArray()
@else: Clone(expr.Else, context)
);
}
@ -141,31 +139,6 @@ public partial class ExprCloner<TContext>
);
}
/// <inheritdoc />
protected override Expr VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context)
{
return expr.With(
dimensions: CloneArray(expr.Dimensions, context),
strides: CloneArray(expr.Strides, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context)
{
return expr.With(
);
}
/// <inheritdoc />
protected override Expr VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context)
{
return expr.With(
buffer: Clone(expr.Buffer, context),
indices: CloneArray(expr.Indices, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context)
{
@ -175,16 +148,6 @@ public partial class ExprCloner<TContext>
);
}
/// <inheritdoc />
protected override Expr VisitLeafBufferStore(TIR.BufferStore expr, TContext context)
{
return expr.With(
buffer: Clone(expr.Buffer, context),
indices: CloneArray(expr.Indices, context),
value: Clone(expr.Value, context)
);
}
/// <inheritdoc />
protected override Expr VisitLeafFor(TIR.For expr, TContext context)
{

View File

@ -102,6 +102,13 @@ public partial class ExprFunctor<TExprResult, TTypeResult> : ExprFunctor<TExprRe
/// <returns>Result.</returns>
public virtual TTypeResult VisitType(TensorType type) => base.VisitType(type, default);
/// <summary>
/// Visit point type.
/// </summary>
/// <param name="type">pointer type.</param>
/// <returns>Result.</returns>
public virtual TTypeResult VisitType(PointerType type) => base.VisitType(type, default);
/// <summary>
/// Visit tuple type.
/// </summary>
@ -116,6 +123,13 @@ public partial class ExprFunctor<TExprResult, TTypeResult> : ExprFunctor<TExprRe
/// <returns>Result.</returns>
public virtual TTypeResult VisitType(CallableType type) => base.VisitType(type, default);
/// <summary>
/// Visit callable type.
/// </summary>
/// <param name="type">Callable type.</param>
/// <returns>Result.</returns>
public virtual TTypeResult VisitType(DistributedType type) => base.VisitType(type, default);
/// <summary>
/// Default visit routine.
/// </summary>
@ -135,12 +149,18 @@ public partial class ExprFunctor<TExprResult, TTypeResult> : ExprFunctor<TExprRe
/// <inheritdoc/>
public sealed override TTypeResult VisitType(TensorType type, Unit context) => VisitType(type);
/// <inheritdoc/>
public sealed override TTypeResult VisitType(PointerType type, Unit context) => VisitType(type);
/// <inheritdoc/>
public sealed override TTypeResult VisitType(TupleType type, Unit context) => VisitType(type);
/// <inheritdoc/>
public sealed override TTypeResult VisitType(CallableType type, Unit context) => VisitType(type);
/// <inheritdoc/>
public sealed override TTypeResult VisitType(DistributedType type, Unit context) => VisitType(type);
/// <inheritdoc/>
public sealed override TTypeResult DefaultVisitType(IRType type, Unit context) => DefaultVisitType(type);

View File

@ -1,4 +1,3 @@

//---------------------------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by T4 template.
@ -79,6 +78,11 @@ public partial class ExprFunctor<TExprResult, TTypeResult, TContext>
/// </summary>
internal protected virtual TExprResult VisitTupleConst(TupleConst expr, TContext context) => VisitConst(expr, context);
/// <summary>
/// Visit <see cref="TIR.MemSpan"/>.
/// </summary>
internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="Var"/>.
/// </summary>
@ -94,31 +98,11 @@ public partial class ExprFunctor<TExprResult, TTypeResult, TContext>
/// </summary>
internal protected virtual TExprResult VisitBuffer(TIR.Buffer expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.LogicalBuffer"/>.
/// </summary>
internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => VisitBuffer(expr, context);
/// <summary>
/// Visit <see cref="TIR.PhysicalBuffer"/>.
/// </summary>
internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitBuffer(expr, context);
/// <summary>
/// Visit <see cref="TIR.BufferLoad"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.BufferRegion"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.BufferStore"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr, TContext context) => DefaultVisit(expr, context);
/// <summary>
/// Visit <see cref="TIR.For"/>.
/// </summary>
@ -250,6 +234,13 @@ public partial class ExprFunctor<TExprResult, TTypeResult>
/// <inheritdoc/>
internal protected sealed override TExprResult VisitTupleConst(TupleConst expr, Unit context) => VisitTupleConst(expr);
/// <summary>
/// Visit <see cref="TIR.MemSpan"/>.
/// </summary>
internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr) => base.VisitMemSpan(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitMemSpan(TIR.MemSpan expr, Unit context) => VisitMemSpan(expr);
/// <summary>
/// Visit <see cref="Var"/>.
/// </summary>
internal protected virtual TExprResult VisitVar(Var expr) => base.VisitVar(expr, default);
@ -271,27 +262,6 @@ public partial class ExprFunctor<TExprResult, TTypeResult>
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBuffer(TIR.Buffer expr, Unit context) => VisitBuffer(expr);
/// <summary>
/// Visit <see cref="TIR.LogicalBuffer"/>.
/// </summary>
internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLogicalBuffer(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLogicalBuffer(expr);
/// <summary>
/// Visit <see cref="TIR.PhysicalBuffer"/>.
/// </summary>
internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitPhysicalBuffer(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr);
/// <summary>
/// Visit <see cref="TIR.BufferLoad"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr) => base.VisitBufferLoad(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBufferLoad(TIR.BufferLoad expr, Unit context) => VisitBufferLoad(expr);
/// <summary>
/// Visit <see cref="TIR.BufferRegion"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default);
@ -299,13 +269,6 @@ public partial class ExprFunctor<TExprResult, TTypeResult>
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBufferRegion(TIR.BufferRegion expr, Unit context) => VisitBufferRegion(expr);
/// <summary>
/// Visit <see cref="TIR.BufferStore"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr) => base.VisitBufferStore(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBufferStore(TIR.BufferStore expr, Unit context) => VisitBufferStore(expr);
/// <summary>
/// Visit <see cref="TIR.For"/>.
/// </summary>
internal protected virtual TExprResult VisitFor(TIR.For expr) => base.VisitFor(expr, default);

View File

@ -1,4 +1,3 @@

//---------------------------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by T4 template.
@ -92,6 +91,12 @@ public partial class ExprRewriter<TContext>
return RewriteLeafTupleConst(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafMemSpan(TIR.MemSpan expr, TContext context)
{
return RewriteLeafMemSpan(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafVar(Var expr, TContext context)
{
@ -110,36 +115,12 @@ public partial class ExprRewriter<TContext>
return RewriteLeafBuffer(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context)
{
return RewriteLeafLogicalBuffer(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context)
{
return RewriteLeafPhysicalBuffer(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context)
{
return RewriteLeafBufferLoad(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context)
{
return RewriteLeafBufferRegion(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafBufferStore(TIR.BufferStore expr, TContext context)
{
return RewriteLeafBufferStore(expr, context);
}
/// <inheritdoc/>
protected sealed override Expr VisitLeafFor(TIR.For expr, TContext context)
{
@ -247,6 +228,11 @@ public partial class ExprRewriter<TContext>
/// </summary>
protected virtual Expr RewriteLeafTupleConst(TupleConst expr, TContext context) => RewriteLeafConst(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.MemSpan"/>.
/// </summary>
protected virtual Expr RewriteLeafMemSpan(TIR.MemSpan expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="Var"/>.
/// </summary>
@ -262,31 +248,11 @@ public partial class ExprRewriter<TContext>
/// </summary>
protected virtual Expr RewriteLeafBuffer(TIR.Buffer expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.LogicalBuffer"/>.
/// </summary>
protected virtual Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => RewriteLeafBuffer(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.PhysicalBuffer"/>.
/// </summary>
protected virtual Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => RewriteLeafBuffer(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.BufferLoad"/>.
/// </summary>
protected virtual Expr RewriteLeafBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.BufferRegion"/>.
/// </summary>
protected virtual Expr RewriteLeafBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.BufferStore"/>.
/// </summary>
protected virtual Expr RewriteLeafBufferStore(TIR.BufferStore expr, TContext context) => DefaultRewriteLeaf(expr, context);
/// <summary>
/// Rewrite leaf <see cref="TIR.For"/>.
/// </summary>
@ -430,6 +396,14 @@ public partial class ExprRewriter
/// <inheritdoc />
protected sealed override Expr RewriteLeafTupleConst(TupleConst expr, Unit context) => RewriteLeafTupleConst(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.MemSpan"/>.
/// </summary>
protected virtual Expr RewriteLeafMemSpan(TIR.MemSpan expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafMemSpan(TIR.MemSpan expr, Unit context) => RewriteLeafMemSpan(expr);
/// <summary>
/// Rewrite leaf <see cref="Var"/>.
/// </summary>
@ -454,30 +428,6 @@ public partial class ExprRewriter
/// <inheritdoc />
protected sealed override Expr RewriteLeafBuffer(TIR.Buffer expr, Unit context) => RewriteLeafBuffer(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.LogicalBuffer"/>.
/// </summary>
protected virtual Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr) => RewriteLeafBuffer(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => RewriteLeafLogicalBuffer(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.PhysicalBuffer"/>.
/// </summary>
protected virtual Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr) => RewriteLeafBuffer(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => RewriteLeafPhysicalBuffer(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.BufferLoad"/>.
/// </summary>
protected virtual Expr RewriteLeafBufferLoad(TIR.BufferLoad expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafBufferLoad(TIR.BufferLoad expr, Unit context) => RewriteLeafBufferLoad(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.BufferRegion"/>.
/// </summary>
@ -486,14 +436,6 @@ public partial class ExprRewriter
/// <inheritdoc />
protected sealed override Expr RewriteLeafBufferRegion(TIR.BufferRegion expr, Unit context) => RewriteLeafBufferRegion(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.BufferStore"/>.
/// </summary>
protected virtual Expr RewriteLeafBufferStore(TIR.BufferStore expr) => DefaultRewriteLeaf(expr);
/// <inheritdoc />
protected sealed override Expr RewriteLeafBufferStore(TIR.BufferStore expr, Unit context) => RewriteLeafBufferStore(expr);
/// <summary>
/// Rewrite leaf <see cref="TIR.For"/>.
/// </summary>

View File

@ -1,4 +1,3 @@

//---------------------------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by T4 template.
@ -103,6 +102,13 @@ public partial class ExprVisitor<TExprResult, TTypeResult, TContext>
return VisitLeafTupleConst(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitMemSpan(TIR.MemSpan expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafMemSpan(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitVar(Var expr, TContext context)
{
@ -118,24 +124,10 @@ public partial class ExprVisitor<TExprResult, TTypeResult, TContext>
}
/// <inheritdoc />
protected internal override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, TContext context)
protected internal override TExprResult VisitBuffer(TIR.Buffer expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafLogicalBuffer(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafPhysicalBuffer(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitBufferLoad(TIR.BufferLoad expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafBufferLoad(expr, context);
return VisitLeafBuffer(expr, context);
}
/// <inheritdoc />
@ -145,13 +137,6 @@ public partial class ExprVisitor<TExprResult, TTypeResult, TContext>
return VisitLeafBufferRegion(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitBufferStore(TIR.BufferStore expr, TContext context)
{
VisitOperands(expr, context);
return VisitLeafBufferStore(expr, context);
}
/// <inheritdoc />
protected internal override TExprResult VisitFor(TIR.For expr, TContext context)
{
@ -270,6 +255,11 @@ public partial class ExprVisitor<TExprResult, TTypeResult, TContext>
/// </summary>
protected virtual TExprResult VisitLeafTupleConst(TupleConst expr, TContext context) => VisitLeafConst(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.MemSpan"/>.
/// </summary>
protected virtual TExprResult VisitLeafMemSpan(TIR.MemSpan expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="Var"/>.
/// </summary>
@ -285,31 +275,11 @@ public partial class ExprVisitor<TExprResult, TTypeResult, TContext>
/// </summary>
protected virtual TExprResult VisitLeafBuffer(TIR.Buffer expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.LogicalBuffer"/>.
/// </summary>
protected virtual TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => VisitLeafBuffer(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.PhysicalBuffer"/>.
/// </summary>
protected virtual TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitLeafBuffer(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.BufferLoad"/>.
/// </summary>
protected virtual TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.BufferRegion"/>.
/// </summary>
protected virtual TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.BufferStore"/>.
/// </summary>
protected virtual TExprResult VisitLeafBufferStore(TIR.BufferStore expr, TContext context) => DefaultVisitLeaf(expr, context);
/// <summary>
/// Visit leaf <see cref="TIR.For"/>.
/// </summary>
@ -353,182 +323,168 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit <see cref="Call"/>.
/// </summary>
internal protected virtual TExprResult VisitCall(Call expr) => base.VisitCall(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitCall(Call expr, Unit context) => VisitCall(expr);
/// <summary>
/// Visit <see cref="Function"/>.
/// </summary>
internal protected virtual TExprResult VisitFunction(Function expr) => base.VisitFunction(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitFunction(Function expr, Unit context) => VisitFunction(expr);
/// <summary>
/// Visit <see cref="Fusion"/>.
/// </summary>
internal protected virtual TExprResult VisitFusion(Fusion expr) => base.VisitFusion(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitFusion(Fusion expr, Unit context) => VisitFusion(expr);
/// <summary>
/// Visit <see cref="If"/>.
/// </summary>
internal protected virtual TExprResult VisitIf(If expr) => base.VisitIf(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitIf(If expr, Unit context) => VisitIf(expr);
/// <summary>
/// Visit <see cref="Marker"/>.
/// </summary>
internal protected virtual TExprResult VisitMarker(Marker expr) => base.VisitMarker(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitMarker(Marker expr, Unit context) => VisitMarker(expr);
/// <summary>
/// Visit <see cref="None"/>.
/// </summary>
internal protected virtual TExprResult VisitNone(None expr) => base.VisitNone(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitNone(None expr, Unit context) => VisitNone(expr);
/// <summary>
/// Visit <see cref="Op"/>.
/// </summary>
internal protected virtual TExprResult VisitOp(Op expr) => base.VisitOp(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitOp(Op expr, Unit context) => VisitOp(expr);
/// <summary>
/// Visit <see cref="PrimFunctionWrapper"/>.
/// </summary>
internal protected virtual TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitPrimFunctionWrapper(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitPrimFunctionWrapper(expr);
/// <summary>
/// Visit <see cref="TensorConst"/>.
/// </summary>
internal protected virtual TExprResult VisitTensorConst(TensorConst expr) => base.VisitTensorConst(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr);
/// <summary>
/// Visit <see cref="IR.Tuple"/>.
/// </summary>
internal protected virtual TExprResult VisitTuple(IR.Tuple expr) => base.VisitTuple(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitTuple(IR.Tuple expr, Unit context) => VisitTuple(expr);
/// <summary>
/// Visit <see cref="TupleConst"/>.
/// </summary>
internal protected virtual TExprResult VisitTupleConst(TupleConst expr) => base.VisitTupleConst(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitTupleConst(TupleConst expr, Unit context) => VisitTupleConst(expr);
/// <summary>
/// Visit <see cref="TIR.MemSpan"/>.
/// </summary>
internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr) => base.VisitMemSpan(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitMemSpan(TIR.MemSpan expr, Unit context) => VisitMemSpan(expr);
/// <summary>
/// Visit <see cref="Var"/>.
/// </summary>
internal protected virtual TExprResult VisitVar(Var expr) => base.VisitVar(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitVar(Var expr, Unit context) => VisitVar(expr);
/// <summary>
/// Visit <see cref="TIR.Block"/>.
/// </summary>
internal protected virtual TExprResult VisitBlock(TIR.Block expr) => base.VisitBlock(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBlock(TIR.Block expr, Unit context) => VisitBlock(expr);
/// <summary>
/// Visit <see cref="TIR.LogicalBuffer"/>.
/// Visit <see cref="TIR.Buffer"/>.
/// </summary>
internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLogicalBuffer(expr, default);
internal protected virtual TExprResult VisitBuffer(TIR.Buffer expr) => base.VisitBuffer(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLogicalBuffer(expr);
/// <summary>
/// Visit <see cref="TIR.PhysicalBuffer"/>.
/// </summary>
internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitPhysicalBuffer(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr);
/// <summary>
/// Visit <see cref="TIR.BufferLoad"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr) => base.VisitBufferLoad(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBufferLoad(TIR.BufferLoad expr, Unit context) => VisitBufferLoad(expr);
internal protected sealed override TExprResult VisitBuffer(TIR.Buffer expr, Unit context) => VisitBuffer(expr);
/// <summary>
/// Visit <see cref="TIR.BufferRegion"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBufferRegion(TIR.BufferRegion expr, Unit context) => VisitBufferRegion(expr);
/// <summary>
/// Visit <see cref="TIR.BufferStore"/>.
/// </summary>
internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr) => base.VisitBufferStore(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitBufferStore(TIR.BufferStore expr, Unit context) => VisitBufferStore(expr);
/// <summary>
/// Visit <see cref="TIR.For"/>.
/// </summary>
internal protected virtual TExprResult VisitFor(TIR.For expr) => base.VisitFor(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitFor(TIR.For expr, Unit context) => VisitFor(expr);
/// <summary>
/// Visit <see cref="TIR.IfThenElse"/>.
/// </summary>
internal protected virtual TExprResult VisitIfThenElse(TIR.IfThenElse expr) => base.VisitIfThenElse(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitIfThenElse(TIR.IfThenElse expr, Unit context) => VisitIfThenElse(expr);
/// <summary>
/// Visit <see cref="TIR.Let"/>.
/// </summary>
internal protected virtual TExprResult VisitLet(TIR.Let expr) => base.VisitLet(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitLet(TIR.Let expr, Unit context) => VisitLet(expr);
/// <summary>
/// Visit <see cref="TIR.PrimFunction"/>.
/// </summary>
internal protected virtual TExprResult VisitPrimFunction(TIR.PrimFunction expr) => base.VisitPrimFunction(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitPrimFunction(TIR.PrimFunction expr, Unit context) => VisitPrimFunction(expr);
/// <summary>
/// Visit <see cref="TIR.Sequential"/>.
/// </summary>
internal protected virtual TExprResult VisitSequential(TIR.Sequential expr) => base.VisitSequential(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitSequential(TIR.Sequential expr, Unit context) => VisitSequential(expr);
/// <summary>
/// Visit <see cref="TIR.Range"/>.
/// </summary>
internal protected virtual TExprResult VisitRange(TIR.Range expr) => base.VisitRange(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitRange(TIR.Range expr, Unit context) => VisitRange(expr);
/// <summary>
/// Visit <see cref="TIR.IterVar"/>.
/// </summary>
internal protected virtual TExprResult VisitIterVar(TIR.IterVar expr) => base.VisitIterVar(expr, default);
/// <inheritdoc/>
internal protected sealed override TExprResult VisitIterVar(TIR.IterVar expr, Unit context) => VisitIterVar(expr);
/// <summary>
/// Visit leaf <see cref="BaseFunction"/>.
/// </summary>
protected virtual TExprResult VisitLeafBaseFunction(BaseFunction expr) => base.VisitLeafBaseFunction(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafBaseFunction(BaseFunction expr, Unit context) => VisitLeafBaseFunction(expr);
@ -536,7 +492,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="Call"/>.
/// </summary>
protected virtual TExprResult VisitLeafCall(Call expr) => base.VisitLeafCall(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafCall(Call expr, Unit context) => VisitLeafCall(expr);
@ -544,7 +500,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="Const"/>.
/// </summary>
protected virtual TExprResult VisitLeafConst(Const expr) => base.VisitLeafConst(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafConst(Const expr, Unit context) => VisitLeafConst(expr);
@ -552,15 +508,15 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="Function"/>.
/// </summary>
protected virtual TExprResult VisitLeafFunction(Function expr) => base.VisitLeafFunction(expr, default);
/// <inheritdoc/>
protected override TExprResult VisitLeafFunction(Function expr, Unit context) => VisitLeafFunction(expr);
protected sealed override TExprResult VisitLeafFunction(Function expr, Unit context) => VisitLeafFunction(expr);
/// <summary>
/// Visit leaf <see cref="Fusion"/>.
/// </summary>
protected virtual TExprResult VisitLeafFusion(Fusion expr) => base.VisitLeafFusion(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafFusion(Fusion expr, Unit context) => VisitLeafFusion(expr);
@ -568,7 +524,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="If"/>.
/// </summary>
protected virtual TExprResult VisitLeafIf(If expr) => base.VisitLeafIf(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafIf(If expr, Unit context) => VisitLeafIf(expr);
@ -576,7 +532,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="Marker"/>.
/// </summary>
protected virtual TExprResult VisitLeafMarker(Marker expr) => base.VisitLeafMarker(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafMarker(Marker expr, Unit context) => VisitLeafMarker(expr);
@ -584,7 +540,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="None"/>.
/// </summary>
protected virtual TExprResult VisitLeafNone(None expr) => base.VisitLeafNone(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafNone(None expr, Unit context) => VisitLeafNone(expr);
@ -592,7 +548,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="Op"/>.
/// </summary>
protected virtual TExprResult VisitLeafOp(Op expr) => base.VisitLeafOp(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafOp(Op expr, Unit context) => VisitLeafOp(expr);
@ -600,7 +556,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="PrimFunctionWrapper"/>.
/// </summary>
protected virtual TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitLeafPrimFunctionWrapper(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitLeafPrimFunctionWrapper(expr);
@ -608,7 +564,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="TensorConst"/>.
/// </summary>
protected virtual TExprResult VisitLeafTensorConst(TensorConst expr) => base.VisitLeafTensorConst(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafTensorConst(TensorConst expr, Unit context) => VisitLeafTensorConst(expr);
@ -616,7 +572,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="IR.Tuple"/>.
/// </summary>
protected virtual TExprResult VisitLeafTuple(IR.Tuple expr) => base.VisitLeafTuple(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafTuple(IR.Tuple expr, Unit context) => VisitLeafTuple(expr);
@ -624,15 +580,23 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="TupleConst"/>.
/// </summary>
protected virtual TExprResult VisitLeafTupleConst(TupleConst expr) => base.VisitLeafTupleConst(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafTupleConst(TupleConst expr, Unit context) => VisitLeafTupleConst(expr);
/// <summary>
/// Visit leaf <see cref="TIR.MemSpan"/>.
/// </summary>
protected virtual TExprResult VisitLeafMemSpan(TIR.MemSpan expr) => base.VisitLeafMemSpan(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafMemSpan(TIR.MemSpan expr, Unit context) => VisitLeafMemSpan(expr);
/// <summary>
/// Visit leaf <see cref="Var"/>.
/// </summary>
protected virtual TExprResult VisitLeafVar(Var expr) => base.VisitLeafVar(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafVar(Var expr, Unit context) => VisitLeafVar(expr);
@ -640,7 +604,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="TIR.Block"/>.
/// </summary>
protected virtual TExprResult VisitLeafBlock(TIR.Block expr) => base.VisitLeafBlock(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafBlock(TIR.Block expr, Unit context) => VisitLeafBlock(expr);
@ -648,55 +612,23 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="TIR.Buffer"/>.
/// </summary>
protected virtual TExprResult VisitLeafBuffer(TIR.Buffer expr) => base.VisitLeafBuffer(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafBuffer(TIR.Buffer expr, Unit context) => VisitLeafBuffer(expr);
/// <summary>
/// Visit leaf <see cref="TIR.LogicalBuffer"/>.
/// </summary>
protected virtual TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLeafLogicalBuffer(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLeafLogicalBuffer(expr);
/// <summary>
/// Visit leaf <see cref="TIR.PhysicalBuffer"/>.
/// </summary>
protected virtual TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitLeafPhysicalBuffer(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitLeafPhysicalBuffer(expr);
/// <summary>
/// Visit leaf <see cref="TIR.BufferLoad"/>.
/// </summary>
protected virtual TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr) => base.VisitLeafBufferLoad(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr, Unit context) => VisitLeafBufferLoad(expr);
/// <summary>
/// Visit leaf <see cref="TIR.BufferRegion"/>.
/// </summary>
protected virtual TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr) => base.VisitLeafBufferRegion(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr, Unit context) => VisitLeafBufferRegion(expr);
/// <summary>
/// Visit leaf <see cref="TIR.BufferStore"/>.
/// </summary>
protected virtual TExprResult VisitLeafBufferStore(TIR.BufferStore expr) => base.VisitLeafBufferStore(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafBufferStore(TIR.BufferStore expr, Unit context) => VisitLeafBufferStore(expr);
/// <summary>
/// Visit leaf <see cref="TIR.For"/>.
/// </summary>
protected virtual TExprResult VisitLeafFor(TIR.For expr) => base.VisitLeafFor(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafFor(TIR.For expr, Unit context) => VisitLeafFor(expr);
@ -704,7 +636,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="TIR.IfThenElse"/>.
/// </summary>
protected virtual TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr) => base.VisitLeafIfThenElse(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr, Unit context) => VisitLeafIfThenElse(expr);
@ -712,7 +644,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="TIR.Let"/>.
/// </summary>
protected virtual TExprResult VisitLeafLet(TIR.Let expr) => base.VisitLeafLet(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafLet(TIR.Let expr, Unit context) => VisitLeafLet(expr);
@ -720,7 +652,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="TIR.PrimFunction"/>.
/// </summary>
protected virtual TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr) => base.VisitLeafPrimFunction(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr, Unit context) => VisitLeafPrimFunction(expr);
@ -728,7 +660,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="TIR.Sequential"/>.
/// </summary>
protected virtual TExprResult VisitLeafSequential(TIR.Sequential expr) => base.VisitLeafSequential(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafSequential(TIR.Sequential expr, Unit context) => VisitLeafSequential(expr);
@ -736,7 +668,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="TIR.Range"/>.
/// </summary>
protected virtual TExprResult VisitLeafRange(TIR.Range expr) => base.VisitLeafRange(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafRange(TIR.Range expr, Unit context) => VisitLeafRange(expr);
@ -744,7 +676,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
/// Visit leaf <see cref="TIR.IterVar"/>.
/// </summary>
protected virtual TExprResult VisitLeafIterVar(TIR.IterVar expr) => base.VisitLeafIterVar(expr, default);
/// <inheritdoc/>
protected sealed override TExprResult VisitLeafIterVar(TIR.IterVar expr, Unit context) => VisitLeafIterVar(expr);

View File

@ -73,6 +73,14 @@ public interface IIRPrinterProvider
/// <param name="randConst">randConst = false will save the const into bin.</param>
public void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst);
/// <summary>
/// dump the expr as csharp code.
/// </summary>
/// <param name="expr">expression.</param>
/// <param name="prefix">file prefix.</param>
/// <param name="dumpDir">file dump ir.</param>
public void DumpPatternIR(Expr expr, string prefix, string dumpDir);
/// <summary>
/// print ir type.
/// </summary>

View File

@ -11,14 +11,11 @@ PrimFunctionWrapper,true,true,BaseFunction,,Target
TensorConst,true,false,Const,,
Tuple,true,false,Default,IR.,@Fields
TupleConst,true,false,Const,,
MemSpan,true,false,Default,TIR.,Start;Size;
Var,true,false,Default,,
Block,true,false,Default,TIR.,Body;InitBody;@IterVars;@Reads;@Writes;@AllocBuffers;Predicate
Buffer,false,false,Default,TIR.,
LogicalBuffer,true,false,Buffer,TIR.,@Dimensions;@Strides
PhysicalBuffer,true,false,Buffer,TIR.,
BufferLoad,true,false,Default,TIR.,Buffer;@Indices
Buffer,true,false,Default,TIR.,MemSpan;@Dimensions;@Strides;
BufferRegion,true,false,Default,TIR.,Buffer;@Region
BufferStore,true,false,Default,TIR.,Buffer;@Indices;Value
For,true,false,Default,TIR.,LoopVar;Domain;Body
IfThenElse,true,false,Default,TIR.,Condition;Then;Else
Let,true,false,Default,TIR.,Var;Expression;Body

1 BaseFunction false true Default
11 TensorConst true false Const
12 Tuple true false Default IR. @Fields
13 TupleConst true false Const
14 MemSpan true false Default TIR. Start;Size;
15 Var true false Default
16 Block true false Default TIR. Body;InitBody;@IterVars;@Reads;@Writes;@AllocBuffers;Predicate
17 Buffer false true false Default TIR. MemSpan;@Dimensions;@Strides;
LogicalBuffer true false Buffer TIR. @Dimensions;@Strides
PhysicalBuffer true false Buffer TIR.
BufferLoad true false Default TIR. Buffer;@Indices
18 BufferRegion true false Default TIR. Buffer;@Region
BufferStore true false Default TIR. Buffer;@Indices;Value
19 For true false Default TIR. LoopVar;Domain;Body
20 IfThenElse true false Default TIR. Condition;Then;Else
21 Let true false Default TIR. Var;Expression;Body

View File

@ -139,6 +139,15 @@ public sealed record TensorType(DataType DType, Shape Shape) : IRType
/// <param name="elemType"> the Pointed Element Type.</param>
/// <returns>the pointer tensor type.</returns>
public static TensorType Pointer(DataType elemType) => new(new PointerType(elemType), Shape.Scalar);
/// <inheritdoc/>
public override string ToString() => DType switch
{
PrimType ptype => ptype.GetDisplayName() + (Shape.IsScalar ? string.Empty : Shape.ToString()),
PointerType { ElemType: PrimType etype } => $"*{etype.GetDisplayName()}",
ValueType => $"{DType}",
_ => throw new NotSupportedException(DType.GetType().Name),
};
}
/// <summary>

View File

@ -21,7 +21,7 @@ public sealed partial class ResizeImage : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(ResizeImage), 0, "input", HasRank(r => r >= 2, "RanK >= 2"));
public static readonly ParameterInfo Input = new(typeof(ResizeImage), 0, "input", HasRank(r => r >= 2, "RanK >= 2"), ParameterKind.Input);
/// <summary>
/// Gets roi.

View File

@ -20,12 +20,12 @@ public sealed partial class Binary : Op
/// <summary>
/// Gets lhs.
/// </summary>
public static readonly ParameterInfo Lhs = new(typeof(Binary), 0, "lhs");
public static readonly ParameterInfo Lhs = new(typeof(Binary), 0, "lhs", ParameterKind.Input);
/// <summary>
/// Gets rhs.
/// </summary>
public static readonly ParameterInfo Rhs = new(typeof(Binary), 1, "rhs");
public static readonly ParameterInfo Rhs = new(typeof(Binary), 1, "rhs", ParameterKind.Input);
public BinaryOp BinaryOp { get; }

View File

@ -21,7 +21,7 @@ public sealed partial class Clamp : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Clamp), 0, "input");
public static readonly ParameterInfo Input = new(typeof(Clamp), 0, "input", ParameterKind.Input);
/// <summary>
/// Gets min.

View File

@ -20,10 +20,10 @@ public sealed partial class MatMul : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Lhs = new(typeof(MatMul), 0, "lhs");
public static readonly ParameterInfo Lhs = new(typeof(MatMul), 0, "lhs", ParameterKind.Input);
/// <summary>
/// Gets Other.
/// </summary>
public static readonly ParameterInfo Rhs = new(typeof(MatMul), 1, "rhs");
public static readonly ParameterInfo Rhs = new(typeof(MatMul), 1, "rhs", ParameterKind.Input);
}

View File

@ -21,7 +21,7 @@ public sealed partial class ReduceArg : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(ReduceArg), 0, "input");
public static readonly ParameterInfo Input = new(typeof(ReduceArg), 0, "input", ParameterKind.Input);
/// <summary>
/// Gets Axis.
@ -42,8 +42,8 @@ public sealed partial class ReduceArg : Op
public ReduceArgOp ReduceArgOp { get; }
public DataType DestType { get; }
public PrimType DestType { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"ReduceArgOp.{ReduceArgOp}";
public override string DisplayProperty() => $"ReduceArgOp.{ReduceArgOp}, DestType: {DestType}";
}

View File

@ -20,7 +20,7 @@ public sealed partial class Unary : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Unary), 0, "input");
public static readonly ParameterInfo Input = new(typeof(Unary), 0, "input", ParameterKind.Input);
public UnaryOp UnaryOp { get; }

View File

@ -154,7 +154,12 @@ public sealed partial class Swish : ActivationOp
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Swish), 0, "input");
public static readonly ParameterInfo Input = new(typeof(Swish), 0, "input", ParameterKind.Input);
/// <summary>
/// Gets beta.
/// </summary>
public static readonly ParameterInfo Beta = new(typeof(Swish), 1, "beta", IsFloatScalar());
}
/// <summary>

View File

@ -21,17 +21,17 @@ public sealed partial class Conv2D : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Conv2D), 0, "input");
public static readonly ParameterInfo Input = new(typeof(Conv2D), 0, "input", ParameterKind.Input);
/// <summary>
/// Gets Weights.
/// </summary>
public static readonly ParameterInfo Weights = new(typeof(Conv2D), 1, "weights", HasRank(4));
public static readonly ParameterInfo Weights = new(typeof(Conv2D), 1, "weights", HasRank(4), ParameterKind.Input);
/// <summary>
/// Gets Bias.
/// </summary>
public static readonly ParameterInfo Bias = new(typeof(Conv2D), 2, "bias", HasRank(1));
public static readonly ParameterInfo Bias = new(typeof(Conv2D), 2, "bias", HasRank(1), ParameterKind.Input);
/// <summary>
/// Gets Stride.

View File

@ -34,7 +34,7 @@ public static class NN
public static Call BatchNormalization(Expr input, Expr scale, Expr bias, Expr input_mean, Expr input_var, Expr epsilon, Expr momentum) => new Call(new BatchNormalization(), input, scale, bias, input_mean, input_var, epsilon, momentum);
public static Call LayerNorm(int axis, float epsilon, Expr input, Expr scale, Expr bias) => new Call(new LayerNorm(axis, epsilon), input, scale, bias);
public static Call LayerNorm(int axis, float epsilon, Expr input, Expr scale, Expr bias, bool hasMean = true) => new Call(new LayerNorm(axis, epsilon, hasMean), input, scale, bias);
public static Call BatchToSpace(Expr input, Expr blockShape, Expr crops) => new Call(new BatchToSpace(), input, blockShape, crops);
@ -103,5 +103,10 @@ public static class NN
/// <summary>
/// create Swish call.
/// </summary>
public static Call Swish(Expr input) => new Call(new Swish(), input);
public static Call Swish(Expr input) => new Call(new Swish(), input, 1f);
/// <summary>
/// create Swish call.
/// </summary>
public static Call Swish(Expr input, Expr beta) => new Call(new Swish(), input, beta);
}

View File

@ -21,19 +21,23 @@ public sealed partial class LayerNorm : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(LayerNorm), 0, "input");
public static readonly ParameterInfo Input = new(typeof(LayerNorm), 0, "input", ParameterKind.Input);
/// <summary>
/// Gets scale.
/// </summary>
public static readonly ParameterInfo Scale = new(typeof(LayerNorm), 1, "scale");
public static readonly ParameterInfo Scale = new(typeof(LayerNorm), 1, "scale", ParameterKind.Input);
/// <summary>
/// Gets bias.
/// </summary>
public static readonly ParameterInfo Bias = new(typeof(LayerNorm), 2, "bias");
public static readonly ParameterInfo Bias = new(typeof(LayerNorm), 2, "bias", ParameterKind.Input);
public int Axis { get; }
public float Epsilon { get; }
public bool UseMean { get; }
public override string DisplayProperty() => $"Axis: {Axis}, Epsilon: {Epsilon}, UseMean: {UseMean}";
}

View File

@ -61,17 +61,17 @@ public sealed partial class InstanceNormalization : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(InstanceNormalization), 0, "input");
public static readonly ParameterInfo Input = new(typeof(InstanceNormalization), 0, "input", ParameterKind.Input);
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Scale = new(typeof(InstanceNormalization), 1, "scale");
public static readonly ParameterInfo Scale = new(typeof(InstanceNormalization), 1, "scale", ParameterKind.Input);
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Bias = new(typeof(InstanceNormalization), 2, "bias");
public static readonly ParameterInfo Bias = new(typeof(InstanceNormalization), 2, "bias", ParameterKind.Input);
/// <summary>
/// Gets Epsilon.

View File

@ -33,7 +33,7 @@ public sealed partial class Softmax : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Softmax), 0, "input");
public static readonly ParameterInfo Input = new(typeof(Softmax), 0, "input", ParameterKind.Input);
/// <summary>
/// Gets axis.

View File

@ -12,6 +12,12 @@ using static Nncase.IR.TypePatternUtility;
namespace Nncase.IR;
public enum ParameterKind : int
{
Input,
Attribute,
}
/// <summary>
/// Parameter information.
/// </summary>
@ -24,11 +30,13 @@ public sealed class ParameterInfo
/// <param name="ownerType">this op type.</param>
/// <param name="index">param index.</param>
/// <param name="name">param name.</param>
public ParameterInfo(Type ownerType, int index, string name)
/// <param name="parameterKind">kind.</param>
public ParameterInfo(Type ownerType, int index, string name, ParameterKind parameterKind = ParameterKind.Attribute)
{
OwnerType = ownerType;
Index = index;
Name = name;
ParameterKind = parameterKind;
}
/// <summary>
@ -39,8 +47,9 @@ public sealed class ParameterInfo
/// <param name="index">param index.</param>
/// <param name="name">param name.</param>
/// <param name="pattern">the param condition.</param>
public ParameterInfo(Type ownerType, int index, string name, TypePattern pattern)
: this(ownerType, index, name)
/// <param name="parameterKind">kind.</param>
public ParameterInfo(Type ownerType, int index, string name, TypePattern pattern, ParameterKind parameterKind = ParameterKind.Attribute)
: this(ownerType, index, name, parameterKind)
{
Pattern = pattern;
}
@ -60,6 +69,11 @@ public sealed class ParameterInfo
/// </summary>
public string Name { get; }
/// <summary>
/// Gets parameter kind.
/// </summary>
public ParameterKind ParameterKind { get; }
/// <summary>
/// Gets this paramter's type condition.
/// </summary>
@ -90,7 +104,7 @@ public abstract class Op : Expr
/// <summary>
/// Gets get the parameters.
/// </summary>
public IEnumerable<ParameterInfo> Parameters =>
public virtual IEnumerable<ParameterInfo> Parameters =>
_parameters ??= (from p in GetType().GetFields(BindingFlags.Public | BindingFlags.Static)
where p.FieldType == typeof(ParameterInfo)
let param = (ParameterInfo)(p.GetValue(null) ?? throw new InvalidOperationException())

View File

@ -19,5 +19,5 @@ namespace Nncase.IR.F;
public static class RNN
{
public static Call LSTM(LSTMDirection direction, LSTMLayout layout, string[] acts, Expr x, Expr w, Expr r, Expr b, Expr seqLens, Expr initH, Expr initC, Expr p, Expr actAlpha, Expr actBeta, Expr clip, Expr hiddenSize, Expr inputForget, Expr outputSize) =>
new Call(new IR.Tensors.LSTM(direction, layout, acts), x, w, r, b, seqLens, initH, initC, p, actAlpha, actBeta, clip, hiddenSize, inputForget, outputSize);
new Call(new IR.RNN.LSTM(direction, layout, acts), x, w, r, b, seqLens, initH, initC, p, actAlpha, actBeta, clip, hiddenSize, inputForget, outputSize);
}

View File

@ -5,7 +5,7 @@ using System.Collections.Immutable;
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
namespace Nncase.IR.Tensors;
namespace Nncase.IR.RNN;
/// <summary>
/// LSTM expression.

View File

@ -146,7 +146,7 @@ public sealed class TensorConst : Const, IEquatable<TensorConst?>
public override bool Equals(object? obj) => Equals(obj as TensorConst);
/// <inheritdoc/>
public bool Equals(TensorConst? other) => other is not null && base.Equals(other) && EqualityComparer<Tensor>.Default.Equals(Value, other.Value);
public bool Equals(TensorConst? other) => other is not null && (ReferenceEquals(this, other) || GetHashCode() == other.GetHashCode()) && EqualityComparer<Tensor>.Default.Equals(Value, other.Value);
/// <inheritdoc/>
protected override int GetHashCodeCore() => HashCode.Combine(Value);

View File

@ -20,7 +20,7 @@ public sealed partial class Cast : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Cast), 0, "input");
public static readonly ParameterInfo Input = new(typeof(Cast), 0, "input", ParameterKind.Input);
public DataType NewType { get; }

View File

@ -20,10 +20,13 @@ public sealed partial class Concat : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Concat), 0, "inputs");
public static readonly ParameterInfo Input = new(typeof(Concat), 0, "inputs", ParameterKind.Input);
/// <summary>
/// Gets axis.
/// </summary>
public static readonly ParameterInfo Axis = new(typeof(Concat), 1, "axis");
public int Axis { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"Axis: {Axis}";
}

View File

@ -21,7 +21,7 @@ public sealed partial class Expand : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Expand), 0, "input");
public static readonly ParameterInfo Input = new(typeof(Expand), 0, "input", ParameterKind.Input);
/// <summary>
/// Gets shape.

View File

@ -70,7 +70,7 @@ public static class Tensors
public static Call Cast(Expr input, DataType newType, CastMode castMode = CastMode.KDefault) =>
new Call(new Cast(newType, castMode), input);
public static Call Concat(Expr input, Expr axis) => new Call(new Concat(), input, axis);
public static Call Concat(Expr input, int axis) => new Call(new Concat(axis), input);
public static Call ConstantOfShape(Expr shape, Expr value) => new Call(new ConstantOfShape(), shape, value);
@ -89,7 +89,7 @@ public static class Tensors
public static Call Flatten(Expr input, Expr axis) => new Call(new Flatten(), input, axis);
public static Call Gather(Expr input, Expr axis, Expr index) => new Call(new Gather(), input, axis, index);
public static Call Gather(Expr input, int axis, Expr index) => new Call(new Gather(axis), input, index);
public static Call GatherElements(Expr input, Expr axis, Expr indices) =>
new Call(new GatherElements(), input, axis, indices);

View File

@ -22,15 +22,18 @@ public sealed partial class Gather : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Gather), 0, "input");
/// <summary>
/// Gets axis.
/// </summary>
public static readonly ParameterInfo Axis = new(typeof(Gather), 1, "axis", IsIntegralScalar());
public static readonly ParameterInfo Input = new(typeof(Gather), 0, "input", ParameterKind.Input);
/// <summary>
/// Gets index.
/// </summary>
public static readonly ParameterInfo Index = new(typeof(Gather), 2, "index", IsIntegral());
public static readonly ParameterInfo Index = new(typeof(Gather), 1, "index", IsIntegral(), ParameterKind.Input);
/// <summary>
/// Gets axis.
/// </summary>
public int Axis { get; }
/// <inheritdoc/>
public override string DisplayProperty() => $"Axis: {Axis}";
}

View File

@ -22,7 +22,7 @@ public sealed partial class Reshape : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Reshape), 0, "input");
public static readonly ParameterInfo Input = new(typeof(Reshape), 0, "input", ParameterKind.Input);
/// <summary>
/// Gets shape.

View File

@ -21,7 +21,7 @@ public sealed partial class Slice : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Slice), 0, "input");
public static readonly ParameterInfo Input = new(typeof(Slice), 0, "input", ParameterKind.Input);
/// <summary>
/// Gets begins.

View File

@ -15,7 +15,7 @@ public sealed partial class Transpose : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Transpose), 0, "input");
public static readonly ParameterInfo Input = new(typeof(Transpose), 0, "input", ParameterKind.Input);
/// <summary>
/// Gets perm.

View File

@ -23,7 +23,7 @@ public sealed partial class Unsqueeze : Op
/// <summary>
/// Gets input.
/// </summary>
public static readonly ParameterInfo Input = new(typeof(Unsqueeze), 0, "input");
public static readonly ParameterInfo Input = new(typeof(Unsqueeze), 0, "input", ParameterKind.Input);
/// <summary>
/// Gets dimension.

View File

@ -32,6 +32,7 @@ public abstract class TypeFunctor<TResult, TContext>
TensorType t => VisitType(t, context),
TupleType t => VisitType(t, context),
CallableType t => VisitType(t, context),
DistributedType t => VisitType(t, context),
_ => DefaultVisitType(type, context),
};
}
@ -68,6 +69,14 @@ public abstract class TypeFunctor<TResult, TContext>
/// <returns>Result.</returns>
public virtual TResult VisitType(TensorType type, TContext context) => DefaultVisitType(type, context);
/// <summary>
/// Visit pointer type.
/// </summary>
/// <param name="type">Pointer type.</param>
/// <param name="context">Context.</param>
/// <returns>Result.</returns>
public virtual TResult VisitType(PointerType type, TContext context) => DefaultVisitType(type, context);
/// <summary>
/// Visit tuple type.
/// </summary>
@ -84,6 +93,14 @@ public abstract class TypeFunctor<TResult, TContext>
/// <returns>Result.</returns>
public virtual TResult VisitType(CallableType type, TContext context) => DefaultVisitType(type, context);
/// <summary>
/// Visit dist tensor type.
/// </summary>
/// <param name="type">dist tensor type.</param>
/// <param name="context">Context.</param>
/// <returns>Result.</returns>
public virtual TResult VisitType(DistributedType type, TContext context) => DefaultVisitType(type, context);
/// <summary>
/// Default visit routine.
/// </summary>

View File

@ -57,12 +57,12 @@ public record TypePattern(Func<IRType, bool> Cond, string Reason)
public T Check<T>(T valueType, string fieldName)
where T : IRType
{
if (valueType is TensorType tensorValueType && tensorValueType.Shape.IsUnranked)
if (valueType is TensorType { Shape: { IsUnranked: true } } || valueType is DistributedType { TensorType: { Shape: { IsUnranked: true } } })
{
return valueType;
}
if (valueType == null || !MatchLeaf(valueType))
if (valueType == null || (valueType is TensorType t && !MatchLeaf(t)) || (valueType is DistributedType d && !MatchLeaf(d.TensorType)))
{
var cur = valueType is null ? "None" : CompilerServices.Print(valueType);
throw new InvalidOperationException($"{fieldName} Requrie <{Reason}>, But {cur}!");
@ -187,6 +187,7 @@ public static partial class TypePatternUtility
x => x switch
{
TensorType ttype => DataTypes.IsIntegral(ttype.DType),
DistributedType distributedType => DataTypes.IsIntegral(distributedType.TensorType.DType),
_ => false,
},
"IsIntegral");

View File

@ -13,6 +13,13 @@ using Nncase.Quantization;
namespace Nncase;
/// <summary>
/// The targets own compile options.
/// </summary>
public interface ITargetCompileOptions
{
}
/// <summary>
/// Target.
/// </summary>
@ -23,6 +30,12 @@ public interface ITarget
/// </summary>
string Kind { get; }
/// <summary>
/// create the current target's command and parser.
/// </summary>
/// <returns>command.</returns>
(System.CommandLine.Command Command, Func<System.CommandLine.Invocation.InvocationContext, System.CommandLine.Command, ITargetCompileOptions> Parser) RegisterCommandAndParser();
/// <summary>
/// Bind Quant Method And Quant Cosine With IR.
/// </summary>
@ -91,3 +104,12 @@ public interface ITarget
/// <returns>Module builder.</returns>
IModuleBuilder CreateModuleBuilder(string moduleKind, CompileOptions options);
}
public sealed class DefaultTargetCompileOptions : ITargetCompileOptions
{
public static readonly DefaultTargetCompileOptions Instance = new();
private DefaultTargetCompileOptions()
{
}
}

View File

@ -14,6 +14,21 @@ namespace Nncase;
/// </summary>
public static class LinqExtensions
{
/// <summary>
/// Get the ranges from range desc.
/// </summary>
/// <param name="stride">stride.</param>
/// <param name="start">start.</param>
/// <param name="stop">stop.</param>
/// <returns>Ranges.</returns>
public static IEnumerable<Range> Ranges(this int stride, int start, int stop)
{
for (int i = start; i < stop; i += stride)
{
yield return new Range(i, Math.Min(stop, i + stride));
}
}
/// <summary>
/// Get cartesian product.
/// </summary>
@ -31,6 +46,23 @@ public static class LinqExtensions
select accseq.Concat(new[] { item }));
}
/// <summary>
/// Get the permutation of the source.
/// </summary>
/// <typeparam name="T">Element type.</typeparam>
/// <param name="source">Source sequences.</param>
/// <returns>Permutated sequences.</returns>
public static IEnumerable<T[]> Permutate<T>(this IEnumerable<T> source)
{
return Permutation(source, Enumerable.Empty<T>());
IEnumerable<T[]> Permutation(IEnumerable<T> reminder, IEnumerable<T> prefix) =>
!reminder.Any() ? new[] { prefix.ToArray() } :
reminder.SelectMany((c, i) => Permutation(
reminder.Take(i).Concat(reminder.Skip(i + 1)).ToArray(),
prefix.Append(c)));
}
/// <summary>
/// take or default.
/// </summary>

View File

@ -4,7 +4,7 @@
<Nullable>enable</Nullable>
<ImplicitUsings>enable</ImplicitUsings>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
</PropertyGroup>
<ItemGroup>
@ -21,6 +21,7 @@
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
<PackageReference Include="Microsoft.Extensions.Options" />
<PackageReference Include="Microsoft.Toolkit.HighPerformance" />
<PackageReference Include="System.CommandLine" />
<PackageReference Include="NetFabric.Hyperlinq" />
<PackageReference Include="System.Reactive" />
<PackageReference Include="GiGraph.Dot" />

View File

@ -12,35 +12,43 @@ using Nncase.TIR;
namespace Nncase.Passes.Mutators;
/// <summary>
/// Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional Load/Store. Also remove Block to ensure that the flattened TIR can not be scheduled again.
/// Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional Load/Store.
/// </summary>
public sealed class FlattenBuffer : ExprRewriter
{
/// <inheritdoc/>
protected override Expr RewriteLeafBlock(Block expr)
{
if (!expr.IterVars.IsEmpty)
{
throw new InvalidOperationException("Non-opaque blocks are not allowed in FlattenBuffer. Please call pass ConvertBlocksToOpaque before.");
}
// 1. Visit the body
var predicate = expr.Predicate;
if (predicate is TensorConst { Value: { Length: 1 } t }
&& t.ToScalar<bool>())
// TODO: put the unfold block into this.
if (expr.Predicate is TensorConst tc && tc.Value.ToScalar<bool>() == true)
{
return expr.Body;
}
else
return T.Nop();
}
/// <inheritdoc/>
protected override Expr RewriteLeafCall(Call expr)
{
if (expr.Target is IR.Buffers.BufferLoad)
{
return new IfThenElse(predicate, expr.Body);
var indices = (IR.Tuple)expr[IR.Buffers.BufferLoad.Indices];
var input = (TIR.Buffer)expr[IR.Buffers.BufferLoad.Input];
return T.Load(input.MemSpan, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i])));
}
else if (expr.Target is IR.Buffers.BufferStore)
{
var indices = (IR.Tuple)expr[IR.Buffers.BufferStore.Indices];
var input = (TIR.Buffer)expr[IR.Buffers.BufferStore.Input];
return T.Store(input.MemSpan, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i])), expr[IR.Buffers.BufferStore.Value]);
}
else if (expr.Target is IR.Buffers.MatchBuffer && expr.Arguments[0] is TIR.Buffer { MemSpan: { Start: Const or Var } })
{
// remove the all fixed match operation.
return T.Nop();
}
// Step 3. Handle allocations in reverse order
// TODO add the alloc buffers.
// for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
// const Buffer& buffer = new_block->alloc_buffers[i - 1];
// body = MakeAllocStmt(buffer, std::move(body));
// }
return expr;
}
}

View File

@ -0,0 +1,51 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System.Reactive;
using Nncase.Evaluator;
using Nncase.IR;
using Nncase.Passes;
using Nncase.TIR;
namespace Nncase.Passes.Mutators;
/// <summary>
/// remove buffer BaseMentOf/DDrOf/MmuOF.
/// </summary>
public sealed class FoldBufferSlot : ExprRewriter
{
protected internal override Expr VisitPrimFunction(TIR.PrimFunction expr, Unit context)
{
if (expr.SchedResult.IsScheduled == true)
{
return base.VisitPrimFunction(expr, context);
}
return expr;
}
protected override Expr RewriteLeafCall(Call expr)
{
if (expr.Target is IR.Buffers.BaseMentOf)
{
var locate = ((TIR.MemSpan)expr.Arguments[0]).Location;
return locate switch
{
MemoryLocation.Input => 0,
MemoryLocation.Output => 1,
MemoryLocation.Rdata => 2,
MemoryLocation.Data => 3,
_ => throw new ArgumentOutOfRangeException($"You Can't Assgin The BaseMent For {locate}!"),
};
}
else if (expr.Target is IR.Buffers.DDrOf)
{
if (expr.Arguments[0] is TIR.MemSpan buf)
{
return buf.Start;
}
}
return expr;
}
}

View File

@ -1,30 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System.Reactive;
using NetFabric.Hyperlinq;
using Nncase.Evaluator;
using Nncase.IR;
using Nncase.Passes;
namespace Nncase.Passes.Mutators;
/// <summary>
/// fold math calc operator.
/// </summary>
public sealed class FoldMathCall : ExprRewriter
{
/// <inheritdoc/>
protected override Expr RewriteLeafCall(Call expr)
{
if (expr.Target is Op op && op.GetType().Namespace is string @namespace
&& @namespace.StartsWith("Nncase.IR.Math"))
{
return expr.Arguments.AsValueEnumerable().All(x => x is Const)
? Const.FromValue(CompilerServices.Evaluate(expr))
: expr;
}
return expr;
}
}

View File

@ -50,10 +50,4 @@ public static class Mutator
/// </summary>
/// <returns>RemoveNop.</returns>
public static Func<ExprRewriter> RemoveNop() => () => new Mutators.RemoveNop();
/// <summary>
/// fold math calc operator.
/// </summary>
/// <returns>FoldMathCall.</returns>
public static Func<ExprRewriter> FoldMathCall() => () => new Mutators.FoldMathCall();
}

View File

@ -144,7 +144,10 @@ public sealed class UnRollLoopSequential : ExprRewriter
}
}
protected override Expr VisitLeafPhysicalBuffer(PhysicalBuffer expr, Unit context) => expr;
protected override Expr VisitLeafMemSpan(MemSpan expr, Unit context)
{
return expr.With(Clone(expr.Start, context), Clone(expr.Size, context));
}
protected override Expr VisitLeafVar(Var expr, Unit context)
{
@ -189,9 +192,10 @@ public sealed class UnRollLoopSequential : ExprRewriter
return CSE(expr.With(start: Clone(expr.Start, context), stop: Clone(expr.Stop, context), step: Clone(expr.Step, context)));
}
protected override Expr VisitLeafLogicalBuffer(LogicalBuffer expr, Unit context)
protected override Expr VisitLeafBuffer(TIR.Buffer expr, Unit context)
{
return expr.With(
memSpan: Clone<MemSpan>(expr.MemSpan, context),
dimensions: CloneArray(expr.Dimensions, context).Select(e => CSE(e)).ToArray(),
strides: CloneArray(expr.Strides, context));
}

View File

@ -10,52 +10,6 @@ using Nncase.TIR;
namespace Nncase.Schedule;
/// <summary>
/// the memory type.
/// </summary>
public enum MemoryLocation : byte
{
/// <summary>
/// input.
/// </summary>
Input = 0,
/// <summary>
/// output.
/// </summary>
Output = 1,
/// <summary>
/// constant data.
/// </summary>
Rdata = 2,
/// <summary>
/// compute temp data.
/// </summary>
Data = 3,
/// <summary>
/// shared data.
/// </summary>
SharedData = 4,
/// <summary>
/// l2 data.
/// </summary>
L2Data = 5,
/// <summary>
/// L1 data.
/// </summary>
L1Data = 6,
/// <summary>
/// base addr.
/// </summary>
PrivateBase = 64,
}
/// <summary>
/// the scheduler interface.
/// </summary>
@ -261,12 +215,12 @@ public sealed class SchedFunctionResult
/// <summary>
/// Gets the buffer allocation.
/// </summary>
public HashSet<TIR.PhysicalBuffer> Rdatas { get; }
public Dictionary<IR.Const, ValueRange<long>> Rdatas { get; }
/// <summary>
/// Gets or sets the data section length.
/// </summary>
public int DataUsage { get; set; }
public long DataUsage { get; set; }
/// <summary>
/// Gets or sets a value indicating whether the Scheduled status.
@ -296,8 +250,8 @@ public sealed class SchedFunctionResult
return true;
}
return EqualityComparer<HashSet<TIR.PhysicalBuffer>>.Default.Equals(Rdatas, result.Rdatas) &&
EqualityComparer<int>.Default.Equals(DataUsage, result.DataUsage);
return EqualityComparer<Dictionary<IR.Const, ValueRange<long>>>.Default.Equals(Rdatas, result.Rdatas) &&
EqualityComparer<long>.Default.Equals(DataUsage, result.DataUsage);
}
/// <inheritdoc/>

View File

@ -267,31 +267,34 @@ public record SelectedRange(int Start, int End, Padding Padding)
/// <summary>
/// buffer.
/// </summary>
public abstract class Buffer : Expr
public sealed class Buffer : Expr
{
public Buffer(string name, DataType elemType, Schedule.MemoryLocation memoryLocation, Expr[] operands)
: base(operands.ToArray())
public Buffer(string name, DataType elemType, MemSpan memSpan, Expr[] dimensions, Expr[] strides)
: base(new[] { memSpan }.Concat(dimensions).Concat(strides))
{
Name = name;
ElemType = elemType;
MemLocation = memoryLocation;
Rank = dimensions.Length;
}
public string Name { get; }
public DataType ElemType { get; }
public Schedule.MemoryLocation MemLocation { get; }
/// <summary>
/// Gets if this buffer from the constant !.
/// </summary>
public TensorConst? Const { get; init; }
/// <summary>
/// Gets rank of the tensor: number of dimensions.
/// </summary>
public abstract int Rank { get; }
public int Rank { get; }
/// <summary>
/// Gets the shape.
/// </summary>
public MemSpan MemSpan => (MemSpan)Operands[0];
/// <summary>
/// Gets the shape.
/// </summary>
public ReadOnlySpan<Expr> Dimensions => Operands[1..(1 + Rank)];
/// <summary>
/// Gets the strides.
@ -299,201 +302,23 @@ public abstract class Buffer : Expr
/// This Strides is by elements not by bytes!
/// </remarks>
/// </summary>
public abstract ReadOnlySpan<Expr> Strides { get; }
public ReadOnlySpan<Expr> Strides => Operands[(1 + Rank)..(1 + Rank + Rank)];
/// <summary>
/// Gets the shape.
/// </summary>
public abstract ReadOnlySpan<Expr> Dimensions { get; }
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context) => functor.VisitBuffer(this, context);
public Buffer With(MemSpan? memSpan = null, Expr[]? dimensions = null, Expr[]? strides = null)
=> new Buffer(Name, ElemType, memSpan ?? MemSpan, dimensions ?? Dimensions.ToArray(), strides ?? Strides.ToArray());
/// <inheritdoc/>
public override bool Equals(object? obj)
{
if (obj is not Buffer other)
if (ReferenceEquals(this, obj))
{
return false;
return true;
}
if (Const is not null && !Const.Equals(other.Const))
{
return false;
}
return string.Equals(Name, other.Name, StringComparison.Ordinal) &&
ElemType.Equals(other.ElemType) &&
MemLocation.Equals(other.MemLocation) &&
Rank.Equals(other.Rank) &&
base.Equals(obj);
return obj is TIR.Buffer other && GetHashCode() == other.GetHashCode() && Name == other.Name && ElemType == other.ElemType && Rank == other.Rank && Operands.SequenceEqual(other.Operands);
}
}
/// <summary>
/// the logical buffer.
/// </summary>
public sealed class LogicalBuffer : Buffer
{
/// <summary>
/// Initializes a new instance of the <see cref="LogicalBuffer"/> class.
/// create from the IRType.
/// </summary>
/// <param name="name">the name.</param>
/// <param name="location">the location.</param>
/// <param name="elemType">prim type.</param>
/// <param name="dimensions">the shape.</param>
/// <param name="strides">the strides.</param>
public LogicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan<Expr> dimensions, ReadOnlySpan<Expr> strides)
: base(name, elemType, location, ArrayUtility.Concat(dimensions, strides))
{
Rank = dimensions.Length;
}
/// <summary>
/// Initializes a new instance of the <see cref="LogicalBuffer"/> class.
/// <see cref="LogicalBuffer"/>.
/// </summary>
public LogicalBuffer(string name, Schedule.MemoryLocation location, TensorConst tensor)
: this(name, tensor.Value.ElementType, location, ArrayUtility.ToExprArray(tensor.Value.Dimensions), ArrayUtility.ToExprArray(tensor.Value.Strides))
{
Const = tensor;
}
/// <summary>
/// Initializes a new instance of the <see cref="LogicalBuffer"/> class.
/// <seealso cref="LogicalBuffer"/>
/// </summary>
public LogicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan<Expr> dimensions)
: this(name, elemType, location, dimensions, TensorUtilities.GetStrides(dimensions))
{
}
/// <summary>
/// Gets get the total length.
/// </summary>
public Expr Length => TensorUtilities.GetProduct(Dimensions);
/// <summary>
/// Gets the shape.
/// </summary>
public override ReadOnlySpan<Expr> Dimensions => Operands[0..Rank];
/// <summary>
/// Gets the strides.
/// </summary>
public override ReadOnlySpan<Expr> Strides => Operands[Rank..];
/// <inheritdoc/>
public override int Rank { get; }
/// <inheritdoc/>
public override string ToString()
{
return $"LogicalBuffer({Name}, {ElemType}, {nameof(MemLocation)})";
}
/// <inheritdoc/>
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
=> functor.VisitLogicalBuffer(this, context);
public LogicalBuffer With(string? name = null, DataType? elemType = null, Schedule.MemoryLocation? location = null, Expr[]? dimensions = null, Expr[]? strides = null)
=> new LogicalBuffer(name ?? Name, elemType ?? ElemType, location ?? MemLocation, dimensions ?? Dimensions, strides ?? Strides) { Const = Const };
}
/// <summary>
/// the physical buffer.
/// </summary>
public sealed class PhysicalBuffer : Buffer
{
private readonly int[] _fixedDimensions;
private readonly int[] _fixedStrides;
/// <summary>
/// Initializes a new instance of the <see cref="PhysicalBuffer"/> class.
/// ctor for physical buffer.
/// </summary>
public PhysicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan<int> dimensions, ReadOnlySpan<int> strides, int start, int size)
: base(name, elemType, location, Array.Empty<Expr>())
{
Start = start;
Size = size;
_fixedDimensions = dimensions.ToArray();
_fixedStrides = strides.ToArray();
}
/// <summary>
/// Initializes a new instance of the <see cref="PhysicalBuffer"/> class.
/// <see cref="PhysicalBuffer"/>.
/// </summary>
public PhysicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan<int> dimensions, int start, int size)
: this(name, elemType, location, dimensions, TensorUtilities.GetStrides(dimensions), start, size)
{
}
/// <summary>
/// Initializes a new instance of the <see cref="PhysicalBuffer"/> class.
/// <see cref="PhysicalBuffer"/>.
/// </summary>
public PhysicalBuffer(string name, Schedule.MemoryLocation location, TensorConst tensor, int start, int size)
: this(name, tensor.Value.ElementType, location, tensor.Value.Dimensions, tensor.Value.Strides, start, size)
{
Const = tensor;
}
/// <summary>
/// Gets fixed dimensions.
/// </summary>
public ReadOnlySpan<int> FixedDimensions => _fixedDimensions;
/// <summary>
/// Gets fixed strides.
/// </summary>
public ReadOnlySpan<int> FixedStrides => _fixedStrides;
/// <summary>
/// Gets or sets start.
/// </summary>
public int Start { get; set; }
/// <summary>
/// Gets total size in bytes.
/// </summary>
public int Size { get; init; }
/// <summary>
/// Gets dimensions.
/// </summary>
public override ReadOnlySpan<Expr> Dimensions => ArrayUtility.ToExprArray(FixedDimensions);
/// <summary>
/// Gets strides.
/// </summary>
public override ReadOnlySpan<Expr> Strides => ArrayUtility.ToExprArray(FixedStrides);
/// <summary>
/// Gets shape.
/// </summary>
public Shape Shape => new Shape(FixedDimensions);
/// <inheritdoc/>
public override int Rank => FixedDimensions.Length;
/// <inheritdoc/>
public override string ToString()
{
return $"PhysicalBuffer({Name}, {ElemType}, {nameof(MemLocation)})";
}
/// <inheritdoc/>
public override bool Equals(object? obj)
{
return base.Equals(obj) && obj is PhysicalBuffer other &&
FixedDimensions.SequenceEqual(other.FixedDimensions) &&
FixedStrides.SequenceEqual(other.FixedStrides);
}
/// <inheritdoc/>
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
=> functor.VisitPhysicalBuffer(this, context);
public PhysicalBuffer With(string? name = null, DataType? elemType = null, Schedule.MemoryLocation? location = null, int[]? dimensions = null, int[]? strides = null, int? start = null, int? size = null)
=> new PhysicalBuffer(name ?? Name, elemType ?? ElemType, location ?? MemLocation, dimensions ?? FixedDimensions, strides ?? FixedStrides, start ?? Start, size ?? Size) { Const = Const };
protected override int GetHashCodeCore() => HashCode.Combine(Name, ElemType, Rank, base.GetHashCodeCore());
}

View File

@ -1,40 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR;
using Nncase.Utilities;
namespace Nncase.TIR;
/// <summary>
/// Buffer load node.
/// </summary>
public sealed class BufferLoad : Expr
{
public BufferLoad(PhysicalBuffer buffer, ReadOnlySpan<Expr> indices)
: base(ArrayUtility.Concat(buffer, indices))
{
}
/// <summary>
/// Gets the buffer to be loaded.
/// </summary>
public PhysicalBuffer Buffer => (PhysicalBuffer)Operands[0];
/// <summary>
/// Gets the buffer indices.
/// </summary>
public ReadOnlySpan<Expr> Indices => Operands.Slice(1);
/// <inheritdoc/>
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
=> functor.VisitBufferLoad(this, context);
public BufferLoad With(PhysicalBuffer? buffer = null, Expr[]? indices = null)
=> new BufferLoad(buffer ?? Buffer, indices ?? Indices);
}

View File

@ -1,47 +0,0 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR;
namespace Nncase.TIR;
/// <summary>
/// Buffer store node.
/// </summary>
public sealed class BufferStore : Expr
{
private readonly int _indicesCount;
public BufferStore(PhysicalBuffer buffer, ReadOnlySpan<Expr> indices, Expr value)
: base(new Expr[] { buffer }.Concat(indices.ToArray()).Append(value).ToArray())
{
_indicesCount = indices.Length;
}
/// <summary>
/// Gets the buffer.
/// </summary>
public PhysicalBuffer Buffer => (PhysicalBuffer)Operands[0];
/// <summary>
/// Gets the value we to be stored.
/// </summary>
public ReadOnlySpan<Expr> Indices => Operands[1.._indicesCount];
/// <summary>
/// Gets the indices location to be stored.
/// </summary>
public Expr Value => Operands[_indicesCount + 1];
/// <inheritdoc/>
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
=> functor.VisitBufferStore(this, context);
public BufferStore With(PhysicalBuffer? buffer = null, Expr[]? indices = null, Expr? value = null)
=> new BufferStore(buffer ?? Buffer, indices ?? Indices, value ?? Value);
}

Some files were not shown because too many files have changed in this diff Show More