mirror of https://github.com/kendryte/nncase.git
Feature/cpu (#1019)
* add layernorm * pass reduce * add comment * add layer norm test * fix layernorm * fix layernorm * add demo2 * fix build / add view * update layernorm * support layernorm of llama * fix build * add demo2 * pass ym * pass demo2 * fix onnx external data importer * fuse MHA of llama * add more cpu kernels * update MHA fusion * reorder MHA weights * add demo3 * add demo3 compute statge 1 * fix build * fix __tdma_all_sync_apply * add to v35 * dump const * update demo3 golden * compiled * support multiple output compare * fix MHA kernel * resplit v2 * fix MHA kernel * to v26 * push * fix double free * fix mha kernel * fix V35 * fix all * support rmsnrom * fix v22 * fix v22 v10 * pass v28 * pass v43 * remove dump * add other part * pass all llama65b decoder layers * pass gather * open 32 threads for demo4 * update binary/unary with external op * fix codes using stdlib * update kernel inputs * Fix gather * refactor cpu cmodel * update demo head * pass graph to tir * fix head main * pass norm case * update demo names * add xpu source gen * Fix head kernel segment fault * fix cost evaluator * Fix head kernel cos similarity * decoder layer pass input layernorm * Add uanry demo * pass v30 of decoder layer * fix softmax * Enable ImmOutput * fix malloc * remove debug macro * Add ImmOut * fix tdma store * refactor cpu runtime * refactor method table * fix cpu test * refactor auto distributed * update cpu test * fix rdata * update cpu test with rdata * fix typeinfer * add XPU Op layernorm * update layernorm cost * fix cost evaluator * fix layernorm * add partial resplit * add rvv matmul * add codegen of cpu gather * add concat/slice codegen * merge * add codegen of cpu softmax * update slice cpu case * fix slice * fix cpu concat * fix cpu concat * Apply code-format changes * fix build * Apply code-format changes * add codegen of transpose * add reshape * pass reshape2 * update stackvm * merge * fix build * update compile * fix to slicing * fix negative axis * fix matmul evaluator * add NormAxis * fix ToSlice * fix matmul * add GatherReduceScatter * fix ToSlice * refactor auto dist * fix boxing partial to slice codegen * softmax support split on axis * add conv2d cpu kernel * disable outter split on inner splited axis * fix binary distributedtype infer * fixGetPartialCandidateNDSBPs * pass cpu conv2d * support dilated conv2d * add mha pattern * add combine reshape transpose * fix mha fusion/ add rules * fix rdata map dispose * add xpu reduce arg * Apply code-format changes * add VAE fusion * fuse VAE * support xpu instance norm * Apply code-format changes * add reduce arg * Apply code-format changes * fix to tir keep vars order * Apply code-format changes * add XPU resize * Apply code-format changes * fix resize cpu kernel op * fix Resize * Update layernom op for test * Apply code-format changes * fix conv2d kernel * fix boxing with reshape * fix build * fix pytest compare * add gelu kernels * add xpu cast * fix swish type infer * support xpu expand * Update layernorm rvv codes * fix binary broadcast with distributed broadcast * support multi outputs * fix single output * fix new linked section * fuse Unet * add cos dump * fix build * speed up onnx external data load * add typeinfer case for binary/matmul * move matmul rvv to kernels * fix conv2d kernel * fix Unet Fusion * optimize dynamic onnx * change fusion counter * fix conv2d if split is partial * split conv to conv+bias+clamp, and add xpu clamp * update fusion merger * fix slice with negative axis * llama-4-decoder pass (x86/rv64) * text encoder/vae decoder pass (x86/rv64) * fix conan config * Fix cpu/test compile * fix cmake config * fix synax err * fix synax err * normallize axes of slice * disable module cpu on windows * donot split softmax on axis * add softmax kernel test * add rvv instance norm * clean modules dir * add rvv clamp * Clean modules dir * fix match result * Apply code-format changes * Clean Tests * fix csproj * fix buffer schedule * Add unet pytest * add target's commands * fix buffer and memspan hashcode and equals * fix unitest * fix unittest * fix test_cli output dump * fix command line * fix type infer * fix format * fix all test * Apply code-format changes * fix merge * Apply code-format changes * fix merge * fix runtime build * fix kernel test build * Apply code-format changes * fix use mean * Apply code-format changes * optimize dot dump * fix merge * fix output when test cli --------- Co-authored-by: xhuohai <xhuohai@users.noreply.github.com> Co-authored-by: 郑启航 <597323109@qq.com> Co-authored-by: zhen8838 <zhen8838@users.noreply.github.com> Co-authored-by: lerenhua <2532375005@qq.com> Co-authored-by: liuzhiming <liuzhiming@canaan-creative.com> Co-authored-by: liuzm6217-jianan <liuzm6217-jianan@users.noreply.github.com>pull/1121/head
parent
21eccd21b9
commit
338ba1070d
|
@ -68,7 +68,7 @@ artifacts/
|
|||
*.pidb
|
||||
*.svclog
|
||||
*.scc
|
||||
|
||||
*.bin
|
||||
# Chutzpah Test files
|
||||
_Chutzpah*
|
||||
|
||||
|
@ -306,4 +306,4 @@ cmake-build-*
|
|||
*gmodel_dump_dir*
|
||||
*.ipynb_checkpoints*
|
||||
# Auto generated files
|
||||
# generated/
|
||||
# generated/
|
||||
|
|
|
@ -12,7 +12,6 @@ endif()
|
|||
|
||||
if(NOT DEFINED NNCASE_VERSION_SUFFIX)
|
||||
find_package (Git)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${GIT_EXECUTABLE} describe --always --dirty --tag
|
||||
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
|
||||
|
@ -274,5 +273,5 @@ if(BUILD_TESTING)
|
|||
endif()
|
||||
|
||||
# Modules
|
||||
#add_subdirectory(modules/k210)
|
||||
|
||||
#add_subdirectory(modules/vulkan)
|
||||
|
|
|
@ -49,8 +49,9 @@
|
|||
<PackageVersion Include="OrtKISharp" Version="0.0.2" />
|
||||
<PackageVersion Include="RazorLight" Version="2.3.0" />
|
||||
<PackageVersion Include="Singulink.Collections.Weak" Version="1.0.2" />
|
||||
<PackageVersion Include="StyleCop.Analyzers" Version="1.2.0-beta.507" />
|
||||
<PackageVersion Include="System.CommandLine.Hosting" Version="0.3.0-alpha.21216.1" />
|
||||
<PackageVersion Include="StyleCop.Analyzers" Version="1.2.0-beta.435" />
|
||||
<PackageVersion Include="System.CommandLine.Hosting" Version="0.4.0-alpha.22272.1" />
|
||||
<PackageVersion Include="System.CommandLine" Version="2.0.0-beta4.22272.1" />
|
||||
<PackageVersion Include="System.Linq.Async" Version="6.0.1" />
|
||||
<PackageVersion Include="System.Reactive" Version="5.0.0" />
|
||||
<PackageVersion Include="Tomlyn.Extensions.Configuration" Version="1.0.5" />
|
||||
|
|
|
@ -1,295 +0,0 @@
|
|||
{
|
||||
"version": 2,
|
||||
"dependencies": {
|
||||
"net7.0": {
|
||||
"StyleCop.Analyzers": {
|
||||
"type": "Direct",
|
||||
"requested": "[1.2.0-beta.435, )",
|
||||
"resolved": "1.2.0-beta.435",
|
||||
"contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==",
|
||||
"dependencies": {
|
||||
"StyleCop.Analyzers.Unstable": "1.2.0.435"
|
||||
}
|
||||
},
|
||||
"Google.OrTools.runtime.linux-arm64": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.4.1874",
|
||||
"contentHash": "Z46ndZcZa2Lt5b76xU9kxVYbPLg/LfuMufhUVsu3Qo3L7Bibf7WXd9j7RRldjnuv8RIHWTqb0b+2FwwMxs0c5A=="
|
||||
},
|
||||
"Google.OrTools.runtime.linux-x64": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.4.1874",
|
||||
"contentHash": "zGeDb8FuvP9HXjrsU7krVXtSDFpR+DUGNEsH51k94jL9tzf2vWYI8+WUBRHZ/cGe50dpLr+vIjfcNo3gFyOpkQ=="
|
||||
},
|
||||
"Google.OrTools.runtime.osx-arm64": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.4.1874",
|
||||
"contentHash": "Wo0ZfDaH6DhiQw0jZm4HWJm/oPGPpWNwOLUz+EYaoH3MLtocSxItHGQj/Ta3HyhXnYNOv+TliAH8L+8RCXu/2w=="
|
||||
},
|
||||
"Google.OrTools.runtime.osx-x64": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.4.1874",
|
||||
"contentHash": "IAfGgKR1og6vU87axK1d37Ak/4jy8B4NMoElovG/KZc/2UY+cJEAQDA709UMegtI4lBhuxTWFNUiHQYmRIB9yQ=="
|
||||
},
|
||||
"Google.OrTools.runtime.win-x64": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.4.1874",
|
||||
"contentHash": "fUs5qDnZA6itygolcX6nPuachQkY9CVvQbakIzIiRAWKcaj8umQAbFdGwbkyzp3qp34BKW5mtPVsmMyfQBBjOQ=="
|
||||
},
|
||||
"libortki": {
|
||||
"type": "Transitive",
|
||||
"resolved": "0.0.2",
|
||||
"contentHash": "svfuG5mxGY/QC/5DVheHOCELmdSP90RtxQ73j23KarPXZ9ZXW+7v1l5J77hGDyQbEh1BGrnGgKBlyn76RauGHg==",
|
||||
"dependencies": {
|
||||
"libortki-linux": "0.0.2",
|
||||
"libortki-osx": "0.0.2",
|
||||
"libortki-osx-arm64": "0.0.2",
|
||||
"libortki-win": "0.0.2"
|
||||
}
|
||||
},
|
||||
"libortki-linux": {
|
||||
"type": "Transitive",
|
||||
"resolved": "0.0.2",
|
||||
"contentHash": "b04LWD4lgGy60tys3hPFhnUpgWDM6dN5r1PI7GOcPj8VupXCaI70LKNQ5/5twbDE6rkowOGanVTw0S2wBGBqBQ=="
|
||||
},
|
||||
"libortki-osx": {
|
||||
"type": "Transitive",
|
||||
"resolved": "0.0.2",
|
||||
"contentHash": "O6Q9GLULkDkZEPAZJVKLPH0ROXGVOE7BxuddgOcHNK2oiTEM7wIRnzp2OIlYgLpaOLyxJMisbGOhtWgdzt2Wng=="
|
||||
},
|
||||
"libortki-osx-arm64": {
|
||||
"type": "Transitive",
|
||||
"resolved": "0.0.2",
|
||||
"contentHash": "4Qn2dirJmRicnUG945oWpq7HVGwgqCKKxYPMISv/MRvmpZBbXrZ1cVvRaF8WwTu4XXgfKTa1sLv+i8zLifUMeQ=="
|
||||
},
|
||||
"libortki-win": {
|
||||
"type": "Transitive",
|
||||
"resolved": "0.0.2",
|
||||
"contentHash": "HAoROgAKn8XBun11X43HZuspKlo5JGy8/OYw5IUPo7FVh5TCaPrLjGmyGYYZ2dqLlv31yv/b6s254PIRGn95cA=="
|
||||
},
|
||||
"Microsoft.Extensions.Configuration.Abstractions": {
|
||||
"type": "Transitive",
|
||||
"resolved": "6.0.0",
|
||||
"contentHash": "qWzV9o+ZRWq+pGm+1dF+R7qTgTYoXvbyowRoBxQJGfqTpqDun2eteerjRQhq5PQ/14S+lqto3Ft4gYaRyl4rdQ==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.Primitives": "6.0.0"
|
||||
}
|
||||
},
|
||||
"Microsoft.Extensions.DependencyInjection.Abstractions": {
|
||||
"type": "Transitive",
|
||||
"resolved": "6.0.0",
|
||||
"contentHash": "xlzi2IYREJH3/m6+lUrQlujzX8wDitm4QGnUu6kUXTQAWPuZY8i+ticFJbzfqaetLA6KR/rO6Ew/HuYD+bxifg=="
|
||||
},
|
||||
"Microsoft.Extensions.FileProviders.Abstractions": {
|
||||
"type": "Transitive",
|
||||
"resolved": "6.0.0",
|
||||
"contentHash": "0pd4/fho0gC12rQswaGQxbU34jOS1TPS8lZPpkFCH68ppQjHNHYle9iRuHeev1LhrJ94YPvzcRd8UmIuFk23Qw==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.Primitives": "6.0.0"
|
||||
}
|
||||
},
|
||||
"Microsoft.Extensions.Primitives": {
|
||||
"type": "Transitive",
|
||||
"resolved": "6.0.0",
|
||||
"contentHash": "9+PnzmQFfEFNR9J2aDTfJGGupShHjOuGw4VUv+JB044biSHrnmCIMD+mJHmb2H7YryrfBEXDurxQ47gJZdCKNQ==",
|
||||
"dependencies": {
|
||||
"System.Runtime.CompilerServices.Unsafe": "6.0.0"
|
||||
}
|
||||
},
|
||||
"NetFabric.Hyperlinq.Abstractions": {
|
||||
"type": "Transitive",
|
||||
"resolved": "1.3.0",
|
||||
"contentHash": "WXnEcGwmXfa8gW9N2MlcaPNUzM3NLMwnAhacbtH554F8YcoXbIkTB+uGa1Aa+9gyb/9JZgYVHnmADgJUKP52nA=="
|
||||
},
|
||||
"StyleCop.Analyzers.Unstable": {
|
||||
"type": "Transitive",
|
||||
"resolved": "1.2.0.435",
|
||||
"contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg=="
|
||||
},
|
||||
"System.Buffers": {
|
||||
"type": "Transitive",
|
||||
"resolved": "4.5.1",
|
||||
"contentHash": "Rw7ijyl1qqRS0YQD/WycNst8hUUMgrMH4FCn1nNm27M4VxchZ1js3fVjQaANHO5f3sN4isvP4a+Met9Y4YomAg=="
|
||||
},
|
||||
"System.Runtime.CompilerServices.Unsafe": {
|
||||
"type": "Transitive",
|
||||
"resolved": "6.0.0",
|
||||
"contentHash": "/iUeP3tq1S0XdNNoMz5C9twLSrM/TH+qElHkXWaPvuNOt+99G75NrV0OS2EqHx5wMN7popYjpc8oTjC1y16DLg=="
|
||||
},
|
||||
"nncase.codegen": {
|
||||
"type": "Project",
|
||||
"dependencies": {
|
||||
"Extension.Mathematics": "[1.2.12, )",
|
||||
"Nncase.Core": "[1.0.0, )",
|
||||
"Nncase.IO": "[1.0.0, )"
|
||||
}
|
||||
},
|
||||
"nncase.core": {
|
||||
"type": "Project",
|
||||
"dependencies": {
|
||||
"DryIoc.dll": "[5.3.1, )",
|
||||
"GiGraph.Dot": "[2.0.0, )",
|
||||
"Microsoft.Extensions.Hosting.Abstractions": "[6.0.0, )",
|
||||
"Microsoft.Extensions.Logging.Abstractions": "[6.0.0, )",
|
||||
"Microsoft.Extensions.Options": "[6.0.0, )",
|
||||
"Microsoft.Toolkit.HighPerformance": "[7.1.1, )",
|
||||
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
|
||||
"System.Reactive": "[5.0.0, )"
|
||||
}
|
||||
},
|
||||
"nncase.diagnostics": {
|
||||
"type": "Project",
|
||||
"dependencies": {
|
||||
"Nncase.Core": "[1.0.0, )"
|
||||
}
|
||||
},
|
||||
"nncase.egraph": {
|
||||
"type": "Project",
|
||||
"dependencies": {
|
||||
"GiGraph.Dot": "[2.0.0, )",
|
||||
"Google.OrTools": "[9.4.1874, )",
|
||||
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
|
||||
"Nncase.Core": "[1.0.0, )",
|
||||
"Nncase.Evaluator": "[1.0.0, )",
|
||||
"Singulink.Collections.Weak": "[1.0.2, )"
|
||||
}
|
||||
},
|
||||
"nncase.evaluator": {
|
||||
"type": "Project",
|
||||
"dependencies": {
|
||||
"Nncase.Core": "[1.0.0, )",
|
||||
"OrtKISharp": "[0.0.2, )"
|
||||
}
|
||||
},
|
||||
"nncase.graph": {
|
||||
"type": "Project",
|
||||
"dependencies": {
|
||||
"Nncase.Core": "[1.0.0, )",
|
||||
"Nncase.Evaluator": "[1.0.0, )"
|
||||
}
|
||||
},
|
||||
"nncase.io": {
|
||||
"type": "Project"
|
||||
},
|
||||
"nncase.modules.stackvm": {
|
||||
"type": "Project",
|
||||
"dependencies": {
|
||||
"Nncase.CodeGen": "[1.0.0, )",
|
||||
"Nncase.Passes": "[1.0.0, )"
|
||||
}
|
||||
},
|
||||
"nncase.passes": {
|
||||
"type": "Project",
|
||||
"dependencies": {
|
||||
"Nncase.Core": "[1.0.0, )",
|
||||
"Nncase.EGraph": "[1.0.0, )",
|
||||
"Nncase.Evaluator": "[1.0.0, )",
|
||||
"Nncase.Graph": "[1.0.0, )"
|
||||
}
|
||||
},
|
||||
"DryIoc.dll": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[5.3.1, )",
|
||||
"resolved": "5.3.1",
|
||||
"contentHash": "E3zclUh2CIBks1t2uBD1k18pyGFJ1YSKCrbCDbB7qCdl2RAB+k68AyDpjeplhF1ot2XPV82AgyCWBXMf0ggL1g=="
|
||||
},
|
||||
"Extension.Mathematics": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[1.2.12, )",
|
||||
"resolved": "1.2.12",
|
||||
"contentHash": "D4mn5Cab4ztPLJ0V8uMErDrO/Y61098nwrvyIOLZymVAYOQcwP1vomVWKbTagf1aPU3cX5Q7adZtQEQwOy6XEg=="
|
||||
},
|
||||
"GiGraph.Dot": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[2.0.0, )",
|
||||
"resolved": "2.0.0",
|
||||
"contentHash": "ThvS2mQVveSkTMUm04tMbRYzu1XFPV8xBHISrUMp02APjhv9IRbLu3v3upTPCywORx2Ds/c6AqEUL1WU6kPfuQ=="
|
||||
},
|
||||
"Google.OrTools": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[9.4.1874, )",
|
||||
"resolved": "9.4.1874",
|
||||
"contentHash": "jqRoI+pYlym+fhoU25u+13oti5h+772bllQ9zDitTVMclDXVTiG6pxzvmYO74wnADBMdpb2SQlgiNQxoNk5dlA==",
|
||||
"dependencies": {
|
||||
"Google.OrTools.runtime.linux-arm64": "9.4.1874",
|
||||
"Google.OrTools.runtime.linux-x64": "9.4.1874",
|
||||
"Google.OrTools.runtime.osx-arm64": "9.4.1874",
|
||||
"Google.OrTools.runtime.osx-x64": "9.4.1874",
|
||||
"Google.OrTools.runtime.win-x64": "9.4.1874",
|
||||
"Google.Protobuf": "3.19.4"
|
||||
}
|
||||
},
|
||||
"Google.Protobuf": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[3.19.4, )",
|
||||
"resolved": "3.19.4",
|
||||
"contentHash": "fd07/ykL4O4FhqrZIELm5lmiyOHfdPg9+o+hWr6tcfRdS7tHXnImg/2wtogLzlW2eEmr0J7j6ZrZvaWOLiJbxQ=="
|
||||
},
|
||||
"Microsoft.Extensions.Hosting.Abstractions": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[6.0.0, )",
|
||||
"resolved": "6.0.0",
|
||||
"contentHash": "GcT5l2CYXL6Sa27KCSh0TixsRfADUgth+ojQSD5EkzisZxmGFh7CwzkcYuGwvmXLjr27uWRNrJ2vuuEjMhU05Q==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.Configuration.Abstractions": "6.0.0",
|
||||
"Microsoft.Extensions.DependencyInjection.Abstractions": "6.0.0",
|
||||
"Microsoft.Extensions.FileProviders.Abstractions": "6.0.0"
|
||||
}
|
||||
},
|
||||
"Microsoft.Extensions.Logging.Abstractions": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[6.0.0, )",
|
||||
"resolved": "6.0.0",
|
||||
"contentHash": "/HggWBbTwy8TgebGSX5DBZ24ndhzi93sHUBDvP1IxbZD7FDokYzdAr6+vbWGjw2XAfR2EJ1sfKUotpjHnFWPxA=="
|
||||
},
|
||||
"Microsoft.Extensions.Options": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[6.0.0, )",
|
||||
"resolved": "6.0.0",
|
||||
"contentHash": "dzXN0+V1AyjOe2xcJ86Qbo233KHuLEY0njf/P2Kw8SfJU+d45HNS2ctJdnEnrWbM9Ye2eFgaC5Mj9otRMU6IsQ==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.DependencyInjection.Abstractions": "6.0.0",
|
||||
"Microsoft.Extensions.Primitives": "6.0.0"
|
||||
}
|
||||
},
|
||||
"Microsoft.Toolkit.HighPerformance": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[7.1.1, )",
|
||||
"resolved": "7.1.1",
|
||||
"contentHash": "TRnvDpZPXO30hTOtjfLw6Y9BtTKtTpzk9lefeh4RMCaUihWrVKQR454nYH4/mMJAh+LXqfAPyk0kfkJs0Amopw=="
|
||||
},
|
||||
"NetFabric.Hyperlinq": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[3.0.0-beta48, )",
|
||||
"resolved": "3.0.0-beta48",
|
||||
"contentHash": "oYUhXvxNS8bBJWqNkvx5g8y0P/0LtyqS2pN0w4OWjVDNWEpLbdbvPy9w/9z1n2PrqIjX3jxUsEnoCmxxGnI3gw==",
|
||||
"dependencies": {
|
||||
"NetFabric.Hyperlinq.Abstractions": "1.3.0",
|
||||
"System.Buffers": "4.5.1",
|
||||
"System.Runtime.CompilerServices.Unsafe": "5.0.0"
|
||||
}
|
||||
},
|
||||
"OrtKISharp": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[0.0.2, )",
|
||||
"resolved": "0.0.2",
|
||||
"contentHash": "q8j0yR5836Zhv9WB9BFkQt1UaEFyibq8bqJcTiULlILF6/sz8z7Wy2N8sgYdDKsdW25zncIz7j6IDbKM5ynePg==",
|
||||
"dependencies": {
|
||||
"libortki": "0.0.2"
|
||||
}
|
||||
},
|
||||
"Singulink.Collections.Weak": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[1.0.2, )",
|
||||
"resolved": "1.0.2",
|
||||
"contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA=="
|
||||
},
|
||||
"System.Reactive": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[5.0.0, )",
|
||||
"resolved": "5.0.0",
|
||||
"contentHash": "erBZjkQHWL9jpasCE/0qKAryzVBJFxGHVBAvgRN1bzM0q2s1S4oYREEEL0Vb+1kA/6BKb5FjUZMp5VXmy+gzkQ=="
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/18 下午5:04:31 +08:00. */
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:08 AM +00:00. */
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
@ -59,7 +59,7 @@ internal partial class CodeGenVisitor
|
|||
Emitter.T.L2Normalization();
|
||||
break;
|
||||
case IR.NN.LayerNorm top:
|
||||
Emitter.T.LayerNorm(top.Axis, top.Epsilon);
|
||||
Emitter.T.LayerNorm(top.Axis, top.Epsilon, top.UseMean);
|
||||
break;
|
||||
case IR.NN.LeakyRelu top:
|
||||
Emitter.T.LeakyRelu();
|
||||
|
@ -176,7 +176,7 @@ internal partial class CodeGenVisitor
|
|||
Emitter.T.Cast(top.NewType, top.CastMode);
|
||||
break;
|
||||
case IR.Tensors.Concat top:
|
||||
Emitter.T.Concat();
|
||||
Emitter.T.Concat(top.Axis);
|
||||
break;
|
||||
case IR.Tensors.ConstantOfShape top:
|
||||
Emitter.T.ConstantOfShape();
|
||||
|
@ -191,7 +191,7 @@ internal partial class CodeGenVisitor
|
|||
Emitter.T.Flatten();
|
||||
break;
|
||||
case IR.Tensors.Gather top:
|
||||
Emitter.T.Gather();
|
||||
Emitter.T.Gather(top.Axis);
|
||||
break;
|
||||
case IR.Tensors.GatherElements top:
|
||||
Emitter.T.GatherElements();
|
||||
|
@ -205,9 +205,6 @@ internal partial class CodeGenVisitor
|
|||
case IR.Tensors.IndexOf top:
|
||||
Emitter.T.IndexOf();
|
||||
break;
|
||||
case IR.Tensors.LSTM top:
|
||||
Emitter.T.LSTM(top.Direction, top.Layout, top.Activations);
|
||||
break;
|
||||
case IR.Tensors.Prod top:
|
||||
Emitter.T.Prod();
|
||||
break;
|
||||
|
@ -289,6 +286,9 @@ internal partial class CodeGenVisitor
|
|||
case IR.ShapeExpr.UnsqueezeShape top:
|
||||
Emitter.T.UnsqueezeShape();
|
||||
break;
|
||||
case IR.RNN.LSTM top:
|
||||
Emitter.T.LSTM(top.Direction, top.Layout, top.Activations);
|
||||
break;
|
||||
case IR.Random.Normal top:
|
||||
Emitter.T.Normal(top.Type);
|
||||
break;
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:30 +08:00. */
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:08 AM +00:00. */
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
@ -723,10 +723,11 @@ public partial class StackVMEmitter
|
|||
}
|
||||
|
||||
///<summary>.</summary>
|
||||
public void Concat()
|
||||
public void Concat(int axis)
|
||||
{
|
||||
_emitter.Write((byte)100);
|
||||
_emitter.Write((ushort)11);
|
||||
_emitter.Write(axis);
|
||||
}
|
||||
|
||||
///<summary>.</summary>
|
||||
|
@ -841,10 +842,11 @@ public partial class StackVMEmitter
|
|||
}
|
||||
|
||||
///<summary>.</summary>
|
||||
public void Gather()
|
||||
public void Gather(int axis)
|
||||
{
|
||||
_emitter.Write((byte)100);
|
||||
_emitter.Write((ushort)27);
|
||||
_emitter.Write(axis);
|
||||
}
|
||||
|
||||
///<summary>.</summary>
|
||||
|
@ -925,12 +927,13 @@ public partial class StackVMEmitter
|
|||
}
|
||||
|
||||
///<summary>.</summary>
|
||||
public void LayerNorm(int axis, float epsilon)
|
||||
public void LayerNorm(int axis, float epsilon, bool useMean)
|
||||
{
|
||||
_emitter.Write((byte)100);
|
||||
_emitter.Write((ushort)39);
|
||||
_emitter.Write(axis);
|
||||
_emitter.Write(epsilon);
|
||||
_emitter.Write(useMean);
|
||||
}
|
||||
|
||||
///<summary>.</summary>
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.CommandLine.Invocation;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
@ -21,13 +22,15 @@ namespace Nncase.Targets;
|
|||
/// </summary>
|
||||
public class CPUTarget : ITarget
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets kind.
|
||||
/// </summary>
|
||||
public static readonly string Kind = "cpu";
|
||||
public const string Kind = "cpu";
|
||||
|
||||
string ITarget.Kind => Kind;
|
||||
|
||||
public (System.CommandLine.Command Command, Func<InvocationContext, System.CommandLine.Command, ITargetCompileOptions> Parser) RegisterCommandAndParser()
|
||||
{
|
||||
return (new System.CommandLine.Command(Kind), (_, _) => DefaultTargetCompileOptions.Instance);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public void ParseTargetDependentOptions(IConfigurationSection configure)
|
||||
{
|
||||
|
|
|
@ -4,11 +4,11 @@
|
|||
"net7.0": {
|
||||
"StyleCop.Analyzers": {
|
||||
"type": "Direct",
|
||||
"requested": "[1.2.0-beta.507, )",
|
||||
"resolved": "1.2.0-beta.507",
|
||||
"contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==",
|
||||
"requested": "[1.2.0-beta.435, )",
|
||||
"resolved": "1.2.0-beta.435",
|
||||
"contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==",
|
||||
"dependencies": {
|
||||
"StyleCop.Analyzers.Unstable": "1.2.0.507"
|
||||
"StyleCop.Analyzers.Unstable": "1.2.0.435"
|
||||
}
|
||||
},
|
||||
"Google.OrTools.runtime.linux-arm64": {
|
||||
|
@ -103,8 +103,8 @@
|
|||
},
|
||||
"StyleCop.Analyzers.Unstable": {
|
||||
"type": "Transitive",
|
||||
"resolved": "1.2.0.507",
|
||||
"contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw=="
|
||||
"resolved": "1.2.0.435",
|
||||
"contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg=="
|
||||
},
|
||||
"System.Buffers": {
|
||||
"type": "Transitive",
|
||||
|
@ -134,6 +134,7 @@
|
|||
"Microsoft.Extensions.Options": "[6.0.0, )",
|
||||
"Microsoft.Toolkit.HighPerformance": "[7.1.1, )",
|
||||
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
|
||||
"System.CommandLine": "[2.0.0-beta4.22272.1, )",
|
||||
"System.Reactive": "[5.0.0, )"
|
||||
}
|
||||
},
|
||||
|
@ -271,6 +272,12 @@
|
|||
"resolved": "1.0.2",
|
||||
"contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA=="
|
||||
},
|
||||
"System.CommandLine": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[2.0.0-beta4.22272.1, )",
|
||||
"resolved": "2.0.0-beta4.22272.1",
|
||||
"contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg=="
|
||||
},
|
||||
"System.Reactive": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[5.0.0, )",
|
||||
|
|
|
@ -44,8 +44,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.IO", "src\Nncase.IO\
|
|||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Schedule", "src\Nncase.Schedule\Nncase.Schedule.csproj", "{8E0E0672-0F96-4EF1-BDCD-D31F96A3DF73}"
|
||||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "targets", "targets", "{A2590531-71C5-4326-88DD-6A9DB2EF0A2B}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Targets", "src\Nncase.Targets\Nncase.Targets.csproj", "{56283378-06E3-4C6E-A8BF-7BD85C92D42C}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Nncase.Simulator", "src\Nncase.Simulator\Nncase.Simulator.csproj", "{901AC17C-7B53-4B10-A2AC-EA7AEA6DC614}"
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
// https://gist.github.com/asford/544323a5da7dddad2c9174490eb5ed06
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <nncase/compiler_defs.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
|
|
|
@ -298,7 +298,7 @@ class Compiler:
|
|||
|
||||
def check_target(target: str):
|
||||
def test_target(target: str):
|
||||
return target in ["cpu", "k510", "k230"]
|
||||
return target in ["cpu", "k510", "k230", "xpu"]
|
||||
|
||||
def target_exists(target: str):
|
||||
return _nncase.Target.exists(target)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
|
||||
* +08:00.
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM
|
||||
* +00:00.
|
||||
*
|
||||
* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
|
@ -78,7 +78,7 @@ compare(runtime::stackvm::compare_op_t compare_op, value_t lhs, value_t rhs,
|
|||
kernel_context &context = default_kernel_context());
|
||||
|
||||
NNCASE_API result<value_t>
|
||||
concat(value_t input, value_t axis, value_t output = nullptr,
|
||||
concat(int32_t axis, value_t input, value_t output = nullptr,
|
||||
kernel_context &context = default_kernel_context());
|
||||
|
||||
NNCASE_API result<value_t>
|
||||
|
@ -157,7 +157,7 @@ flatten(value_t input, value_t axis, value_t output = nullptr,
|
|||
kernel_context &context = default_kernel_context());
|
||||
|
||||
NNCASE_API result<value_t>
|
||||
gather(value_t input, value_t axis, value_t index, value_t output = nullptr,
|
||||
gather(int32_t axis, value_t input, value_t index, value_t output = nullptr,
|
||||
kernel_context &context = default_kernel_context());
|
||||
|
||||
NNCASE_API result<value_t>
|
||||
|
@ -211,8 +211,8 @@ l2_normalization(value_t input, value_t output = nullptr,
|
|||
kernel_context &context = default_kernel_context());
|
||||
|
||||
NNCASE_API result<value_t>
|
||||
layer_norm(int32_t axis, float epsilon, value_t input, value_t scale,
|
||||
value_t bias, value_t output = nullptr,
|
||||
layer_norm(int32_t axis, float epsilon, bool use_mean, value_t input,
|
||||
value_t scale, value_t bias, value_t output = nullptr,
|
||||
kernel_context &context = default_kernel_context());
|
||||
|
||||
NNCASE_API result<value_t>
|
||||
|
|
|
@ -73,6 +73,7 @@ class NNCASE_API interpreter {
|
|||
|
||||
options_dict &options() noexcept;
|
||||
result<runtime_module *> find_module_by_id(size_t index) noexcept;
|
||||
result<size_t> find_id_by_module(runtime_module *module) noexcept;
|
||||
|
||||
/* V1 APIs */
|
||||
|
||||
|
|
|
@ -58,6 +58,8 @@ class NNCASE_API runtime_module {
|
|||
|
||||
result<runtime_function *> find_function_by_id(size_t index) noexcept;
|
||||
|
||||
result<size_t> find_id_by_function(runtime_function *function) noexcept;
|
||||
|
||||
protected:
|
||||
virtual result<void>
|
||||
initialize_before_functions(runtime_module_init_context &context) noexcept;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
|
||||
* +08:00.
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM
|
||||
* +00:00.
|
||||
*
|
||||
* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
|
@ -837,6 +837,7 @@ template <> struct tensor_op_reader<tensor_function_t::compare> {
|
|||
template <> struct tensor_op_reader<tensor_function_t::concat> {
|
||||
tensor_concat_op_t operator()(NNCASE_UNUSED span_reader &reader) const {
|
||||
tensor_concat_op_t op;
|
||||
op.axis = reader.read_unaligned<int32_t>();
|
||||
return op;
|
||||
}
|
||||
};
|
||||
|
@ -964,6 +965,7 @@ template <> struct tensor_op_reader<tensor_function_t::flatten> {
|
|||
template <> struct tensor_op_reader<tensor_function_t::gather> {
|
||||
tensor_gather_op_t operator()(NNCASE_UNUSED span_reader &reader) const {
|
||||
tensor_gather_op_t op;
|
||||
op.axis = reader.read_unaligned<int32_t>();
|
||||
return op;
|
||||
}
|
||||
};
|
||||
|
@ -1055,6 +1057,7 @@ template <> struct tensor_op_reader<tensor_function_t::layer_norm> {
|
|||
tensor_layer_norm_op_t op;
|
||||
op.axis = reader.read_unaligned<int32_t>();
|
||||
op.epsilon = reader.read_unaligned<float>();
|
||||
op.use_mean = reader.read_unaligned<bool>();
|
||||
return op;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
|
||||
* +08:00.
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM
|
||||
* +00:00.
|
||||
*
|
||||
* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
|
@ -190,7 +190,6 @@ enum class tensor_function_t : uint16_t {
|
|||
gather_nd = 29,
|
||||
get_item = 31,
|
||||
index_of = 36,
|
||||
lstm = 44,
|
||||
prod = 52,
|
||||
range = 55,
|
||||
rank = 57,
|
||||
|
@ -218,6 +217,7 @@ enum class tensor_function_t : uint16_t {
|
|||
squeeze_shape = 81,
|
||||
transpose_shape = 87,
|
||||
unsqueeze_shape = 93,
|
||||
lstm = 44,
|
||||
normal = 47,
|
||||
normal_like = 48,
|
||||
uniform = 90,
|
||||
|
@ -614,7 +614,9 @@ struct tensor_compare_op_t {
|
|||
compare_op_t compare_op;
|
||||
};
|
||||
|
||||
struct tensor_concat_op_t {};
|
||||
struct tensor_concat_op_t {
|
||||
int32_t axis;
|
||||
};
|
||||
|
||||
struct tensor_condition_op_t {
|
||||
bool can_fold_const_call;
|
||||
|
@ -658,7 +660,9 @@ struct tensor_fix_shape_op_t {};
|
|||
|
||||
struct tensor_flatten_op_t {};
|
||||
|
||||
struct tensor_gather_op_t {};
|
||||
struct tensor_gather_op_t {
|
||||
int32_t axis;
|
||||
};
|
||||
|
||||
struct tensor_gather_elements_op_t {};
|
||||
|
||||
|
@ -685,6 +689,7 @@ struct tensor_l2_normalization_op_t {};
|
|||
struct tensor_layer_norm_op_t {
|
||||
int32_t axis;
|
||||
float epsilon;
|
||||
bool use_mean;
|
||||
};
|
||||
|
||||
struct tensor_leaky_relu_op_t {};
|
||||
|
@ -964,8 +969,6 @@ inline std::string to_string(tensor_function_t tensor_funct) {
|
|||
return "get_item";
|
||||
case tensor_function_t::index_of:
|
||||
return "index_of";
|
||||
case tensor_function_t::lstm:
|
||||
return "lstm";
|
||||
case tensor_function_t::prod:
|
||||
return "prod";
|
||||
case tensor_function_t::range:
|
||||
|
@ -1020,6 +1023,8 @@ inline std::string to_string(tensor_function_t tensor_funct) {
|
|||
return "transpose_shape";
|
||||
case tensor_function_t::unsqueeze_shape:
|
||||
return "unsqueeze_shape";
|
||||
case tensor_function_t::lstm:
|
||||
return "lstm";
|
||||
case tensor_function_t::normal:
|
||||
return "normal";
|
||||
case tensor_function_t::normal_like:
|
||||
|
|
|
@ -47,8 +47,9 @@ result<value_t> nncase::kernels::stackvm::batch_normalization(
|
|||
}
|
||||
|
||||
result<value_t> nncase::kernels::stackvm::layer_norm(
|
||||
int32_t axis, float epsilon, value_t input, value_t scale, value_t bias,
|
||||
value_t output, [[maybe_unused]] kernel_context &context) {
|
||||
int32_t axis, float epsilon, [[maybe_unused]] bool use_mean, value_t input,
|
||||
value_t scale, value_t bias, value_t output,
|
||||
[[maybe_unused]] kernel_context &context) {
|
||||
try_input(input_mem, input);
|
||||
try_input(scale_mem, scale);
|
||||
try_input(bias_mem, bias);
|
||||
|
@ -124,7 +125,7 @@ nncase::kernels::stackvm::clamp(value_t input, value_t min, value_t max,
|
|||
KERNEL_FINISH;
|
||||
}
|
||||
|
||||
result<value_t> nncase::kernels::stackvm::concat(value_t input, value_t axis,
|
||||
result<value_t> nncase::kernels::stackvm::concat(int32_t axis, value_t input,
|
||||
value_t output,
|
||||
kernel_context &context) {
|
||||
try_tuple_input(inputs_mem, input);
|
||||
|
@ -132,7 +133,7 @@ result<value_t> nncase::kernels::stackvm::concat(value_t input, value_t axis,
|
|||
try_var(strides, get_strides(input_tuple));
|
||||
try_tuple_field0(input0, input_tuple);
|
||||
auto dtype = input0->dtype();
|
||||
try_positive_axis_with_rank(axis_value, axis, input0->shape().size());
|
||||
auto axis_value = positive_index(axis, input0->shape().size());
|
||||
auto out_shape = concat_infer_shape(shapes, axis_value);
|
||||
try_output(out_mem, output, dtype, out_shape);
|
||||
auto concat_dims = dims_t();
|
||||
|
@ -293,14 +294,15 @@ nncase::kernels::stackvm::flatten(value_t input, value_t axis, value_t output,
|
|||
KERNEL_FINISH;
|
||||
}
|
||||
|
||||
result<value_t> nncase::kernels::stackvm::gather(value_t input, value_t axis,
|
||||
result<value_t> nncase::kernels::stackvm::gather(int32_t axis, value_t input,
|
||||
value_t index, value_t output,
|
||||
kernel_context &context) {
|
||||
try_input(input_mem, input);
|
||||
try_input(index_mem, index);
|
||||
auto dtype = input_tensor->dtype();
|
||||
try_var(typecode, to_typecode(dtype));
|
||||
try_positive_axis(axis_value, axis, input_tensor);
|
||||
// try_positive_axis(axis_value, axis, input_tensor);
|
||||
auto axis_value = positive_index(axis, input_tensor->shape().size());
|
||||
auto out_shape = gather_infer_shape(input_tensor->shape(),
|
||||
index_tensor->shape(), axis_value);
|
||||
try_output(out_mem, output, dtype, out_shape);
|
||||
|
|
|
@ -54,6 +54,7 @@ else()
|
|||
add_library(simulator OBJECT ${SRCS})
|
||||
target_include_directories(simulator PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
||||
target_link_libraries(simulator PUBLIC gsl::gsl-lite)
|
||||
target_link_libraries(simulator PUBLIC fmt::fmt)
|
||||
target_link_libraries(simulator PRIVATE kernels)
|
||||
target_compile_definitions(simulator PUBLIC -DNNCASE_DLL -DNNCASE_SIMULATOR)
|
||||
if (DEFAULT_BUILTIN_RUNTIMES)
|
||||
|
|
|
@ -246,6 +246,17 @@ result<runtime_module *> interpreter::find_module_by_id(size_t index) noexcept {
|
|||
return ok(modules_[index].get());
|
||||
}
|
||||
|
||||
result<size_t> interpreter::find_id_by_module(runtime_module *module) noexcept {
|
||||
auto it = std::find_if(modules_.begin(), modules_.end(),
|
||||
[&module](const std::unique_ptr<runtime_module> &p) {
|
||||
return p.get() == module;
|
||||
});
|
||||
if (it == modules_.end()) {
|
||||
return err(std::errc::result_out_of_range);
|
||||
}
|
||||
return ok((it - modules_.begin()));
|
||||
}
|
||||
|
||||
options_dict &interpreter::options() noexcept { return options_; }
|
||||
|
||||
result<runtime_function *> interpreter::entry_function() noexcept {
|
||||
|
|
|
@ -189,6 +189,19 @@ runtime_module::find_function_by_id(size_t index) noexcept {
|
|||
return ok(functions_[index].get());
|
||||
}
|
||||
|
||||
result<size_t>
|
||||
runtime_module::find_id_by_function(runtime_function *function) noexcept {
|
||||
auto it =
|
||||
std::find_if(functions_.begin(), functions_.end(),
|
||||
[&function](const std::unique_ptr<runtime_function> &p) {
|
||||
return p.get() == function;
|
||||
});
|
||||
if (it == functions_.end()) {
|
||||
return err(std::errc::result_out_of_range);
|
||||
}
|
||||
return ok((it - functions_.begin()));
|
||||
}
|
||||
|
||||
result<void> runtime_module::initialize_before_functions(
|
||||
NNCASE_UNUSED runtime_module_init_context &context) noexcept {
|
||||
return ok();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
|
||||
* +08:00.
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM
|
||||
* +00:00.
|
||||
*
|
||||
* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
|
||||
* +08:00.
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:08 AM
|
||||
* +00:00.
|
||||
*
|
||||
* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
|
@ -207,9 +207,7 @@ result<void> stackvm_runtime_function::visit(
|
|||
dump_op("concat");
|
||||
try_var(input, pop_value());
|
||||
dump_input(input);
|
||||
try_var(axis, pop_value());
|
||||
dump_input(axis);
|
||||
try_var(output, kernels::stackvm::concat(input, axis, nullptr,
|
||||
try_var(output, kernels::stackvm::concat(op.axis, input, nullptr,
|
||||
module().kernel_context()));
|
||||
dump_output(output);
|
||||
stack_.push(std::move(output));
|
||||
|
@ -491,11 +489,9 @@ result<void> stackvm_runtime_function::visit(
|
|||
dump_op("gather");
|
||||
try_var(input, pop_value());
|
||||
dump_input(input);
|
||||
try_var(axis, pop_value());
|
||||
dump_input(axis);
|
||||
try_var(index, pop_value());
|
||||
dump_input(index);
|
||||
try_var(output, kernels::stackvm::gather(input, axis, index, nullptr,
|
||||
try_var(output, kernels::stackvm::gather(op.axis, input, index, nullptr,
|
||||
module().kernel_context()));
|
||||
dump_output(output);
|
||||
stack_.push(std::move(output));
|
||||
|
@ -683,9 +679,9 @@ result<void> stackvm_runtime_function::visit(
|
|||
dump_input(scale);
|
||||
try_var(bias, pop_value());
|
||||
dump_input(bias);
|
||||
try_var(output, kernels::stackvm::layer_norm(op.axis, op.epsilon, input,
|
||||
scale, bias, nullptr,
|
||||
module().kernel_context()));
|
||||
try_var(output, kernels::stackvm::layer_norm(
|
||||
op.axis, op.epsilon, op.use_mean, input, scale, bias,
|
||||
nullptr, module().kernel_context()));
|
||||
dump_output(output);
|
||||
stack_.push(std::move(output));
|
||||
return ok();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
|
||||
* +08:00.
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 9/20/2023 10:17:07 AM
|
||||
* +00:00.
|
||||
*
|
||||
* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <chrono>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <nncase/io_utils.h>
|
||||
|
@ -19,6 +20,8 @@
|
|||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
// constexpr size_t loop_count = 10;
|
||||
constexpr size_t loop_count = 1;
|
||||
|
||||
#define TRY(x) \
|
||||
if (x) \
|
||||
|
@ -34,8 +37,7 @@ result<void> write_tensor_buffer(value_t value, std::ofstream &of) {
|
|||
}
|
||||
|
||||
result<void> run_core(const std::string &kmodel_path,
|
||||
const std::vector<std::string> &input_bins,
|
||||
const std::string &output_bin) {
|
||||
const std::vector<std::string> &bins) {
|
||||
auto kmodel = read_file(kmodel_path);
|
||||
interpreter *interp = new interpreter();
|
||||
// auto dump_path =
|
||||
|
@ -47,16 +49,16 @@ result<void> run_core(const std::string &kmodel_path,
|
|||
|
||||
try_var(entry, interp->entry_function());
|
||||
|
||||
if (entry->parameters_size() != input_bins.size())
|
||||
if (entry->parameters_size() > bins.size())
|
||||
return err(std::errc::argument_list_too_long);
|
||||
/* create the input parameters tensor
|
||||
note the input tenosr must be contiguous
|
||||
*/
|
||||
std::vector<value_t> parameters;
|
||||
for (int i = 0; i < input_bins.size(); i++) {
|
||||
for (int i = 0; i < entry->parameters_size(); i++) {
|
||||
try_var(type, entry->parameter_type(i));
|
||||
try_var(ts_type, type.as<tensor_type>());
|
||||
auto input_pool = read_file(input_bins[i]);
|
||||
auto input_pool = read_file(bins[i]);
|
||||
gsl::span<gsl::byte> input_pool_span = {
|
||||
reinterpret_cast<gsl::byte *>(input_pool.data()),
|
||||
input_pool.size()};
|
||||
|
@ -66,21 +68,40 @@ result<void> run_core(const std::string &kmodel_path,
|
|||
parameters.push_back(_.impl());
|
||||
}
|
||||
|
||||
try_var(ret, entry->invoke({parameters.data(), parameters.size()}));
|
||||
double total_time = 0.0;
|
||||
for (size_t i = 0; i < loop_count; i++) {
|
||||
auto start_time = std::chrono::steady_clock::now();
|
||||
try_var(ret, entry->invoke({parameters.data(), parameters.size()}));
|
||||
auto end_time = std::chrono::steady_clock::now();
|
||||
total_time += (std::chrono::duration_cast<std::chrono::nanoseconds>(
|
||||
end_time - start_time)
|
||||
.count() /
|
||||
1e6);
|
||||
|
||||
std::ofstream output_stream(output_bin, std::ios::binary);
|
||||
|
||||
if (ret.is_a<tensor>()) {
|
||||
try_(write_tensor_buffer(ret, output_stream));
|
||||
} else if (ret.is_a<tuple>()) {
|
||||
try_var(tp, ret.as<tuple>());
|
||||
for (auto &&ret_v : tp->fields()) {
|
||||
try_(write_tensor_buffer(ret_v, output_stream));
|
||||
if (i == (loop_count - 1) && (entry->parameters_size() < bins.size())) {
|
||||
if (ret.is_a<tensor>()) {
|
||||
auto output_bin = bins.back();
|
||||
std::ofstream output_stream(output_bin, std::ios::binary);
|
||||
try_(write_tensor_buffer(ret, output_stream));
|
||||
output_stream.close();
|
||||
} else if (ret.is_a<tuple>()) {
|
||||
try_var(tp, ret.as<tuple>());
|
||||
auto o = 0;
|
||||
for (auto &&ret_v : tp->fields()) {
|
||||
auto output_bin = bins[entry->parameters_size() + (o++)];
|
||||
std::ofstream output_stream(output_bin, std::ios::binary);
|
||||
try_(write_tensor_buffer(ret_v, output_stream));
|
||||
output_stream.close();
|
||||
}
|
||||
} else {
|
||||
return nncase::err(std::errc::bad_message);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nncase::err(std::errc::bad_message);
|
||||
}
|
||||
output_stream.close();
|
||||
|
||||
std::cout << "interp run: " << (total_time / loop_count)
|
||||
<< " ms, fps = " << 1000 / (total_time / loop_count) << std::endl;
|
||||
|
||||
return ok();
|
||||
}
|
||||
|
||||
|
@ -92,13 +113,12 @@ result<void> run_core(const std::string &kmodel_path,
|
|||
* @return int
|
||||
*/
|
||||
int main(NNCASE_UNUSED int argc, char **argv) {
|
||||
assert(argc >= 4);
|
||||
std::vector<std::string> input_bins;
|
||||
for (int i = 2; i < argc - 1; i++) {
|
||||
input_bins.push_back(argv[i]);
|
||||
assert(argc >= 3);
|
||||
std::vector<std::string> bins;
|
||||
for (int i = 2; i < argc; i++) {
|
||||
bins.push_back(argv[i]);
|
||||
}
|
||||
std::string kmodel_bin(argv[1]);
|
||||
std::string output_bin(argv[argc - 1]);
|
||||
run_core(kmodel_bin, input_bins, output_bin).unwrap_or_throw();
|
||||
run_core(kmodel_bin, bins).unwrap_or_throw();
|
||||
return 0;
|
||||
}
|
||||
}
|
|
@ -1,291 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.CommandLine;
|
||||
using System.CommandLine.Invocation;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
using Nncase.CodeGen;
|
||||
using Nncase.Compiler;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.IR;
|
||||
using Nncase.Passes;
|
||||
using Nncase.Quantization;
|
||||
|
||||
namespace Nncase.Cli.Commands;
|
||||
|
||||
internal enum QuantType
|
||||
{
|
||||
UInt8,
|
||||
Int8,
|
||||
Int16,
|
||||
}
|
||||
|
||||
internal enum DatasetFormat
|
||||
{
|
||||
Image,
|
||||
Raw,
|
||||
Pytest,
|
||||
Random,
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Compile command.
|
||||
/// </summary>
|
||||
public sealed class Compile : Command
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="Compile"/> class.
|
||||
/// </summary>
|
||||
public Compile()
|
||||
: base("compile")
|
||||
{
|
||||
AddArgument(new Argument("input-file"));
|
||||
AddArgument(new Argument("output-file"));
|
||||
AddOption(new Option<string>(
|
||||
aliases: new string[] { "-t", "--target" },
|
||||
description: "target architecture, e.g. cpu, k210"));
|
||||
AddOption(new Option<string>(
|
||||
aliases: new[] { "-i", "--input-format" },
|
||||
description: "input format, e.g. tflite",
|
||||
getDefaultValue: () => "tflite"));
|
||||
AddOption(new Option<int>(
|
||||
alias: "--dump-level",
|
||||
description: $"dump ir to .il, default is {0}",
|
||||
getDefaultValue: () => 0));
|
||||
AddOption(new Option<string>(
|
||||
alias: "--dump-dir",
|
||||
description: "dump to directory, default is .",
|
||||
getDefaultValue: () => "."));
|
||||
AddOption(new Option<QuantType>(
|
||||
alias: "--quant-type",
|
||||
description: $"quant type, default is {QuantType.UInt8}",
|
||||
getDefaultValue: () => QuantType.UInt8));
|
||||
AddOption(new Option<QuantType>(
|
||||
alias: "--wquant-type",
|
||||
description: $"wquant type, default is {QuantType.UInt8}",
|
||||
getDefaultValue: () => QuantType.UInt8));
|
||||
AddOption(new Option<string>(
|
||||
alias: "--dataset",
|
||||
description: $"calibration dataset, used in post quantization, default is empty",
|
||||
getDefaultValue: () => string.Empty));
|
||||
AddOption(new Option<DatasetFormat>(
|
||||
alias: "--dataset-format",
|
||||
description: $"datset format: e.g. Image|Raw|Pytest",
|
||||
getDefaultValue: () => DatasetFormat.Raw));
|
||||
AddOption(new Option<Quantization.ModelQuantMode>(
|
||||
alias: "--model-quant-mode",
|
||||
description: $"model quant mode, default is {Quantization.ModelQuantMode.NoQuant}",
|
||||
getDefaultValue: () => Quantization.ModelQuantMode.NoQuant));
|
||||
AddOption(new Option<Quantization.CalibMethod>(
|
||||
alias: "--calib-method",
|
||||
description: $"model quant options, default is {Quantization.CalibMethod.Kld}",
|
||||
getDefaultValue: () => Quantization.CalibMethod.Kld));
|
||||
AddOption(new Option<bool>(
|
||||
alias: "--pre-process",
|
||||
description: "whether enable pre process, default is False",
|
||||
getDefaultValue: () => false));
|
||||
AddOption(new Option(
|
||||
alias: "--input-layout",
|
||||
description: "the model input data layout, default is empty. eg. NCHW/NHWC",
|
||||
getDefaultValue: () => string.Empty));
|
||||
AddOption(new Option(
|
||||
alias: "--output-layout",
|
||||
description: "the model output data layout, default is empty. eg. NCHW/NHWC",
|
||||
getDefaultValue: () => string.Empty));
|
||||
AddOption(new Option(
|
||||
alias: "--input-type",
|
||||
description: "the model input data value type, default is Float32",
|
||||
getDefaultValue: () => InputType.Float32));
|
||||
AddOption(new Option<IEnumerable<int>>(
|
||||
alias: "--input-shape",
|
||||
description: "the model input data shape, default is []. eg. `--input-shape 1 2 3 4`",
|
||||
getDefaultValue: () => Array.Empty<int>()));
|
||||
AddOption(new Option<IEnumerable<float>>(
|
||||
alias: "--input-range",
|
||||
description: "the model input data value range, default is []. eg `--input-range -100.3 200.4`",
|
||||
getDefaultValue: () => Array.Empty<float>()));
|
||||
AddOption(new Option<bool>(
|
||||
alias: "--swap-rb",
|
||||
description: "whether swap the model input data channel R and B",
|
||||
getDefaultValue: () => false));
|
||||
AddOption(new Option(
|
||||
alias: "--letter-box-value",
|
||||
description: "letterbox value, default 0.0",
|
||||
getDefaultValue: () => 0.0f));
|
||||
AddOption(new Option<IEnumerable<float>>(
|
||||
alias: "--mean",
|
||||
description: "the model input data mean, default []",
|
||||
getDefaultValue: () => Array.Empty<float>()));
|
||||
AddOption(new Option<IEnumerable<float>>(
|
||||
alias: "--std",
|
||||
description: "the model input data std, default []",
|
||||
getDefaultValue: () => Array.Empty<float>()));
|
||||
AddOption(new Option(
|
||||
alias: "--model-layout",
|
||||
description: "the model's input layout, default is empty. eg. NCHW/NHWC",
|
||||
getDefaultValue: () => string.Empty));
|
||||
AddOption(new Option<bool>(
|
||||
alias: "--benchmark-only",
|
||||
description: $"benchmark only",
|
||||
getDefaultValue: () => false));
|
||||
|
||||
Handler = CommandHandler.Create<CliCompileOptions, IHost>(RunAsync);
|
||||
}
|
||||
|
||||
private static DumpFlags DumpLevelToFlags(int dumpLevel)
|
||||
{
|
||||
return dumpLevel switch
|
||||
{
|
||||
0 => DumpFlags.None,
|
||||
1 => DumpLevelToFlags(0) | DumpFlags.Compile,
|
||||
2 => DumpLevelToFlags(1) | DumpFlags.PassIR,
|
||||
3 => DumpLevelToFlags(2) | DumpFlags.Rewrite,
|
||||
4 => DumpLevelToFlags(3) | DumpFlags.EGraphCost,
|
||||
5 => DumpLevelToFlags(4) | DumpFlags.Evaluator,
|
||||
6 => DumpLevelToFlags(5) | DumpFlags.Calibration,
|
||||
7 => DumpLevelToFlags(6) | DumpFlags.Tiling,
|
||||
8 => DumpLevelToFlags(7) | DumpFlags.Schedule,
|
||||
>= 9 => DumpLevelToFlags(8) | DumpFlags.CodeGen,
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(dumpLevel)),
|
||||
};
|
||||
}
|
||||
|
||||
private async Task RunAsync(CliCompileOptions cliOptions, IHost host)
|
||||
{
|
||||
CompilerServices.Configure(host.Services);
|
||||
|
||||
// 1. setup the options
|
||||
var compileOptions = new CompileOptions
|
||||
{
|
||||
InputFile = cliOptions.InputFile,
|
||||
InputFormat = cliOptions.InputFormat,
|
||||
DumpFlags = DumpLevelToFlags(cliOptions.DumpLevel),
|
||||
DumpDir = cliOptions.DumpDir,
|
||||
QuantizeOptions = new()
|
||||
{
|
||||
CalibrationMethod = cliOptions.CalibMethod,
|
||||
QuantType = cliOptions.QuantType switch
|
||||
{
|
||||
QuantType.UInt8 => DataTypes.UInt8,
|
||||
QuantType.Int8 => DataTypes.Int8,
|
||||
QuantType.Int16 => DataTypes.Int16,
|
||||
_ => throw new ArgumentException("Invalid quant type"),
|
||||
},
|
||||
WQuantType = cliOptions.WQuantType switch
|
||||
{
|
||||
QuantType.UInt8 => DataTypes.UInt8,
|
||||
QuantType.Int8 => DataTypes.Int8,
|
||||
QuantType.Int16 => DataTypes.Int16,
|
||||
_ => throw new ArgumentException("Invalid weights quant type"),
|
||||
},
|
||||
ModelQuantMode = cliOptions.ModelQuantMode,
|
||||
},
|
||||
PreProcess = cliOptions.PreProcess,
|
||||
InputLayout = cliOptions.InputLayout,
|
||||
OutputLayout = cliOptions.OutputLayout,
|
||||
InputType = cliOptions.InputType,
|
||||
InputShape = cliOptions.InputShape.ToArray(),
|
||||
InputRange = cliOptions.InputRange.ToArray(),
|
||||
SwapRB = cliOptions.SwapRB,
|
||||
LetterBoxValue = cliOptions.LetterBoxValue,
|
||||
Mean = cliOptions.Mean.ToArray(),
|
||||
Std = cliOptions.Std.ToArray(),
|
||||
ModelLayout = cliOptions.ModelLayout,
|
||||
IsBenchmarkOnly = cliOptions.BenchmarkOnly,
|
||||
};
|
||||
|
||||
// 2. import the model
|
||||
var target = CompilerServices.GetTarget(cliOptions.Target);
|
||||
using var compileSession = CompileSession.Create(target, compileOptions);
|
||||
var compiler = compileSession.Compiler;
|
||||
var module = await compiler.ImportModuleAsync(compileOptions.InputFormat, compileOptions.InputFile, compileOptions.IsBenchmarkOnly);
|
||||
|
||||
// 3. create the calib dataset
|
||||
if (compileOptions.QuantizeOptions.ModelQuantMode == Quantization.ModelQuantMode.UsePTQ)
|
||||
{
|
||||
if (cliOptions.DatasetFormat == DatasetFormat.Random)
|
||||
{
|
||||
compileOptions.QuantizeOptions.CalibrationDataset = new RandomCalibrationDatasetProvider(((Function)module.Entry!).Parameters.ToArray(), 5);
|
||||
}
|
||||
else if (cliOptions.DatasetFormat == DatasetFormat.Pytest)
|
||||
{
|
||||
compileOptions.QuantizeOptions.CalibrationDataset = new PytestCalibrationDatasetProvider(((Function)module.Entry!).Parameters.ToArray(), cliOptions.Dataset);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new NotSupportedException(cliOptions.DatasetFormat.ToString());
|
||||
}
|
||||
}
|
||||
|
||||
// 4. compile
|
||||
await compiler.CompileAsync();
|
||||
|
||||
// 5. code gen
|
||||
using (var os = File.OpenWrite(cliOptions.OutputFile))
|
||||
{
|
||||
compiler.Gencode(os);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate null in command line parser.
|
||||
#pragma warning disable CS8618
|
||||
|
||||
internal sealed class CliCompileOptions
|
||||
{
|
||||
public string InputFile { get; set; }
|
||||
|
||||
public string InputFormat { get; set; }
|
||||
|
||||
public string Target { get; set; }
|
||||
|
||||
public int DumpLevel { get; set; }
|
||||
|
||||
public string DumpDir { get; set; }
|
||||
|
||||
public QuantType QuantType { get; set; }
|
||||
|
||||
public QuantType WQuantType { get; set; }
|
||||
|
||||
public string OutputFile { get; set; }
|
||||
|
||||
public ModelQuantMode ModelQuantMode { get; set; }
|
||||
|
||||
public CalibMethod CalibMethod { get; set; }
|
||||
|
||||
public string Dataset { get; set; }
|
||||
|
||||
public DatasetFormat DatasetFormat { get; set; }
|
||||
|
||||
public bool BenchmarkOnly { get; set; }
|
||||
|
||||
public bool PreProcess { get; set; }
|
||||
|
||||
public string InputLayout { get; set; }
|
||||
|
||||
public string OutputLayout { get; set; }
|
||||
|
||||
public InputType InputType { get; set; }
|
||||
|
||||
public List<int> InputShape { get; set; }
|
||||
|
||||
public List<float> InputRange { get; set; }
|
||||
|
||||
public bool SwapRB { get; set; }
|
||||
|
||||
public float LetterBoxValue { get; set; }
|
||||
|
||||
public List<float> Mean { get; set; }
|
||||
|
||||
public List<float> Std { get; set; }
|
||||
|
||||
public string ModelLayout { get; set; }
|
||||
}
|
||||
|
||||
#pragma warning restore CS8618
|
|
@ -0,0 +1,217 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.CommandLine;
|
||||
using System.Linq;
|
||||
using Nncase.Diagnostics;
|
||||
using Nncase.Quantization;
|
||||
|
||||
namespace Nncase.Cli;
|
||||
|
||||
internal enum QuantType
|
||||
{
|
||||
UInt8,
|
||||
Int8,
|
||||
Int16,
|
||||
}
|
||||
|
||||
internal enum DatasetFormat
|
||||
{
|
||||
Image,
|
||||
Raw,
|
||||
Pytest,
|
||||
Random,
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Compile command.
|
||||
/// </summary>
|
||||
internal sealed class CompileCommand : Command
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="CompileCommand"/> class.
|
||||
/// </summary>
|
||||
public CompileCommand()
|
||||
: base("compile")
|
||||
{
|
||||
InputFile = new Argument<string>("input-file");
|
||||
OutputFile = new Argument<string>("output-file");
|
||||
InputFormat = new Option<string>(
|
||||
aliases: new[] { "-i", "--input-format" },
|
||||
description: "input format, e.g. tflite",
|
||||
getDefaultValue: () => "tflite");
|
||||
DumpFlags = new Option<IEnumerable<DumpFlags>>(
|
||||
name: "--dump-flags",
|
||||
description: "dump ir flags. \navailable value: None,ImportOps,PassIR,EGraphCost,Rewrite,Calibration,Evaluator,Compile,Tiling,Schedule,CodeGen.")
|
||||
{
|
||||
AllowMultipleArgumentsPerToken = true,
|
||||
};
|
||||
DumpDir = new Option<string>(
|
||||
name: "--dump-dir",
|
||||
description: "dump to directory.",
|
||||
getDefaultValue: () => ".");
|
||||
QuantType = new Option<QuantType>(
|
||||
name: "--quant-type",
|
||||
description: $"quant type",
|
||||
getDefaultValue: () => Nncase.Cli.QuantType.UInt8);
|
||||
WQuantType = new Option<QuantType>(
|
||||
name: "--wquant-type",
|
||||
description: $"wquant type",
|
||||
getDefaultValue: () => Nncase.Cli.QuantType.UInt8);
|
||||
Dataset = new Option<string>(
|
||||
name: "--dataset",
|
||||
description: $"calibration dataset, used in post quantization",
|
||||
getDefaultValue: () => string.Empty);
|
||||
DatasetFormat = new Option<DatasetFormat>(
|
||||
name: "--dataset-format",
|
||||
description: $"datset format.",
|
||||
getDefaultValue: () => Nncase.Cli.DatasetFormat.Raw);
|
||||
ModelQuantMode = new Option<Quantization.ModelQuantMode>(
|
||||
name: "--model-quant-mode",
|
||||
description: $"model quant mode",
|
||||
getDefaultValue: () => Quantization.ModelQuantMode.NoQuant);
|
||||
CalibMethod = new Option<Quantization.CalibMethod>(
|
||||
name: "--calib-method",
|
||||
description: $"model quant options",
|
||||
getDefaultValue: () => Quantization.CalibMethod.Kld);
|
||||
FixedVars = new Option<IEnumerable<(string, int)>>(
|
||||
name: "--fixed-vars",
|
||||
description: $"dynamic shape fixed vars, default is empty. \nset by `n:123`",
|
||||
parseArgument: result =>
|
||||
{
|
||||
return result.Tokens.
|
||||
Select(tk => tk.Value.Split(":").ToArray()).
|
||||
Select(tp => (tp[0].Trim(), int.Parse(tp[1].Trim())));
|
||||
})
|
||||
{
|
||||
AllowMultipleArgumentsPerToken = true,
|
||||
};
|
||||
PreProcess = new Option<bool>(
|
||||
name: "--pre-process",
|
||||
description: "whether enable pre process",
|
||||
getDefaultValue: () => false);
|
||||
InputLayout = new Option<string>(
|
||||
name: "--input-layout",
|
||||
description: "the model input data layout",
|
||||
getDefaultValue: () => string.Empty).FromAmong("NCHW", "NHWC");
|
||||
OutputLayout = new Option<string>(
|
||||
name: "--output-layout",
|
||||
description: "the model output data layout.",
|
||||
getDefaultValue: () => string.Empty).FromAmong("NCHW", "NHWC");
|
||||
InputType = new Option<InputType>(
|
||||
name: "--input-type",
|
||||
description: "the model input data value type, default is Float32",
|
||||
getDefaultValue: () => Nncase.InputType.Float32);
|
||||
InputShape = new Option<IEnumerable<int>>(
|
||||
name: "--input-shape",
|
||||
description: "the model input data shape. eg. `--input-shape 1 2 3 4`",
|
||||
getDefaultValue: Array.Empty<int>)
|
||||
{
|
||||
AllowMultipleArgumentsPerToken = true,
|
||||
};
|
||||
InputRange = new Option<IEnumerable<float>>(
|
||||
name: "--input-range",
|
||||
description: "the model input data value range. eg `--input-range -100.3 200.4`",
|
||||
getDefaultValue: Array.Empty<float>)
|
||||
{
|
||||
AllowMultipleArgumentsPerToken = true,
|
||||
};
|
||||
SwapRB = new Option<bool>(
|
||||
name: "--swap-rb",
|
||||
description: "whether swap the model input data channel, like cv2.BGRtoRGB(im)",
|
||||
getDefaultValue: () => false);
|
||||
LetterBoxValue = new Option<float>(
|
||||
name: "--letter-box-value",
|
||||
description: "letterbox fill value",
|
||||
getDefaultValue: () => 0.0f);
|
||||
Mean = new Option<IEnumerable<float>>(
|
||||
name: "--mean",
|
||||
description: "the model input data mean, default []",
|
||||
getDefaultValue: Array.Empty<float>)
|
||||
{
|
||||
AllowMultipleArgumentsPerToken = true,
|
||||
};
|
||||
Std = new Option<IEnumerable<float>>(
|
||||
name: "--std",
|
||||
description: "the model input data std, default []",
|
||||
getDefaultValue: Array.Empty<float>)
|
||||
{
|
||||
AllowMultipleArgumentsPerToken = true,
|
||||
};
|
||||
ModelLayout = new Option<string>(
|
||||
name: "--model-layout",
|
||||
description: "the model's input layout.",
|
||||
getDefaultValue: () => string.Empty).FromAmong("NCHW", "NHWC");
|
||||
AddArgument(InputFile);
|
||||
AddArgument(OutputFile);
|
||||
AddGlobalOption(InputFormat);
|
||||
AddGlobalOption(DumpFlags);
|
||||
AddGlobalOption(DumpDir);
|
||||
AddGlobalOption(QuantType);
|
||||
AddGlobalOption(WQuantType);
|
||||
AddGlobalOption(Dataset);
|
||||
AddGlobalOption(DatasetFormat);
|
||||
AddGlobalOption(ModelQuantMode);
|
||||
AddGlobalOption(CalibMethod);
|
||||
AddGlobalOption(FixedVars);
|
||||
AddGlobalOption(PreProcess);
|
||||
AddGlobalOption(InputLayout);
|
||||
AddGlobalOption(OutputLayout);
|
||||
AddGlobalOption(InputType);
|
||||
AddGlobalOption(InputShape);
|
||||
AddGlobalOption(InputRange);
|
||||
AddGlobalOption(SwapRB);
|
||||
AddGlobalOption(LetterBoxValue);
|
||||
AddGlobalOption(Mean);
|
||||
AddGlobalOption(Std);
|
||||
AddGlobalOption(ModelLayout);
|
||||
}
|
||||
|
||||
public Argument<string> InputFile { get; }
|
||||
|
||||
public Argument<string> OutputFile { get; }
|
||||
|
||||
public Option<string> InputFormat { get; }
|
||||
|
||||
public Option<IEnumerable<DumpFlags>> DumpFlags { get; }
|
||||
|
||||
public Option<string> DumpDir { get; }
|
||||
|
||||
public Option<QuantType> QuantType { get; }
|
||||
|
||||
public Option<QuantType> WQuantType { get; }
|
||||
|
||||
public Option<string> Dataset { get; }
|
||||
|
||||
public Option<DatasetFormat> DatasetFormat { get; }
|
||||
|
||||
public Option<ModelQuantMode> ModelQuantMode { get; }
|
||||
|
||||
public Option<CalibMethod> CalibMethod { get; }
|
||||
|
||||
public Option<IEnumerable<(string Name, int Value)>> FixedVars { get; }
|
||||
|
||||
public Option<bool> PreProcess { get; }
|
||||
|
||||
public Option<string> InputLayout { get; }
|
||||
|
||||
public Option<string> OutputLayout { get; }
|
||||
|
||||
public Option<InputType> InputType { get; }
|
||||
|
||||
public Option<IEnumerable<int>> InputShape { get; }
|
||||
|
||||
public Option<IEnumerable<float>> InputRange { get; }
|
||||
|
||||
public Option<bool> SwapRB { get; }
|
||||
|
||||
public Option<float> LetterBoxValue { get; }
|
||||
|
||||
public Option<IEnumerable<float>> Mean { get; }
|
||||
|
||||
public Option<IEnumerable<float>> Std { get; }
|
||||
|
||||
public Option<string> ModelLayout { get; }
|
||||
}
|
|
@ -26,4 +26,8 @@
|
|||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||
</None>
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<Folder Include="Properties\" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.CommandLine;
|
||||
using System.CommandLine.Builder;
|
||||
using System.Linq;
|
||||
|
||||
namespace Nncase.Cli;
|
||||
|
||||
internal partial class Program
|
||||
{
|
||||
private static CommandLineBuilder BuildCommandLine()
|
||||
{
|
||||
var commands = from t in typeof(Program).Assembly.ExportedTypes
|
||||
where t.Namespace == "Nncase.Cli.Commands" && t.IsAssignableTo(typeof(Command))
|
||||
select (Command)Activator.CreateInstance(t)!;
|
||||
var root = new RootCommand();
|
||||
foreach (var command in commands)
|
||||
{
|
||||
root.AddCommand(command);
|
||||
}
|
||||
|
||||
return new CommandLineBuilder(root);
|
||||
}
|
||||
}
|
|
@ -1,10 +1,14 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.CommandLine;
|
||||
using System.CommandLine.Builder;
|
||||
using System.CommandLine.Hosting;
|
||||
using System.CommandLine.Parsing;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
using Microsoft.Extensions.Hosting;
|
||||
|
@ -16,12 +20,155 @@ internal partial class Program
|
|||
{
|
||||
public static async Task<int> Main(string[] args)
|
||||
{
|
||||
return await BuildCommandLine()
|
||||
return await ConfigureCommandLine()
|
||||
.UseHost(ConfigureHost)
|
||||
.UseDefaults()
|
||||
.Build().InvokeAsync(args);
|
||||
}
|
||||
|
||||
private static async Task RunAsync(string targetKind, CompileOptions compileOptions, DatasetFormat datasetFormat, string dataset, string outputFile, IHost host)
|
||||
{
|
||||
CompilerServices.Configure(host.Services);
|
||||
|
||||
// 2. import the model
|
||||
var target = CompilerServices.GetTarget(targetKind);
|
||||
using var compileSession = CompileSession.Create(target, compileOptions);
|
||||
var compiler = compileSession.Compiler;
|
||||
IR.IRModule module = await compiler.ImportModuleAsync(Path.GetExtension(compileOptions.InputFile).Trim('.'), compileOptions.InputFile);
|
||||
|
||||
// 3. create the calib dataset
|
||||
if (compileOptions.QuantizeOptions.ModelQuantMode == Quantization.ModelQuantMode.UsePTQ)
|
||||
{
|
||||
if (datasetFormat == DatasetFormat.Random)
|
||||
{
|
||||
compileOptions.QuantizeOptions.CalibrationDataset = new Quantization.RandomCalibrationDatasetProvider(((Nncase.IR.Function)module.Entry!).Parameters.ToArray(), 5);
|
||||
}
|
||||
else if (datasetFormat == DatasetFormat.Pytest)
|
||||
{
|
||||
compileOptions.QuantizeOptions.CalibrationDataset = new Quantization.PytestCalibrationDatasetProvider(((IR.Function)module.Entry!).Parameters.ToArray(), dataset);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new NotSupportedException(datasetFormat.ToString());
|
||||
}
|
||||
}
|
||||
|
||||
// 4. compile
|
||||
await compiler.CompileAsync();
|
||||
|
||||
// 5. code gen
|
||||
using (var os = File.OpenWrite(outputFile))
|
||||
{
|
||||
compiler.Gencode(os);
|
||||
}
|
||||
}
|
||||
|
||||
private static CommandLineBuilder ConfigureCommandLine()
|
||||
{
|
||||
var compile = new CompileCommand();
|
||||
foreach (var target in LoadTargets())
|
||||
{
|
||||
var (targetCmd, targetParser) = target.RegisterCommandAndParser();
|
||||
Action<System.CommandLine.Invocation.InvocationContext> targetHandler = async (System.CommandLine.Invocation.InvocationContext context) =>
|
||||
{
|
||||
var options = ParseCompileOptions(context, compile);
|
||||
options.TargetCompileOptions = targetParser(context, targetCmd);
|
||||
await RunAsync(targetCmd.Name, options, context.ParseResult.GetValueForOption(compile.DatasetFormat), context.ParseResult.GetValueForOption(compile.Dataset)!, context.ParseResult.GetValueForArgument(compile.OutputFile), context.GetHost());
|
||||
};
|
||||
targetCmd.SetHandler(targetHandler);
|
||||
compile.AddCommand(targetCmd);
|
||||
}
|
||||
|
||||
return new CommandLineBuilder(new RootCommand() { compile });
|
||||
}
|
||||
|
||||
private static CompileOptions ParseCompileOptions(System.CommandLine.Invocation.InvocationContext context, CompileCommand compilecmd)
|
||||
{
|
||||
// 1. setup the options
|
||||
var compileOptions = new CompileOptions
|
||||
{
|
||||
InputFile = context.ParseResult.GetValueForArgument(compilecmd.InputFile),
|
||||
InputFormat = context.ParseResult.GetValueForOption(compilecmd.InputFormat)!,
|
||||
DumpFlags = context.ParseResult.GetValueForOption(compilecmd.DumpFlags)!.Aggregate(Diagnostics.DumpFlags.None, (a, b) => a | b),
|
||||
DumpDir = context.ParseResult.GetValueForOption(compilecmd.DumpDir)!,
|
||||
PreProcess = context.ParseResult.GetValueForOption(compilecmd.PreProcess)!,
|
||||
InputLayout = context.ParseResult.GetValueForOption(compilecmd.InputLayout)!,
|
||||
OutputLayout = context.ParseResult.GetValueForOption(compilecmd.OutputLayout)!,
|
||||
InputType = context.ParseResult.GetValueForOption(compilecmd.InputType)!,
|
||||
InputShape = context.ParseResult.GetValueForOption(compilecmd.InputShape)!.ToArray(),
|
||||
InputRange = context.ParseResult.GetValueForOption(compilecmd.InputRange)!.ToArray(),
|
||||
SwapRB = context.ParseResult.GetValueForOption(compilecmd.SwapRB)!,
|
||||
LetterBoxValue = context.ParseResult.GetValueForOption(compilecmd.LetterBoxValue)!,
|
||||
Mean = context.ParseResult.GetValueForOption(compilecmd.Mean)!.ToArray(),
|
||||
Std = context.ParseResult.GetValueForOption(compilecmd.Std)!.ToArray(),
|
||||
ModelLayout = context.ParseResult.GetValueForOption(compilecmd.ModelLayout)!,
|
||||
QuantizeOptions = new()
|
||||
{
|
||||
CalibrationMethod = context.ParseResult.GetValueForOption(compilecmd.CalibMethod),
|
||||
QuantType = context.ParseResult.GetValueForOption(compilecmd.QuantType) switch
|
||||
{
|
||||
QuantType.UInt8 => DataTypes.UInt8,
|
||||
QuantType.Int8 => DataTypes.Int8,
|
||||
QuantType.Int16 => DataTypes.Int16,
|
||||
_ => throw new ArgumentException("Invalid quant type"),
|
||||
},
|
||||
WQuantType = context.ParseResult.GetValueForOption(compilecmd.WQuantType) switch
|
||||
{
|
||||
QuantType.UInt8 => DataTypes.UInt8,
|
||||
QuantType.Int8 => DataTypes.Int8,
|
||||
QuantType.Int16 => DataTypes.Int16,
|
||||
_ => throw new ArgumentException("Invalid weights quant type"),
|
||||
},
|
||||
ModelQuantMode = context.ParseResult.GetValueForOption(compilecmd.ModelQuantMode),
|
||||
},
|
||||
};
|
||||
|
||||
foreach (var item in context.ParseResult.GetValueForOption(compilecmd.FixedVars)!)
|
||||
{
|
||||
compileOptions.ShapeBucketOptions.FixVarMap.Add(item.Name, item.Value);
|
||||
}
|
||||
|
||||
return compileOptions;
|
||||
}
|
||||
|
||||
private static IReadOnlyList<ITarget> LoadTargets()
|
||||
{
|
||||
var loadContext = System.Runtime.Loader.AssemblyLoadContext.Default;
|
||||
var pluginAsms = PluginLoader.GetPluginsSearchDirectories(PluginLoader.PluginPathEnvName, null).
|
||||
Select(PluginLoader.GetPluginAssemblies).
|
||||
SelectMany(x => x).
|
||||
DistinctBy(Path.GetFileName).
|
||||
Select(x => PluginLoader.LoadPluginAssembly(x, loadContext)).
|
||||
Distinct().
|
||||
ToList();
|
||||
pluginAsms.AddRange(new[] { Path.GetDirectoryName(typeof(Program).Assembly.Location)! }.
|
||||
Select(basePath =>
|
||||
{
|
||||
if (Directory.Exists(basePath))
|
||||
{
|
||||
return (from filePath in Directory.GetFiles(basePath, PluginLoader.ModulesDllPattern, SearchOption.AllDirectories)
|
||||
where PluginLoader.IsLoadableAssembly(filePath)
|
||||
select filePath).Distinct();
|
||||
}
|
||||
else
|
||||
{
|
||||
return Array.Empty<string>();
|
||||
}
|
||||
}).
|
||||
SelectMany(x => x).
|
||||
DistinctBy(Path.GetFileName).
|
||||
Select(x => PluginLoader.LoadPluginAssembly(x, loadContext)).
|
||||
Distinct());
|
||||
var targets = (from asm in pluginAsms
|
||||
from t in asm.ExportedTypes
|
||||
where t.IsClass
|
||||
&& t.IsAssignableTo(typeof(ITarget))
|
||||
let ctor = t.GetConstructor(Type.EmptyTypes)
|
||||
where ctor != null
|
||||
select (ITarget)ctor.Invoke(null)).ToList();
|
||||
return targets;
|
||||
}
|
||||
|
||||
private static void ConfigureHost(IHostBuilder hostBuilder)
|
||||
{
|
||||
hostBuilder.ConfigureAppConfiguration(ConfigureAppConfiguration)
|
||||
|
|
|
@ -33,21 +33,22 @@
|
|||
},
|
||||
"StyleCop.Analyzers": {
|
||||
"type": "Direct",
|
||||
"requested": "[1.2.0-beta.507, )",
|
||||
"resolved": "1.2.0-beta.507",
|
||||
"contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==",
|
||||
"requested": "[1.2.0-beta.435, )",
|
||||
"resolved": "1.2.0-beta.435",
|
||||
"contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==",
|
||||
"dependencies": {
|
||||
"StyleCop.Analyzers.Unstable": "1.2.0.507"
|
||||
"StyleCop.Analyzers.Unstable": "1.2.0.435"
|
||||
}
|
||||
},
|
||||
"System.CommandLine.Hosting": {
|
||||
"type": "Direct",
|
||||
"requested": "[0.3.0-alpha.21216.1, )",
|
||||
"resolved": "0.3.0-alpha.21216.1",
|
||||
"contentHash": "zP8QEUH8dSUYUHdGk6k71kOJy8uFgEPZG2RfhA0cMjDH3/Jov5AjUNaxOvpSNHh+ewu8eIUCYgV8+fEkCPyNlw==",
|
||||
"requested": "[0.4.0-alpha.22272.1, )",
|
||||
"resolved": "0.4.0-alpha.22272.1",
|
||||
"contentHash": "x9JhHxBLxlKyCIZADFYC8q16L9yGHdTakrLFjHabwR7Tk0761aTexiGgMTIS744HGuhc8pk9MoLUzsr/TlRfMQ==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.Hosting": "3.1.5",
|
||||
"System.CommandLine": "2.0.0-beta1.21216.1"
|
||||
"Microsoft.Extensions.Hosting": "6.0.0",
|
||||
"System.CommandLine": "2.0.0-beta4.22272.1",
|
||||
"System.CommandLine.NamingConventionBinder": "2.0.0-beta4.22272.1"
|
||||
}
|
||||
},
|
||||
"Google.OrTools.runtime.linux-arm64": {
|
||||
|
@ -344,8 +345,8 @@
|
|||
},
|
||||
"StyleCop.Analyzers.Unstable": {
|
||||
"type": "Transitive",
|
||||
"resolved": "1.2.0.507",
|
||||
"contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw=="
|
||||
"resolved": "1.2.0.435",
|
||||
"contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg=="
|
||||
},
|
||||
"System.Buffers": {
|
||||
"type": "Transitive",
|
||||
|
@ -362,13 +363,12 @@
|
|||
"System.Runtime": "4.3.0"
|
||||
}
|
||||
},
|
||||
"System.CommandLine": {
|
||||
"System.CommandLine.NamingConventionBinder": {
|
||||
"type": "Transitive",
|
||||
"resolved": "2.0.0-beta1.21216.1",
|
||||
"contentHash": "Nbv/tW8sbOKN5T+4SSVBMdk4ADSIpJpY4UHMsj3VkcNtOckIT4iyzagjF+W5FEh2YBRvmvVQijOTIZbUJ1+1aA==",
|
||||
"resolved": "2.0.0-beta4.22272.1",
|
||||
"contentHash": "ux2eUA/syF+JtlpMDc/Lsd6PBIBuwjH3AvHnestoh5uD0WKT5b+wkQxDWVCqp9qgVjMBTLNhX19ZYFtenunt9A==",
|
||||
"dependencies": {
|
||||
"Microsoft.CSharp": "4.4.1",
|
||||
"system.memory": "4.5.4"
|
||||
"System.CommandLine": "2.0.0-beta4.22272.1"
|
||||
}
|
||||
},
|
||||
"System.Diagnostics.Contracts": {
|
||||
|
@ -696,6 +696,7 @@
|
|||
"Microsoft.Extensions.Options": "[6.0.0, )",
|
||||
"Microsoft.Toolkit.HighPerformance": "[7.1.1, )",
|
||||
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
|
||||
"System.CommandLine": "[2.0.0-beta4.22272.1, )",
|
||||
"System.Reactive": "[5.0.0, )"
|
||||
}
|
||||
},
|
||||
|
@ -937,6 +938,12 @@
|
|||
"resolved": "1.0.2",
|
||||
"contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA=="
|
||||
},
|
||||
"System.CommandLine": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[2.0.0-beta4.22272.1, )",
|
||||
"resolved": "2.0.0-beta4.22272.1",
|
||||
"contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg=="
|
||||
},
|
||||
"System.Linq.Async": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[6.0.1, )",
|
||||
|
|
|
@ -15,7 +15,6 @@ public class LinkedFunction : ILinkedFunction
|
|||
public LinkedFunction(uint id, Callable sourceFunction, ulong textBegin, ulong textLength, IReadOnlyList<ILinkedSection> sections)
|
||||
{
|
||||
Id = id;
|
||||
CompilerServices.InferenceType(sourceFunction);
|
||||
ParameterTypes = ((CallableType)sourceFunction.CheckedType).Parameters.ToArray();
|
||||
ReturnType = ((CallableType)sourceFunction.CheckedType).ReturnType;
|
||||
TextBegin = textBegin;
|
||||
|
|
|
@ -10,11 +10,11 @@
|
|||
},
|
||||
"StyleCop.Analyzers": {
|
||||
"type": "Direct",
|
||||
"requested": "[1.2.0-beta.507, )",
|
||||
"resolved": "1.2.0-beta.507",
|
||||
"contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==",
|
||||
"requested": "[1.2.0-beta.435, )",
|
||||
"resolved": "1.2.0-beta.435",
|
||||
"contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==",
|
||||
"dependencies": {
|
||||
"StyleCop.Analyzers.Unstable": "1.2.0.507"
|
||||
"StyleCop.Analyzers.Unstable": "1.2.0.435"
|
||||
}
|
||||
},
|
||||
"Microsoft.Extensions.Configuration.Abstractions": {
|
||||
|
@ -53,8 +53,8 @@
|
|||
},
|
||||
"StyleCop.Analyzers.Unstable": {
|
||||
"type": "Transitive",
|
||||
"resolved": "1.2.0.507",
|
||||
"contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw=="
|
||||
"resolved": "1.2.0.435",
|
||||
"contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg=="
|
||||
},
|
||||
"System.Buffers": {
|
||||
"type": "Transitive",
|
||||
|
@ -76,6 +76,7 @@
|
|||
"Microsoft.Extensions.Options": "[6.0.0, )",
|
||||
"Microsoft.Toolkit.HighPerformance": "[7.1.1, )",
|
||||
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
|
||||
"System.CommandLine": "[2.0.0-beta4.22272.1, )",
|
||||
"System.Reactive": "[5.0.0, )"
|
||||
}
|
||||
},
|
||||
|
@ -138,6 +139,12 @@
|
|||
"System.Runtime.CompilerServices.Unsafe": "5.0.0"
|
||||
}
|
||||
},
|
||||
"System.CommandLine": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[2.0.0-beta4.22272.1, )",
|
||||
"resolved": "2.0.0-beta4.22272.1",
|
||||
"contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg=="
|
||||
},
|
||||
"System.Reactive": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[5.0.0, )",
|
||||
|
|
|
@ -88,13 +88,15 @@ internal class Compiler : ICompiler
|
|||
|
||||
public void TargetIndependentPass(IPassManager passManager)
|
||||
{
|
||||
passManager.AddWithName<DataflowPass>("ReshapeMatMul").Configure(p =>
|
||||
passManager.AddWithName<DataflowPass>("NormAxisAndShape").Configure(p =>
|
||||
{
|
||||
p.Add<Passes.Rules.Neutral.ReshapeMatMul>();
|
||||
});
|
||||
|
||||
passManager.AddWithName<DataflowPass>("SqueezeShape").Configure(p =>
|
||||
{
|
||||
p.Add<Passes.Rules.Neutral.NormAxisGather>();
|
||||
p.Add<Passes.Rules.Neutral.NormAxisConcat>();
|
||||
p.Add<Passes.Rules.Neutral.NormAxisReduce>();
|
||||
p.Add<Passes.Rules.Neutral.NormAxisReshape>();
|
||||
p.Add<Passes.Rules.Neutral.NormAxisReduceArg>();
|
||||
p.Add<Passes.Rules.Neutral.NormAxisSlice>();
|
||||
p.Add<Passes.Rules.Neutral.SqueezeTransposeShape>();
|
||||
p.Add<Passes.Rules.Neutral.Squeeze5DTranspose>();
|
||||
p.Add<Passes.Rules.Neutral.SqueezeBinaryShape>();
|
||||
|
@ -102,6 +104,7 @@ internal class Compiler : ICompiler
|
|||
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern2>();
|
||||
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern3>();
|
||||
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern4>();
|
||||
p.Add<Passes.Rules.Neutral.FoldLayerNormPattern5>();
|
||||
p.Add<Passes.Rules.Neutral.FoldGeluWithScale>();
|
||||
p.Add<Passes.Rules.Neutral.FoldGeneralGelu>();
|
||||
p.Add<Passes.Rules.Neutral.FoldSwishPattern1>();
|
||||
|
@ -124,6 +127,7 @@ internal class Compiler : ICompiler
|
|||
p.Add<Passes.Rules.Neutral.FoldNopCast>();
|
||||
p.Add<Passes.Rules.Neutral.FoldNopReshape>();
|
||||
p.Add<Passes.Rules.Neutral.FoldNopSlice>();
|
||||
p.Add<Passes.Rules.Neutral.FoldPrePostReshapeSoftmax>();
|
||||
p.Add<Passes.Rules.Neutral.FoldSqueezeUnsqueeze>();
|
||||
p.Add<Passes.Rules.Neutral.FoldUnsqueezeSqueeze>();
|
||||
p.Add<Passes.Rules.Neutral.FoldTwoTransposes>();
|
||||
|
@ -157,6 +161,8 @@ internal class Compiler : ICompiler
|
|||
p.Add<Passes.Rules.Neutral.CombineUnaryReshape>();
|
||||
p.Add<Passes.Rules.Neutral.CombineActivationsReshape>();
|
||||
p.Add<Passes.Rules.Neutral.CombineReshapePad>();
|
||||
p.Add<Passes.Rules.Neutral.CombineReshapeTranspose>();
|
||||
p.Add<Passes.Rules.Neutral.CombineTransposeReshape>();
|
||||
p.Add<Passes.Rules.Neutral.FoldNopPad>();
|
||||
p.Add<Passes.Rules.Neutral.FoldConv2DPads>();
|
||||
p.Add<Passes.Rules.Neutral.FuseClampConv2D>();
|
||||
|
@ -168,6 +174,7 @@ internal class Compiler : ICompiler
|
|||
p.Add<Passes.Rules.Neutral.ReshapeToTranspose>();
|
||||
p.Add<Passes.Rules.Neutral.FoldNopReshape>();
|
||||
p.Add<Passes.Rules.Neutral.FoldTwoReshapes>();
|
||||
p.Add<Passes.Rules.Neutral.FoldReshapeBinaryConstReshape>();
|
||||
p.Add<Passes.Rules.Neutral.ReluToClamp>();
|
||||
p.Add<Passes.Rules.Neutral.Relu6ToClamp>();
|
||||
p.Add<Passes.Rules.Neutral.FoldNopSlice>();
|
||||
|
|
|
@ -19,12 +19,14 @@ namespace Nncase.Hosting;
|
|||
/// </summary>
|
||||
public sealed class PluginLoader
|
||||
{
|
||||
private const string _modulesDllPattern = "Nncase.Modules.*.dll";
|
||||
private const string _pluginPathEnvName = "NNCASE_PLUGIN_PATH";
|
||||
public const string PluginPathEnvName = "NNCASE_PLUGIN_PATH";
|
||||
|
||||
public const string ModulesDllPattern = "Nncase.Modules.*.dll";
|
||||
|
||||
private static readonly string[] _builtinModules = new[]
|
||||
{
|
||||
"Nncase.Modules.StackVM.dll",
|
||||
"Nncase.Modules.CPU.dll",
|
||||
"Nncase.Modules.K210.dll",
|
||||
};
|
||||
|
||||
|
@ -42,26 +44,60 @@ public sealed class PluginLoader
|
|||
?? AssemblyLoadContext.Default;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load plugins.
|
||||
/// </summary>
|
||||
/// <returns>Plugins.</returns>
|
||||
public IReadOnlyList<IPlugin> LoadPlugins()
|
||||
public static Assembly LoadPluginAssembly(string assemblyFile, AssemblyLoadContext loadContext)
|
||||
{
|
||||
var pluginAsms = GetPluginsSearchDirectories().Select(GetPluginAssemblies).SelectMany(x => x)
|
||||
.DistinctBy(Path.GetFileName).Select(LoadPluginAssembly).Distinct().ToList();
|
||||
var plugins = (from asm in pluginAsms
|
||||
from t in asm.ExportedTypes
|
||||
where t.IsClass
|
||||
&& t.IsAssignableTo(typeof(IPlugin))
|
||||
let ctor = t.GetConstructor(Type.EmptyTypes)
|
||||
where ctor != null
|
||||
select (IPlugin)ctor.Invoke(null)).ToList();
|
||||
|
||||
return plugins;
|
||||
return loadContext.LoadFromAssemblyPath(assemblyFile);
|
||||
}
|
||||
|
||||
private static bool IsLoadableAssembly(string filePath)
|
||||
public static IEnumerable<string> GetPluginAssemblies(string basePath)
|
||||
{
|
||||
if (Directory.Exists(basePath))
|
||||
{
|
||||
return (from filePath in Directory.GetFiles(basePath, ModulesDllPattern, SearchOption.AllDirectories)
|
||||
where !_builtinModules.Contains(Path.GetFileName(filePath))
|
||||
&& IsLoadableAssembly(filePath)
|
||||
select filePath).Distinct();
|
||||
}
|
||||
else
|
||||
{
|
||||
return Array.Empty<string>();
|
||||
}
|
||||
}
|
||||
|
||||
public static IEnumerable<string> GetPluginsSearchDirectories(string pluginPathEnvName, ILogger? logger)
|
||||
{
|
||||
var directories = new List<string>();
|
||||
|
||||
// 1. Environment variable
|
||||
var targetPathEnv = Environment.GetEnvironmentVariable(pluginPathEnvName);
|
||||
if (string.IsNullOrWhiteSpace(targetPathEnv))
|
||||
{
|
||||
if (logger is not null)
|
||||
{
|
||||
logger.LogWarning($"{pluginPathEnvName} is not set.");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
var targetPaths = from path in targetPathEnv!.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries)
|
||||
select Environment.ExpandEnvironmentVariables(path);
|
||||
directories.AddRange(targetPaths);
|
||||
}
|
||||
|
||||
// 2. Python nncase modules
|
||||
var rootPath = Path.GetDirectoryName(typeof(PluginLoader).Assembly.Location)!;
|
||||
var modulesPath = Path.Combine(rootPath, "modules");
|
||||
directories.Add(modulesPath);
|
||||
|
||||
if (logger is not null && logger.IsEnabled(LogLevel.Trace))
|
||||
{
|
||||
logger.LogInformation($"Loading plugins from {string.Join(", ", directories)}.");
|
||||
}
|
||||
|
||||
return directories.Distinct();
|
||||
}
|
||||
|
||||
public static bool IsLoadableAssembly(string filePath)
|
||||
{
|
||||
using var fs = File.OpenRead(filePath);
|
||||
using var peReader = new PEReader(fs);
|
||||
|
@ -93,53 +129,22 @@ public sealed class PluginLoader
|
|||
return true;
|
||||
}
|
||||
|
||||
private Assembly LoadPluginAssembly(string assemblyFile)
|
||||
/// <summary>
|
||||
/// Load plugins.
|
||||
/// </summary>
|
||||
/// <returns>Plugins.</returns>
|
||||
public IReadOnlyList<IPlugin> LoadPlugins()
|
||||
{
|
||||
return _loadContext.LoadFromAssemblyPath(assemblyFile);
|
||||
}
|
||||
var pluginAsms = GetPluginsSearchDirectories(PluginPathEnvName, _logger).Select(GetPluginAssemblies).SelectMany(x => x)
|
||||
.DistinctBy(Path.GetFileName).Select(x => LoadPluginAssembly(x, _loadContext)).Distinct().ToList();
|
||||
var plugins = (from asm in pluginAsms
|
||||
from t in asm.ExportedTypes
|
||||
where t.IsClass
|
||||
&& t.IsAssignableTo(typeof(IPlugin))
|
||||
let ctor = t.GetConstructor(Type.EmptyTypes)
|
||||
where ctor != null
|
||||
select (IPlugin)ctor.Invoke(null)).ToList();
|
||||
|
||||
private IEnumerable<string> GetPluginAssemblies(string basePath)
|
||||
{
|
||||
if (Directory.Exists(basePath))
|
||||
{
|
||||
return (from filePath in Directory.GetFiles(basePath, _modulesDllPattern, SearchOption.AllDirectories)
|
||||
where !_builtinModules.Contains(Path.GetFileName(filePath))
|
||||
&& IsLoadableAssembly(filePath)
|
||||
select filePath).Distinct();
|
||||
}
|
||||
else
|
||||
{
|
||||
return Array.Empty<string>();
|
||||
}
|
||||
}
|
||||
|
||||
private IEnumerable<string> GetPluginsSearchDirectories()
|
||||
{
|
||||
var directories = new List<string>();
|
||||
|
||||
// 1. Environment variable
|
||||
var targetPathEnv = Environment.GetEnvironmentVariable(_pluginPathEnvName);
|
||||
if (string.IsNullOrWhiteSpace(targetPathEnv))
|
||||
{
|
||||
_logger.LogWarning($"{_pluginPathEnvName} is not set.");
|
||||
}
|
||||
else
|
||||
{
|
||||
var targetPaths = from path in targetPathEnv.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries)
|
||||
select Environment.ExpandEnvironmentVariables(path);
|
||||
directories.AddRange(targetPaths);
|
||||
}
|
||||
|
||||
// 2. Python nncase modules
|
||||
var rootPath = Path.GetDirectoryName(typeof(PluginLoader).Assembly.Location)!;
|
||||
var modulesPath = Path.Combine(rootPath, "modules");
|
||||
directories.Add(modulesPath);
|
||||
|
||||
if (_logger.IsEnabled(LogLevel.Trace))
|
||||
{
|
||||
_logger.LogInformation($"Loading plugins from {string.Join(", ", directories)}.");
|
||||
}
|
||||
|
||||
return directories.Distinct();
|
||||
return plugins;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,11 +49,11 @@
|
|||
},
|
||||
"StyleCop.Analyzers": {
|
||||
"type": "Direct",
|
||||
"requested": "[1.2.0-beta.507, )",
|
||||
"resolved": "1.2.0-beta.507",
|
||||
"contentHash": "/FtugDT66cKJJ+GGH7rNpG6UDrT4iIWz45M6lrXXHobDUFDHw+q5VgkbiR+6ffTO564ge7w6fQh/eoQhVdJO8Q==",
|
||||
"requested": "[1.2.0-beta.435, )",
|
||||
"resolved": "1.2.0-beta.435",
|
||||
"contentHash": "TADk7vdGXtfTnYCV7GyleaaRTQjfoSfZXprQrVMm7cSJtJbFc1QIbWPyLvrgrfGdfHbGmUPvaN4ODKNxg2jgPQ==",
|
||||
"dependencies": {
|
||||
"StyleCop.Analyzers.Unstable": "1.2.0.507"
|
||||
"StyleCop.Analyzers.Unstable": "1.2.0.435"
|
||||
}
|
||||
},
|
||||
"Google.OrTools.runtime.linux-arm64": {
|
||||
|
@ -350,8 +350,8 @@
|
|||
},
|
||||
"StyleCop.Analyzers.Unstable": {
|
||||
"type": "Transitive",
|
||||
"resolved": "1.2.0.507",
|
||||
"contentHash": "gTY3IQdRqDJ4hbhSA3e/R48oE8b/OiKfvwkt1QdNVfrJK2gMHBV8ldaHJ885jxWZfllK66soa/sdcjh9bX49Tw=="
|
||||
"resolved": "1.2.0.435",
|
||||
"contentHash": "ouwPWZxbOV3SmCZxIRqHvljkSzkCyi1tDoMzQtDb/bRP8ctASV/iRJr+A2Gdj0QLaLmWnqTWDrH82/iP+X80Lg=="
|
||||
},
|
||||
"System.Buffers": {
|
||||
"type": "Transitive",
|
||||
|
@ -674,6 +674,7 @@
|
|||
"Microsoft.Extensions.Options": "[6.0.0, )",
|
||||
"Microsoft.Toolkit.HighPerformance": "[7.1.1, )",
|
||||
"NetFabric.Hyperlinq": "[3.0.0-beta48, )",
|
||||
"System.CommandLine": "[2.0.0-beta4.22272.1, )",
|
||||
"System.Reactive": "[5.0.0, )"
|
||||
}
|
||||
},
|
||||
|
@ -885,6 +886,12 @@
|
|||
"resolved": "1.0.2",
|
||||
"contentHash": "giLAHrjJe0Bh7yhNexR6pmcv02+Fi+lEPxQVdB8zvkuJCmy6rnqu8CZLIpxrUfLcWDuTCSiK0IfGmMhig3UDhA=="
|
||||
},
|
||||
"System.CommandLine": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[2.0.0-beta4.22272.1, )",
|
||||
"resolved": "2.0.0-beta4.22272.1",
|
||||
"contentHash": "1uqED/q2H0kKoLJ4+hI2iPSBSEdTuhfCYADeJrAqERmiGQ2NNacYKRNEQ+gFbU4glgVyK8rxI+ZOe1onEtr/Pg=="
|
||||
},
|
||||
"System.Linq.Async": {
|
||||
"type": "CentralTransitive",
|
||||
"requested": "[6.0.1, )",
|
||||
|
|
|
@ -119,4 +119,9 @@ public sealed record CompileOptions
|
|||
/// Gets or sets a value indicating whether is benchmark only.
|
||||
/// </summary>
|
||||
public bool IsBenchmarkOnly { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the target compile options.
|
||||
/// </summary>
|
||||
public ITargetCompileOptions TargetCompileOptions { get; set; } = null!;
|
||||
}
|
||||
|
|
|
@ -73,6 +73,14 @@ public interface ICompilerServicesProvider
|
|||
/// <param name="randConst">false for save const into bin.</param>
|
||||
public void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst);
|
||||
|
||||
/// <summary>
|
||||
/// dump the expr as csharp code.
|
||||
/// </summary>
|
||||
/// <param name="expr">expression.</param>
|
||||
/// <param name="prefix">file prefix.</param>
|
||||
/// <param name="dumpDir">file dump ir.</param>
|
||||
public void DumpPatternIR(Expr expr, string prefix, string dumpDir);
|
||||
|
||||
/// <summary>
|
||||
/// print ir type.
|
||||
/// </summary>
|
||||
|
@ -468,6 +476,15 @@ public static class CompilerServices
|
|||
public static void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst = true) =>
|
||||
Provider.DumpCSharpIR(expr, prefix, dumpDir, randConst);
|
||||
|
||||
/// <summary>
|
||||
/// dump the expr as csharp code.
|
||||
/// </summary>
|
||||
/// <param name="expr">expression.</param>
|
||||
/// <param name="prefix">file prefix.</param>
|
||||
/// <param name="dumpDir">file dump ir.</param>
|
||||
public static void DumpPatternIR(Expr expr, string prefix, string dumpDir) =>
|
||||
Provider.DumpPatternIR(expr, prefix, dumpDir);
|
||||
|
||||
public static string Print(IRType type) => Provider.Print(type);
|
||||
|
||||
public static string Print(Expr expr, bool useScript = false) => Provider.Print(expr, useScript);
|
||||
|
@ -583,6 +600,10 @@ internal class CompilerServicesProvider : ICompilerServicesProvider, ICompilerSe
|
|||
public void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst) =>
|
||||
_irprinterProvider.DumpCSharpIR(expr, prefix, dumpDir, randConst);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public void DumpPatternIR(Expr expr, string prefix, string dumpDir) =>
|
||||
_irprinterProvider.DumpPatternIR(expr, prefix, dumpDir);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public string Print(IRType type) => _irprinterProvider.Print(type);
|
||||
|
||||
|
|
|
@ -28,5 +28,6 @@ internal class ConvertersModule : IApplicationPart
|
|||
registrator.RegisterManyInterface<UInt64Converters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<UInt8Converters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<PointerConverters>(reuse: Reuse.Singleton);
|
||||
registrator.RegisterManyInterface<PointerIntConverters>(reuse: Reuse.Singleton);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,3 +30,25 @@ internal class PointerConverters : IPointerSpanConverter<ulong>
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal class PointerIntConverters : IPointerSpanConverter<int>
|
||||
{
|
||||
public void ConvertTo<T>(ReadOnlySpan<Pointer<T>> source, Span<int> dest, CastMode castMode)
|
||||
where T : unmanaged, IEquatable<T>
|
||||
{
|
||||
if (castMode != CastMode.KDefault)
|
||||
{
|
||||
throw new InvalidCastException();
|
||||
}
|
||||
|
||||
if (dest.Length < source.Length)
|
||||
{
|
||||
throw new ArgumentException("Dest buffer is not sufficient.");
|
||||
}
|
||||
|
||||
for (int i = 0; i < source.Length; i++)
|
||||
{
|
||||
dest[i] = checked((int)source[i].Value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -204,6 +204,7 @@ public static class CostUtility
|
|||
{
|
||||
TensorType t => (UInt128)(t.Shape.Aggregate(1D, (acc, x) => acc * (x.IsFixed ? x.FixedValue : 1)) * t.DType.SizeInBytes),
|
||||
TupleType t => t.Fields.Sum(GetMemoryAccess),
|
||||
DistributedType t => GetMemoryAccess(Utilities.DistributedUtility.GetDividedTensorType(t)),
|
||||
_ => 0,
|
||||
};
|
||||
}
|
||||
|
@ -229,6 +230,7 @@ public static class CostUtility
|
|||
{
|
||||
TensorType t => (UInt128)(t.Shape.Aggregate(1D, (acc, x) => acc * (x.IsFixed ? x.FixedValue : 1)) * cyclesPerElement),
|
||||
TupleType t => t.Fields.Sum(GetMemoryAccess),
|
||||
DistributedType t => GetCPUCycles(Utilities.DistributedUtility.GetDividedTensorType(t)),
|
||||
_ => 0,
|
||||
};
|
||||
}
|
||||
|
@ -328,7 +330,7 @@ public static class CostUtility
|
|||
}
|
||||
|
||||
// cost for op similar to broadcast
|
||||
public static Cost GetBroadcastCost(TensorType input, TensorType ret)
|
||||
public static Cost GetBroadcastCost(IRType input, IRType ret)
|
||||
{
|
||||
return new()
|
||||
{
|
||||
|
|
|
@ -114,7 +114,7 @@ public static class DataTypes
|
|||
/// <returns> datatype name.</returns>
|
||||
public static string GetDisplayName(this DataType dataType) => dataType switch
|
||||
{
|
||||
PointerType pointerType => $"({GetDisplayName(pointerType.ElemType)}*)",
|
||||
PointerType pointerType => $"({GetDisplayName(pointerType.ElemType)} *)",
|
||||
PrimType primType => primType.ShortName,
|
||||
ValueType => dataType.ToString(),
|
||||
_ => throw new ArgumentOutOfRangeException(dataType.GetType().Name),
|
||||
|
|
|
@ -42,6 +42,8 @@ public interface IDumpper
|
|||
|
||||
void DumpCSharpIR(Expr expr, string prefix, string? reletivePath = null);
|
||||
|
||||
void DumpPatternIR(Expr expr, string prefix, string? reletivePath = null);
|
||||
|
||||
void DumpModule(IRModule module, string? reletivePath = null);
|
||||
|
||||
Stream OpenFile(string reletivePath, FileMode fileMode = FileMode.Create);
|
||||
|
|
|
@ -46,6 +46,11 @@ public sealed class NullDumpper : IDumpper
|
|||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public void DumpPatternIR(Expr expr, string prefix, string? reletivePath = null)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public bool IsEnabled(DumpFlags dumpFlags) => false;
|
||||
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
using System;
|
||||
using System.Collections;
|
||||
using System.Collections.Generic;
|
||||
using System.Collections.Immutable;
|
||||
using System.Globalization;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using DryIoc.ImTools;
|
||||
|
||||
namespace Nncase.IR;
|
||||
|
||||
public abstract record SBP
|
||||
{
|
||||
public static SBPPartialSum P => SBPPartialSum.Instance;
|
||||
|
||||
public static SBPBroadCast B => SBPBroadCast.Instance;
|
||||
|
||||
public static SBPSplit S(int axis) => new SBPSplit(axis);
|
||||
}
|
||||
|
||||
public sealed record SBPSplit(int Axis) : SBP
|
||||
{
|
||||
public override string ToString() => $"S({Axis})";
|
||||
}
|
||||
|
||||
public sealed record SBPPartialSum : SBP
|
||||
{
|
||||
public static readonly SBPPartialSum Instance = new SBPPartialSum();
|
||||
|
||||
private SBPPartialSum()
|
||||
{
|
||||
}
|
||||
|
||||
public override string ToString() => "P";
|
||||
}
|
||||
|
||||
public sealed record SBPBroadCast : SBP
|
||||
{
|
||||
public static readonly SBPBroadCast Instance = new SBPBroadCast();
|
||||
|
||||
private SBPBroadCast()
|
||||
{
|
||||
}
|
||||
|
||||
public override string ToString() => "B";
|
||||
}
|
||||
|
||||
// public sealed record Placement(Placement.DeviceKind Kind, IRArray<int> Hierarchy, string Name)
|
||||
public sealed record Placement(IRArray<int> Hierarchy, string Name)
|
||||
{
|
||||
// public enum DeviceKind : uint
|
||||
// {
|
||||
// CPU = 0,
|
||||
// }
|
||||
public int Rank => Hierarchy.Count;
|
||||
|
||||
// public override string ToString() => $"@{Kind} [{string.Join(',', Hierarchy.Zip(Name).Select(t => t.First.ToString() + '@' + t.Second.ToString()))}]";
|
||||
public override string ToString() => $"@ [{string.Join(',', Hierarchy.Zip(Name).Select(t => t.First.ToString() + '@' + t.Second.ToString()))}]";
|
||||
}
|
||||
|
||||
public sealed record DistributedType(TensorType TensorType, IRArray<SBP> NdSBP, Placement Placement) : IRType
|
||||
{
|
||||
public override string ToString() => $"{TensorType}, ({string.Join(',', NdSBP)}), {Placement}";
|
||||
}
|
|
@ -17,7 +17,7 @@ namespace Nncase
|
|||
|
||||
public HashSet<Function> Functions => _functions;
|
||||
|
||||
protected override int VisitLeafFunction(Function expr, Unit context)
|
||||
protected override int VisitLeafFunction(Function expr)
|
||||
{
|
||||
_functions.Add(expr);
|
||||
return 0;
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Nncase.IR.Tensors;
|
||||
using Nncase.PatternMatch;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
|
||||
namespace Nncase.IR.Buffers;
|
||||
|
||||
/// <summary>
|
||||
/// BufferLoad expression.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed partial class BufferLoad : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Get the input parameter.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(BufferLoad), 0, "input", IsTensor());
|
||||
|
||||
/// <summary>
|
||||
/// Get the indices.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Indices = new(typeof(BufferLoad), 1, "indices", IsTuple());
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override bool CanFoldConstCall => false;
|
||||
}
|
|
@ -16,7 +16,7 @@ public sealed partial class BufferOf : Op
|
|||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(BufferOf), 0, "input", IsTensor());
|
||||
|
||||
public Schedule.MemoryLocation MemoryLocation { get; }
|
||||
public TIR.MemoryLocation MemoryLocation { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"Schedule.MemoryLocation.{MemoryLocation}";
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Nncase.IR.Tensors;
|
||||
using Nncase.PatternMatch;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
|
||||
namespace Nncase.IR.Buffers;
|
||||
|
||||
/// <summary>
|
||||
/// BufferStore op.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed partial class BufferStore : Op
|
||||
{
|
||||
/// <summary>
|
||||
/// Get the input parameter.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(BufferStore), 0, "input", IsTensor());
|
||||
|
||||
/// <summary>
|
||||
/// Get the indices parameter.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Indices = new(typeof(BufferStore), 1, "indices", IsTuple());
|
||||
|
||||
/// <summary>
|
||||
/// Get the value parameter.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Value = new(typeof(BufferStore), 2, "value", IsScalar());
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override bool CanFoldConstCall => false;
|
||||
}
|
|
@ -17,4 +17,7 @@ public sealed partial class DDrOf : Op
|
|||
/// Get the input parameter.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(DDrOf), 0, "input", IsTensor());
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override bool CanFoldConstCall => false;
|
||||
}
|
||||
|
|
|
@ -41,5 +41,5 @@ public static class Buffer
|
|||
/// <summary>
|
||||
/// create the uninitialized buffer.
|
||||
/// </summary>
|
||||
public static Call Uninitialized(DataType dataType, Schedule.MemoryLocation memoryLocation, Expr shape) => new Call(new Uninitialized(dataType, memoryLocation), shape);
|
||||
public static Call Uninitialized(DataType dataType, TIR.MemoryLocation memoryLocation, Expr shape) => new Call(new Uninitialized(dataType, memoryLocation), shape);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using Nncase.IR.Tensors;
|
||||
using Nncase.PatternMatch;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
|
||||
namespace Nncase.IR.Buffers;
|
||||
|
||||
/// <summary>
|
||||
/// MatchBuffer op.
|
||||
/// todo maybe need united matchbuffer and allocatebuffer.
|
||||
/// </summary>
|
||||
[PatternFunctionalGenerator]
|
||||
public sealed partial class MatchBuffer : Op
|
||||
{
|
||||
public static readonly ParameterInfo Input = new(typeof(MatchBuffer), 0, "input", IsTensor());
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override bool CanFoldConstCall => false;
|
||||
}
|
|
@ -19,11 +19,11 @@ public sealed partial class Uninitialized : Op
|
|||
|
||||
public DataType DType { get; }
|
||||
|
||||
public Schedule.MemoryLocation MemoryLocation { get; }
|
||||
public TIR.MemoryLocation MemoryLocation { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override bool CanFoldConstCall => false;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"{DType.GetCSharpName()}, Schedule.MemoryLocation.{MemoryLocation}";
|
||||
public override string DisplayProperty() => $"{DType.GetCSharpName()}, MemoryLocation.{MemoryLocation}";
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ public abstract class Callable : Expr
|
|||
/// <summary>
|
||||
/// StackVM module kind.
|
||||
/// </summary>
|
||||
public static readonly string StackVMModuleKind = "stackvm";
|
||||
public const string StackVMModuleKind = "stackvm";
|
||||
|
||||
public Callable(string name, string moduleKind, Expr[] operands)
|
||||
: base(operands)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
// <auto-generated>
|
||||
// This code was generated by T4 template.
|
||||
|
@ -57,8 +56,7 @@ public partial class ExprCloner<TContext>
|
|||
return expr.With(
|
||||
condition: Clone(expr.Condition, context),
|
||||
then: Clone(expr.Then, context),
|
||||
@else: Clone(expr.Else, context),
|
||||
paramList: expr.ParamList.Select(p => Clone(p, context)).ToArray()
|
||||
@else: Clone(expr.Else, context)
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -141,31 +139,6 @@ public partial class ExprCloner<TContext>
|
|||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
dimensions: CloneArray(expr.Dimensions, context),
|
||||
strides: CloneArray(expr.Strides, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
buffer: Clone(expr.Buffer, context),
|
||||
indices: CloneArray(expr.Indices, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context)
|
||||
{
|
||||
|
@ -175,16 +148,6 @@ public partial class ExprCloner<TContext>
|
|||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafBufferStore(TIR.BufferStore expr, TContext context)
|
||||
{
|
||||
return expr.With(
|
||||
buffer: Clone(expr.Buffer, context),
|
||||
indices: CloneArray(expr.Indices, context),
|
||||
value: Clone(expr.Value, context)
|
||||
);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override Expr VisitLeafFor(TIR.For expr, TContext context)
|
||||
{
|
||||
|
|
|
@ -102,6 +102,13 @@ public partial class ExprFunctor<TExprResult, TTypeResult> : ExprFunctor<TExprRe
|
|||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult VisitType(TensorType type) => base.VisitType(type, default);
|
||||
|
||||
/// <summary>
|
||||
/// Visit point type.
|
||||
/// </summary>
|
||||
/// <param name="type">pointer type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult VisitType(PointerType type) => base.VisitType(type, default);
|
||||
|
||||
/// <summary>
|
||||
/// Visit tuple type.
|
||||
/// </summary>
|
||||
|
@ -116,6 +123,13 @@ public partial class ExprFunctor<TExprResult, TTypeResult> : ExprFunctor<TExprRe
|
|||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult VisitType(CallableType type) => base.VisitType(type, default);
|
||||
|
||||
/// <summary>
|
||||
/// Visit callable type.
|
||||
/// </summary>
|
||||
/// <param name="type">Callable type.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TTypeResult VisitType(DistributedType type) => base.VisitType(type, default);
|
||||
|
||||
/// <summary>
|
||||
/// Default visit routine.
|
||||
/// </summary>
|
||||
|
@ -135,12 +149,18 @@ public partial class ExprFunctor<TExprResult, TTypeResult> : ExprFunctor<TExprRe
|
|||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(TensorType type, Unit context) => VisitType(type);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(PointerType type, Unit context) => VisitType(type);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(TupleType type, Unit context) => VisitType(type);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(CallableType type, Unit context) => VisitType(type);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult VisitType(DistributedType type, Unit context) => VisitType(type);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public sealed override TTypeResult DefaultVisitType(IRType type, Unit context) => DefaultVisitType(type);
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
// <auto-generated>
|
||||
// This code was generated by T4 template.
|
||||
|
@ -79,6 +78,11 @@ public partial class ExprFunctor<TExprResult, TTypeResult, TContext>
|
|||
/// </summary>
|
||||
internal protected virtual TExprResult VisitTupleConst(TupleConst expr, TContext context) => VisitConst(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.MemSpan"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="Var"/>.
|
||||
/// </summary>
|
||||
|
@ -94,31 +98,11 @@ public partial class ExprFunctor<TExprResult, TTypeResult, TContext>
|
|||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBuffer(TIR.Buffer expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.LogicalBuffer"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => VisitBuffer(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitBuffer(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferLoad"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferRegion"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferStore"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr, TContext context) => DefaultVisit(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.For"/>.
|
||||
/// </summary>
|
||||
|
@ -250,6 +234,13 @@ public partial class ExprFunctor<TExprResult, TTypeResult>
|
|||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitTupleConst(TupleConst expr, Unit context) => VisitTupleConst(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.MemSpan"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr) => base.VisitMemSpan(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitMemSpan(TIR.MemSpan expr, Unit context) => VisitMemSpan(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Var"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitVar(Var expr) => base.VisitVar(expr, default);
|
||||
|
@ -271,27 +262,6 @@ public partial class ExprFunctor<TExprResult, TTypeResult>
|
|||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBuffer(TIR.Buffer expr, Unit context) => VisitBuffer(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.LogicalBuffer"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLogicalBuffer(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLogicalBuffer(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitPhysicalBuffer(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferLoad"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr) => base.VisitBufferLoad(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBufferLoad(TIR.BufferLoad expr, Unit context) => VisitBufferLoad(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferRegion"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default);
|
||||
|
@ -299,13 +269,6 @@ public partial class ExprFunctor<TExprResult, TTypeResult>
|
|||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBufferRegion(TIR.BufferRegion expr, Unit context) => VisitBufferRegion(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferStore"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr) => base.VisitBufferStore(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBufferStore(TIR.BufferStore expr, Unit context) => VisitBufferStore(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.For"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitFor(TIR.For expr) => base.VisitFor(expr, default);
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
// <auto-generated>
|
||||
// This code was generated by T4 template.
|
||||
|
@ -92,6 +91,12 @@ public partial class ExprRewriter<TContext>
|
|||
return RewriteLeafTupleConst(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafMemSpan(TIR.MemSpan expr, TContext context)
|
||||
{
|
||||
return RewriteLeafMemSpan(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafVar(Var expr, TContext context)
|
||||
{
|
||||
|
@ -110,36 +115,12 @@ public partial class ExprRewriter<TContext>
|
|||
return RewriteLeafBuffer(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context)
|
||||
{
|
||||
return RewriteLeafLogicalBuffer(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context)
|
||||
{
|
||||
return RewriteLeafPhysicalBuffer(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context)
|
||||
{
|
||||
return RewriteLeafBufferLoad(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context)
|
||||
{
|
||||
return RewriteLeafBufferRegion(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafBufferStore(TIR.BufferStore expr, TContext context)
|
||||
{
|
||||
return RewriteLeafBufferStore(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override Expr VisitLeafFor(TIR.For expr, TContext context)
|
||||
{
|
||||
|
@ -247,6 +228,11 @@ public partial class ExprRewriter<TContext>
|
|||
/// </summary>
|
||||
protected virtual Expr RewriteLeafTupleConst(TupleConst expr, TContext context) => RewriteLeafConst(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.MemSpan"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafMemSpan(TIR.MemSpan expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Var"/>.
|
||||
/// </summary>
|
||||
|
@ -262,31 +248,11 @@ public partial class ExprRewriter<TContext>
|
|||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBuffer(TIR.Buffer expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.LogicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => RewriteLeafBuffer(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => RewriteLeafBuffer(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.BufferLoad"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.BufferRegion"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.BufferStore"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBufferStore(TIR.BufferStore expr, TContext context) => DefaultRewriteLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.For"/>.
|
||||
/// </summary>
|
||||
|
@ -430,6 +396,14 @@ public partial class ExprRewriter
|
|||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafTupleConst(TupleConst expr, Unit context) => RewriteLeafTupleConst(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.MemSpan"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafMemSpan(TIR.MemSpan expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafMemSpan(TIR.MemSpan expr, Unit context) => RewriteLeafMemSpan(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="Var"/>.
|
||||
/// </summary>
|
||||
|
@ -454,30 +428,6 @@ public partial class ExprRewriter
|
|||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafBuffer(TIR.Buffer expr, Unit context) => RewriteLeafBuffer(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.LogicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr) => RewriteLeafBuffer(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => RewriteLeafLogicalBuffer(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr) => RewriteLeafBuffer(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => RewriteLeafPhysicalBuffer(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.BufferLoad"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBufferLoad(TIR.BufferLoad expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafBufferLoad(TIR.BufferLoad expr, Unit context) => RewriteLeafBufferLoad(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.BufferRegion"/>.
|
||||
/// </summary>
|
||||
|
@ -486,14 +436,6 @@ public partial class ExprRewriter
|
|||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafBufferRegion(TIR.BufferRegion expr, Unit context) => RewriteLeafBufferRegion(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.BufferStore"/>.
|
||||
/// </summary>
|
||||
protected virtual Expr RewriteLeafBufferStore(TIR.BufferStore expr) => DefaultRewriteLeaf(expr);
|
||||
|
||||
/// <inheritdoc />
|
||||
protected sealed override Expr RewriteLeafBufferStore(TIR.BufferStore expr, Unit context) => RewriteLeafBufferStore(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite leaf <see cref="TIR.For"/>.
|
||||
/// </summary>
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
//---------------------------------------------------------------------------------------------------
|
||||
// <auto-generated>
|
||||
// This code was generated by T4 template.
|
||||
|
@ -103,6 +102,13 @@ public partial class ExprVisitor<TExprResult, TTypeResult, TContext>
|
|||
return VisitLeafTupleConst(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitMemSpan(TIR.MemSpan expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafMemSpan(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitVar(Var expr, TContext context)
|
||||
{
|
||||
|
@ -118,24 +124,10 @@ public partial class ExprVisitor<TExprResult, TTypeResult, TContext>
|
|||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, TContext context)
|
||||
protected internal override TExprResult VisitBuffer(TIR.Buffer expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafLogicalBuffer(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafPhysicalBuffer(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitBufferLoad(TIR.BufferLoad expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafBufferLoad(expr, context);
|
||||
return VisitLeafBuffer(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
|
@ -145,13 +137,6 @@ public partial class ExprVisitor<TExprResult, TTypeResult, TContext>
|
|||
return VisitLeafBufferRegion(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitBufferStore(TIR.BufferStore expr, TContext context)
|
||||
{
|
||||
VisitOperands(expr, context);
|
||||
return VisitLeafBufferStore(expr, context);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected internal override TExprResult VisitFor(TIR.For expr, TContext context)
|
||||
{
|
||||
|
@ -270,6 +255,11 @@ public partial class ExprVisitor<TExprResult, TTypeResult, TContext>
|
|||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafTupleConst(TupleConst expr, TContext context) => VisitLeafConst(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.MemSpan"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafMemSpan(TIR.MemSpan expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Var"/>.
|
||||
/// </summary>
|
||||
|
@ -285,31 +275,11 @@ public partial class ExprVisitor<TExprResult, TTypeResult, TContext>
|
|||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBuffer(TIR.Buffer expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.LogicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, TContext context) => VisitLeafBuffer(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, TContext context) => VisitLeafBuffer(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.BufferLoad"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.BufferRegion"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.BufferStore"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBufferStore(TIR.BufferStore expr, TContext context) => DefaultVisitLeaf(expr, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.For"/>.
|
||||
/// </summary>
|
||||
|
@ -353,182 +323,168 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit <see cref="Call"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitCall(Call expr) => base.VisitCall(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitCall(Call expr, Unit context) => VisitCall(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Function"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitFunction(Function expr) => base.VisitFunction(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitFunction(Function expr, Unit context) => VisitFunction(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Fusion"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitFusion(Fusion expr) => base.VisitFusion(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitFusion(Fusion expr, Unit context) => VisitFusion(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="If"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitIf(If expr) => base.VisitIf(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitIf(If expr, Unit context) => VisitIf(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Marker"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitMarker(Marker expr) => base.VisitMarker(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitMarker(Marker expr, Unit context) => VisitMarker(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="None"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitNone(None expr) => base.VisitNone(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitNone(None expr, Unit context) => VisitNone(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Op"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitOp(Op expr) => base.VisitOp(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitOp(Op expr, Unit context) => VisitOp(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="PrimFunctionWrapper"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitPrimFunctionWrapper(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitPrimFunctionWrapper(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TensorConst"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitTensorConst(TensorConst expr) => base.VisitTensorConst(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="IR.Tuple"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitTuple(IR.Tuple expr) => base.VisitTuple(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitTuple(IR.Tuple expr, Unit context) => VisitTuple(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TupleConst"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitTupleConst(TupleConst expr) => base.VisitTupleConst(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitTupleConst(TupleConst expr, Unit context) => VisitTupleConst(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.MemSpan"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitMemSpan(TIR.MemSpan expr) => base.VisitMemSpan(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitMemSpan(TIR.MemSpan expr, Unit context) => VisitMemSpan(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="Var"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitVar(Var expr) => base.VisitVar(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitVar(Var expr, Unit context) => VisitVar(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Block"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBlock(TIR.Block expr) => base.VisitBlock(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBlock(TIR.Block expr, Unit context) => VisitBlock(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.LogicalBuffer"/>.
|
||||
/// Visit <see cref="TIR.Buffer"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLogicalBuffer(expr, default);
|
||||
|
||||
internal protected virtual TExprResult VisitBuffer(TIR.Buffer expr) => base.VisitBuffer(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLogicalBuffer(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitPhysicalBuffer(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitPhysicalBuffer(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferLoad"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferLoad(TIR.BufferLoad expr) => base.VisitBufferLoad(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBufferLoad(TIR.BufferLoad expr, Unit context) => VisitBufferLoad(expr);
|
||||
internal protected sealed override TExprResult VisitBuffer(TIR.Buffer expr, Unit context) => VisitBuffer(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferRegion"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferRegion(TIR.BufferRegion expr) => base.VisitBufferRegion(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBufferRegion(TIR.BufferRegion expr, Unit context) => VisitBufferRegion(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.BufferStore"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitBufferStore(TIR.BufferStore expr) => base.VisitBufferStore(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitBufferStore(TIR.BufferStore expr, Unit context) => VisitBufferStore(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.For"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitFor(TIR.For expr) => base.VisitFor(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitFor(TIR.For expr, Unit context) => VisitFor(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.IfThenElse"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitIfThenElse(TIR.IfThenElse expr) => base.VisitIfThenElse(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitIfThenElse(TIR.IfThenElse expr, Unit context) => VisitIfThenElse(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Let"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitLet(TIR.Let expr) => base.VisitLet(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitLet(TIR.Let expr, Unit context) => VisitLet(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.PrimFunction"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitPrimFunction(TIR.PrimFunction expr) => base.VisitPrimFunction(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitPrimFunction(TIR.PrimFunction expr, Unit context) => VisitPrimFunction(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Sequential"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitSequential(TIR.Sequential expr) => base.VisitSequential(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitSequential(TIR.Sequential expr, Unit context) => VisitSequential(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.Range"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitRange(TIR.Range expr) => base.VisitRange(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitRange(TIR.Range expr, Unit context) => VisitRange(expr);
|
||||
/// <summary>
|
||||
/// Visit <see cref="TIR.IterVar"/>.
|
||||
/// </summary>
|
||||
internal protected virtual TExprResult VisitIterVar(TIR.IterVar expr) => base.VisitIterVar(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal protected sealed override TExprResult VisitIterVar(TIR.IterVar expr, Unit context) => VisitIterVar(expr);
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="BaseFunction"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBaseFunction(BaseFunction expr) => base.VisitLeafBaseFunction(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafBaseFunction(BaseFunction expr, Unit context) => VisitLeafBaseFunction(expr);
|
||||
|
||||
|
@ -536,7 +492,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="Call"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafCall(Call expr) => base.VisitLeafCall(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafCall(Call expr, Unit context) => VisitLeafCall(expr);
|
||||
|
||||
|
@ -544,7 +500,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="Const"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafConst(Const expr) => base.VisitLeafConst(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafConst(Const expr, Unit context) => VisitLeafConst(expr);
|
||||
|
||||
|
@ -552,15 +508,15 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="Function"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafFunction(Function expr) => base.VisitLeafFunction(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override TExprResult VisitLeafFunction(Function expr, Unit context) => VisitLeafFunction(expr);
|
||||
protected sealed override TExprResult VisitLeafFunction(Function expr, Unit context) => VisitLeafFunction(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Fusion"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafFusion(Fusion expr) => base.VisitLeafFusion(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafFusion(Fusion expr, Unit context) => VisitLeafFusion(expr);
|
||||
|
||||
|
@ -568,7 +524,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="If"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafIf(If expr) => base.VisitLeafIf(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafIf(If expr, Unit context) => VisitLeafIf(expr);
|
||||
|
||||
|
@ -576,7 +532,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="Marker"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafMarker(Marker expr) => base.VisitLeafMarker(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafMarker(Marker expr, Unit context) => VisitLeafMarker(expr);
|
||||
|
||||
|
@ -584,7 +540,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="None"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafNone(None expr) => base.VisitLeafNone(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafNone(None expr, Unit context) => VisitLeafNone(expr);
|
||||
|
||||
|
@ -592,7 +548,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="Op"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafOp(Op expr) => base.VisitLeafOp(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafOp(Op expr, Unit context) => VisitLeafOp(expr);
|
||||
|
||||
|
@ -600,7 +556,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="PrimFunctionWrapper"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr) => base.VisitLeafPrimFunctionWrapper(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitLeafPrimFunctionWrapper(expr);
|
||||
|
||||
|
@ -608,7 +564,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="TensorConst"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafTensorConst(TensorConst expr) => base.VisitLeafTensorConst(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafTensorConst(TensorConst expr, Unit context) => VisitLeafTensorConst(expr);
|
||||
|
||||
|
@ -616,7 +572,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="IR.Tuple"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafTuple(IR.Tuple expr) => base.VisitLeafTuple(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafTuple(IR.Tuple expr, Unit context) => VisitLeafTuple(expr);
|
||||
|
||||
|
@ -624,15 +580,23 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="TupleConst"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafTupleConst(TupleConst expr) => base.VisitLeafTupleConst(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafTupleConst(TupleConst expr, Unit context) => VisitLeafTupleConst(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.MemSpan"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafMemSpan(TIR.MemSpan expr) => base.VisitLeafMemSpan(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafMemSpan(TIR.MemSpan expr, Unit context) => VisitLeafMemSpan(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="Var"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafVar(Var expr) => base.VisitLeafVar(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafVar(Var expr, Unit context) => VisitLeafVar(expr);
|
||||
|
||||
|
@ -640,7 +604,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="TIR.Block"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBlock(TIR.Block expr) => base.VisitLeafBlock(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafBlock(TIR.Block expr, Unit context) => VisitLeafBlock(expr);
|
||||
|
||||
|
@ -648,55 +612,23 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="TIR.Buffer"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBuffer(TIR.Buffer expr) => base.VisitLeafBuffer(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafBuffer(TIR.Buffer expr, Unit context) => VisitLeafBuffer(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.LogicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr) => base.VisitLeafLogicalBuffer(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafLogicalBuffer(TIR.LogicalBuffer expr, Unit context) => VisitLeafLogicalBuffer(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr) => base.VisitLeafPhysicalBuffer(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafPhysicalBuffer(TIR.PhysicalBuffer expr, Unit context) => VisitLeafPhysicalBuffer(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.BufferLoad"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr) => base.VisitLeafBufferLoad(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafBufferLoad(TIR.BufferLoad expr, Unit context) => VisitLeafBufferLoad(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.BufferRegion"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr) => base.VisitLeafBufferRegion(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafBufferRegion(TIR.BufferRegion expr, Unit context) => VisitLeafBufferRegion(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.BufferStore"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafBufferStore(TIR.BufferStore expr) => base.VisitLeafBufferStore(expr, default);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafBufferStore(TIR.BufferStore expr, Unit context) => VisitLeafBufferStore(expr);
|
||||
|
||||
/// <summary>
|
||||
/// Visit leaf <see cref="TIR.For"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafFor(TIR.For expr) => base.VisitLeafFor(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafFor(TIR.For expr, Unit context) => VisitLeafFor(expr);
|
||||
|
||||
|
@ -704,7 +636,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="TIR.IfThenElse"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr) => base.VisitLeafIfThenElse(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafIfThenElse(TIR.IfThenElse expr, Unit context) => VisitLeafIfThenElse(expr);
|
||||
|
||||
|
@ -712,7 +644,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="TIR.Let"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafLet(TIR.Let expr) => base.VisitLeafLet(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafLet(TIR.Let expr, Unit context) => VisitLeafLet(expr);
|
||||
|
||||
|
@ -720,7 +652,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="TIR.PrimFunction"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr) => base.VisitLeafPrimFunction(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafPrimFunction(TIR.PrimFunction expr, Unit context) => VisitLeafPrimFunction(expr);
|
||||
|
||||
|
@ -728,7 +660,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="TIR.Sequential"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafSequential(TIR.Sequential expr) => base.VisitLeafSequential(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafSequential(TIR.Sequential expr, Unit context) => VisitLeafSequential(expr);
|
||||
|
||||
|
@ -736,7 +668,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="TIR.Range"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafRange(TIR.Range expr) => base.VisitLeafRange(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafRange(TIR.Range expr, Unit context) => VisitLeafRange(expr);
|
||||
|
||||
|
@ -744,7 +676,7 @@ public partial class ExprVisitor<TExprResult, TTypeResult>
|
|||
/// Visit leaf <see cref="TIR.IterVar"/>.
|
||||
/// </summary>
|
||||
protected virtual TExprResult VisitLeafIterVar(TIR.IterVar expr) => base.VisitLeafIterVar(expr, default);
|
||||
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected sealed override TExprResult VisitLeafIterVar(TIR.IterVar expr, Unit context) => VisitLeafIterVar(expr);
|
||||
|
||||
|
|
|
@ -73,6 +73,14 @@ public interface IIRPrinterProvider
|
|||
/// <param name="randConst">randConst = false will save the const into bin.</param>
|
||||
public void DumpCSharpIR(Expr expr, string prefix, string dumpDir, bool randConst);
|
||||
|
||||
/// <summary>
|
||||
/// dump the expr as csharp code.
|
||||
/// </summary>
|
||||
/// <param name="expr">expression.</param>
|
||||
/// <param name="prefix">file prefix.</param>
|
||||
/// <param name="dumpDir">file dump ir.</param>
|
||||
public void DumpPatternIR(Expr expr, string prefix, string dumpDir);
|
||||
|
||||
/// <summary>
|
||||
/// print ir type.
|
||||
/// </summary>
|
||||
|
|
|
@ -11,14 +11,11 @@ PrimFunctionWrapper,true,true,BaseFunction,,Target
|
|||
TensorConst,true,false,Const,,
|
||||
Tuple,true,false,Default,IR.,@Fields
|
||||
TupleConst,true,false,Const,,
|
||||
MemSpan,true,false,Default,TIR.,Start;Size;
|
||||
Var,true,false,Default,,
|
||||
Block,true,false,Default,TIR.,Body;InitBody;@IterVars;@Reads;@Writes;@AllocBuffers;Predicate
|
||||
Buffer,false,false,Default,TIR.,
|
||||
LogicalBuffer,true,false,Buffer,TIR.,@Dimensions;@Strides
|
||||
PhysicalBuffer,true,false,Buffer,TIR.,
|
||||
BufferLoad,true,false,Default,TIR.,Buffer;@Indices
|
||||
Buffer,true,false,Default,TIR.,MemSpan;@Dimensions;@Strides;
|
||||
BufferRegion,true,false,Default,TIR.,Buffer;@Region
|
||||
BufferStore,true,false,Default,TIR.,Buffer;@Indices;Value
|
||||
For,true,false,Default,TIR.,LoopVar;Domain;Body
|
||||
IfThenElse,true,false,Default,TIR.,Condition;Then;Else
|
||||
Let,true,false,Default,TIR.,Var;Expression;Body
|
||||
|
|
|
|
@ -139,6 +139,15 @@ public sealed record TensorType(DataType DType, Shape Shape) : IRType
|
|||
/// <param name="elemType"> the Pointed Element Type.</param>
|
||||
/// <returns>the pointer tensor type.</returns>
|
||||
public static TensorType Pointer(DataType elemType) => new(new PointerType(elemType), Shape.Scalar);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string ToString() => DType switch
|
||||
{
|
||||
PrimType ptype => ptype.GetDisplayName() + (Shape.IsScalar ? string.Empty : Shape.ToString()),
|
||||
PointerType { ElemType: PrimType etype } => $"*{etype.GetDisplayName()}",
|
||||
ValueType => $"{DType}",
|
||||
_ => throw new NotSupportedException(DType.GetType().Name),
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -21,7 +21,7 @@ public sealed partial class ResizeImage : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(ResizeImage), 0, "input", HasRank(r => r >= 2, "RanK >= 2"));
|
||||
public static readonly ParameterInfo Input = new(typeof(ResizeImage), 0, "input", HasRank(r => r >= 2, "RanK >= 2"), ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets roi.
|
||||
|
|
|
@ -20,12 +20,12 @@ public sealed partial class Binary : Op
|
|||
/// <summary>
|
||||
/// Gets lhs.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Lhs = new(typeof(Binary), 0, "lhs");
|
||||
public static readonly ParameterInfo Lhs = new(typeof(Binary), 0, "lhs", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets rhs.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Rhs = new(typeof(Binary), 1, "rhs");
|
||||
public static readonly ParameterInfo Rhs = new(typeof(Binary), 1, "rhs", ParameterKind.Input);
|
||||
|
||||
public BinaryOp BinaryOp { get; }
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ public sealed partial class Clamp : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Clamp), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(Clamp), 0, "input", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets min.
|
||||
|
|
|
@ -20,10 +20,10 @@ public sealed partial class MatMul : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Lhs = new(typeof(MatMul), 0, "lhs");
|
||||
public static readonly ParameterInfo Lhs = new(typeof(MatMul), 0, "lhs", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets Other.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Rhs = new(typeof(MatMul), 1, "rhs");
|
||||
public static readonly ParameterInfo Rhs = new(typeof(MatMul), 1, "rhs", ParameterKind.Input);
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ public sealed partial class ReduceArg : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(ReduceArg), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(ReduceArg), 0, "input", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets Axis.
|
||||
|
@ -42,8 +42,8 @@ public sealed partial class ReduceArg : Op
|
|||
|
||||
public ReduceArgOp ReduceArgOp { get; }
|
||||
|
||||
public DataType DestType { get; }
|
||||
public PrimType DestType { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"ReduceArgOp.{ReduceArgOp}";
|
||||
public override string DisplayProperty() => $"ReduceArgOp.{ReduceArgOp}, DestType: {DestType}";
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ public sealed partial class Unary : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Unary), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(Unary), 0, "input", ParameterKind.Input);
|
||||
|
||||
public UnaryOp UnaryOp { get; }
|
||||
|
||||
|
|
|
@ -154,7 +154,12 @@ public sealed partial class Swish : ActivationOp
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Swish), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(Swish), 0, "input", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets beta.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Beta = new(typeof(Swish), 1, "beta", IsFloatScalar());
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -21,17 +21,17 @@ public sealed partial class Conv2D : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Conv2D), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(Conv2D), 0, "input", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets Weights.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Weights = new(typeof(Conv2D), 1, "weights", HasRank(4));
|
||||
public static readonly ParameterInfo Weights = new(typeof(Conv2D), 1, "weights", HasRank(4), ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets Bias.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Bias = new(typeof(Conv2D), 2, "bias", HasRank(1));
|
||||
public static readonly ParameterInfo Bias = new(typeof(Conv2D), 2, "bias", HasRank(1), ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets Stride.
|
||||
|
|
|
@ -34,7 +34,7 @@ public static class NN
|
|||
|
||||
public static Call BatchNormalization(Expr input, Expr scale, Expr bias, Expr input_mean, Expr input_var, Expr epsilon, Expr momentum) => new Call(new BatchNormalization(), input, scale, bias, input_mean, input_var, epsilon, momentum);
|
||||
|
||||
public static Call LayerNorm(int axis, float epsilon, Expr input, Expr scale, Expr bias) => new Call(new LayerNorm(axis, epsilon), input, scale, bias);
|
||||
public static Call LayerNorm(int axis, float epsilon, Expr input, Expr scale, Expr bias, bool hasMean = true) => new Call(new LayerNorm(axis, epsilon, hasMean), input, scale, bias);
|
||||
|
||||
public static Call BatchToSpace(Expr input, Expr blockShape, Expr crops) => new Call(new BatchToSpace(), input, blockShape, crops);
|
||||
|
||||
|
@ -103,5 +103,10 @@ public static class NN
|
|||
/// <summary>
|
||||
/// create Swish call.
|
||||
/// </summary>
|
||||
public static Call Swish(Expr input) => new Call(new Swish(), input);
|
||||
public static Call Swish(Expr input) => new Call(new Swish(), input, 1f);
|
||||
|
||||
/// <summary>
|
||||
/// create Swish call.
|
||||
/// </summary>
|
||||
public static Call Swish(Expr input, Expr beta) => new Call(new Swish(), input, beta);
|
||||
}
|
||||
|
|
|
@ -21,19 +21,23 @@ public sealed partial class LayerNorm : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(LayerNorm), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(LayerNorm), 0, "input", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets scale.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Scale = new(typeof(LayerNorm), 1, "scale");
|
||||
public static readonly ParameterInfo Scale = new(typeof(LayerNorm), 1, "scale", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets bias.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Bias = new(typeof(LayerNorm), 2, "bias");
|
||||
public static readonly ParameterInfo Bias = new(typeof(LayerNorm), 2, "bias", ParameterKind.Input);
|
||||
|
||||
public int Axis { get; }
|
||||
|
||||
public float Epsilon { get; }
|
||||
|
||||
public bool UseMean { get; }
|
||||
|
||||
public override string DisplayProperty() => $"Axis: {Axis}, Epsilon: {Epsilon}, UseMean: {UseMean}";
|
||||
}
|
||||
|
|
|
@ -61,17 +61,17 @@ public sealed partial class InstanceNormalization : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(InstanceNormalization), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(InstanceNormalization), 0, "input", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Scale = new(typeof(InstanceNormalization), 1, "scale");
|
||||
public static readonly ParameterInfo Scale = new(typeof(InstanceNormalization), 1, "scale", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Bias = new(typeof(InstanceNormalization), 2, "bias");
|
||||
public static readonly ParameterInfo Bias = new(typeof(InstanceNormalization), 2, "bias", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets Epsilon.
|
||||
|
|
|
@ -33,7 +33,7 @@ public sealed partial class Softmax : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Softmax), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(Softmax), 0, "input", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets axis.
|
||||
|
|
|
@ -12,6 +12,12 @@ using static Nncase.IR.TypePatternUtility;
|
|||
|
||||
namespace Nncase.IR;
|
||||
|
||||
public enum ParameterKind : int
|
||||
{
|
||||
Input,
|
||||
Attribute,
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Parameter information.
|
||||
/// </summary>
|
||||
|
@ -24,11 +30,13 @@ public sealed class ParameterInfo
|
|||
/// <param name="ownerType">this op type.</param>
|
||||
/// <param name="index">param index.</param>
|
||||
/// <param name="name">param name.</param>
|
||||
public ParameterInfo(Type ownerType, int index, string name)
|
||||
/// <param name="parameterKind">kind.</param>
|
||||
public ParameterInfo(Type ownerType, int index, string name, ParameterKind parameterKind = ParameterKind.Attribute)
|
||||
{
|
||||
OwnerType = ownerType;
|
||||
Index = index;
|
||||
Name = name;
|
||||
ParameterKind = parameterKind;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -39,8 +47,9 @@ public sealed class ParameterInfo
|
|||
/// <param name="index">param index.</param>
|
||||
/// <param name="name">param name.</param>
|
||||
/// <param name="pattern">the param condition.</param>
|
||||
public ParameterInfo(Type ownerType, int index, string name, TypePattern pattern)
|
||||
: this(ownerType, index, name)
|
||||
/// <param name="parameterKind">kind.</param>
|
||||
public ParameterInfo(Type ownerType, int index, string name, TypePattern pattern, ParameterKind parameterKind = ParameterKind.Attribute)
|
||||
: this(ownerType, index, name, parameterKind)
|
||||
{
|
||||
Pattern = pattern;
|
||||
}
|
||||
|
@ -60,6 +69,11 @@ public sealed class ParameterInfo
|
|||
/// </summary>
|
||||
public string Name { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets parameter kind.
|
||||
/// </summary>
|
||||
public ParameterKind ParameterKind { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets this paramter's type condition.
|
||||
/// </summary>
|
||||
|
@ -90,7 +104,7 @@ public abstract class Op : Expr
|
|||
/// <summary>
|
||||
/// Gets get the parameters.
|
||||
/// </summary>
|
||||
public IEnumerable<ParameterInfo> Parameters =>
|
||||
public virtual IEnumerable<ParameterInfo> Parameters =>
|
||||
_parameters ??= (from p in GetType().GetFields(BindingFlags.Public | BindingFlags.Static)
|
||||
where p.FieldType == typeof(ParameterInfo)
|
||||
let param = (ParameterInfo)(p.GetValue(null) ?? throw new InvalidOperationException())
|
||||
|
|
|
@ -19,5 +19,5 @@ namespace Nncase.IR.F;
|
|||
public static class RNN
|
||||
{
|
||||
public static Call LSTM(LSTMDirection direction, LSTMLayout layout, string[] acts, Expr x, Expr w, Expr r, Expr b, Expr seqLens, Expr initH, Expr initC, Expr p, Expr actAlpha, Expr actBeta, Expr clip, Expr hiddenSize, Expr inputForget, Expr outputSize) =>
|
||||
new Call(new IR.Tensors.LSTM(direction, layout, acts), x, w, r, b, seqLens, initH, initC, p, actAlpha, actBeta, clip, hiddenSize, inputForget, outputSize);
|
||||
new Call(new IR.RNN.LSTM(direction, layout, acts), x, w, r, b, seqLens, initH, initC, p, actAlpha, actBeta, clip, hiddenSize, inputForget, outputSize);
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ using System.Collections.Immutable;
|
|||
using Nncase.PatternMatch;
|
||||
using static Nncase.IR.TypePatternUtility;
|
||||
|
||||
namespace Nncase.IR.Tensors;
|
||||
namespace Nncase.IR.RNN;
|
||||
|
||||
/// <summary>
|
||||
/// LSTM expression.
|
||||
|
|
|
@ -146,7 +146,7 @@ public sealed class TensorConst : Const, IEquatable<TensorConst?>
|
|||
public override bool Equals(object? obj) => Equals(obj as TensorConst);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public bool Equals(TensorConst? other) => other is not null && base.Equals(other) && EqualityComparer<Tensor>.Default.Equals(Value, other.Value);
|
||||
public bool Equals(TensorConst? other) => other is not null && (ReferenceEquals(this, other) || GetHashCode() == other.GetHashCode()) && EqualityComparer<Tensor>.Default.Equals(Value, other.Value);
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override int GetHashCodeCore() => HashCode.Combine(Value);
|
||||
|
|
|
@ -20,7 +20,7 @@ public sealed partial class Cast : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Cast), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(Cast), 0, "input", ParameterKind.Input);
|
||||
|
||||
public DataType NewType { get; }
|
||||
|
||||
|
|
|
@ -20,10 +20,13 @@ public sealed partial class Concat : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Concat), 0, "inputs");
|
||||
public static readonly ParameterInfo Input = new(typeof(Concat), 0, "inputs", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets axis.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Axis = new(typeof(Concat), 1, "axis");
|
||||
public int Axis { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"Axis: {Axis}";
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ public sealed partial class Expand : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Expand), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(Expand), 0, "input", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets shape.
|
||||
|
|
|
@ -70,7 +70,7 @@ public static class Tensors
|
|||
public static Call Cast(Expr input, DataType newType, CastMode castMode = CastMode.KDefault) =>
|
||||
new Call(new Cast(newType, castMode), input);
|
||||
|
||||
public static Call Concat(Expr input, Expr axis) => new Call(new Concat(), input, axis);
|
||||
public static Call Concat(Expr input, int axis) => new Call(new Concat(axis), input);
|
||||
|
||||
public static Call ConstantOfShape(Expr shape, Expr value) => new Call(new ConstantOfShape(), shape, value);
|
||||
|
||||
|
@ -89,7 +89,7 @@ public static class Tensors
|
|||
|
||||
public static Call Flatten(Expr input, Expr axis) => new Call(new Flatten(), input, axis);
|
||||
|
||||
public static Call Gather(Expr input, Expr axis, Expr index) => new Call(new Gather(), input, axis, index);
|
||||
public static Call Gather(Expr input, int axis, Expr index) => new Call(new Gather(axis), input, index);
|
||||
|
||||
public static Call GatherElements(Expr input, Expr axis, Expr indices) =>
|
||||
new Call(new GatherElements(), input, axis, indices);
|
||||
|
|
|
@ -22,15 +22,18 @@ public sealed partial class Gather : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Gather), 0, "input");
|
||||
|
||||
/// <summary>
|
||||
/// Gets axis.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Axis = new(typeof(Gather), 1, "axis", IsIntegralScalar());
|
||||
public static readonly ParameterInfo Input = new(typeof(Gather), 0, "input", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets index.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Index = new(typeof(Gather), 2, "index", IsIntegral());
|
||||
public static readonly ParameterInfo Index = new(typeof(Gather), 1, "index", IsIntegral(), ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets axis.
|
||||
/// </summary>
|
||||
public int Axis { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string DisplayProperty() => $"Axis: {Axis}";
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ public sealed partial class Reshape : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Reshape), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(Reshape), 0, "input", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets shape.
|
||||
|
|
|
@ -21,7 +21,7 @@ public sealed partial class Slice : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Slice), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(Slice), 0, "input", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets begins.
|
||||
|
|
|
@ -15,7 +15,7 @@ public sealed partial class Transpose : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Transpose), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(Transpose), 0, "input", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets perm.
|
||||
|
|
|
@ -23,7 +23,7 @@ public sealed partial class Unsqueeze : Op
|
|||
/// <summary>
|
||||
/// Gets input.
|
||||
/// </summary>
|
||||
public static readonly ParameterInfo Input = new(typeof(Unsqueeze), 0, "input");
|
||||
public static readonly ParameterInfo Input = new(typeof(Unsqueeze), 0, "input", ParameterKind.Input);
|
||||
|
||||
/// <summary>
|
||||
/// Gets dimension.
|
||||
|
|
|
@ -32,6 +32,7 @@ public abstract class TypeFunctor<TResult, TContext>
|
|||
TensorType t => VisitType(t, context),
|
||||
TupleType t => VisitType(t, context),
|
||||
CallableType t => VisitType(t, context),
|
||||
DistributedType t => VisitType(t, context),
|
||||
_ => DefaultVisitType(type, context),
|
||||
};
|
||||
}
|
||||
|
@ -68,6 +69,14 @@ public abstract class TypeFunctor<TResult, TContext>
|
|||
/// <returns>Result.</returns>
|
||||
public virtual TResult VisitType(TensorType type, TContext context) => DefaultVisitType(type, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit pointer type.
|
||||
/// </summary>
|
||||
/// <param name="type">Pointer type.</param>
|
||||
/// <param name="context">Context.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TResult VisitType(PointerType type, TContext context) => DefaultVisitType(type, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit tuple type.
|
||||
/// </summary>
|
||||
|
@ -84,6 +93,14 @@ public abstract class TypeFunctor<TResult, TContext>
|
|||
/// <returns>Result.</returns>
|
||||
public virtual TResult VisitType(CallableType type, TContext context) => DefaultVisitType(type, context);
|
||||
|
||||
/// <summary>
|
||||
/// Visit dist tensor type.
|
||||
/// </summary>
|
||||
/// <param name="type">dist tensor type.</param>
|
||||
/// <param name="context">Context.</param>
|
||||
/// <returns>Result.</returns>
|
||||
public virtual TResult VisitType(DistributedType type, TContext context) => DefaultVisitType(type, context);
|
||||
|
||||
/// <summary>
|
||||
/// Default visit routine.
|
||||
/// </summary>
|
||||
|
|
|
@ -57,12 +57,12 @@ public record TypePattern(Func<IRType, bool> Cond, string Reason)
|
|||
public T Check<T>(T valueType, string fieldName)
|
||||
where T : IRType
|
||||
{
|
||||
if (valueType is TensorType tensorValueType && tensorValueType.Shape.IsUnranked)
|
||||
if (valueType is TensorType { Shape: { IsUnranked: true } } || valueType is DistributedType { TensorType: { Shape: { IsUnranked: true } } })
|
||||
{
|
||||
return valueType;
|
||||
}
|
||||
|
||||
if (valueType == null || !MatchLeaf(valueType))
|
||||
if (valueType == null || (valueType is TensorType t && !MatchLeaf(t)) || (valueType is DistributedType d && !MatchLeaf(d.TensorType)))
|
||||
{
|
||||
var cur = valueType is null ? "None" : CompilerServices.Print(valueType);
|
||||
throw new InvalidOperationException($"{fieldName} Requrie <{Reason}>, But {cur}!");
|
||||
|
@ -187,6 +187,7 @@ public static partial class TypePatternUtility
|
|||
x => x switch
|
||||
{
|
||||
TensorType ttype => DataTypes.IsIntegral(ttype.DType),
|
||||
DistributedType distributedType => DataTypes.IsIntegral(distributedType.TensorType.DType),
|
||||
_ => false,
|
||||
},
|
||||
"IsIntegral");
|
||||
|
|
|
@ -13,6 +13,13 @@ using Nncase.Quantization;
|
|||
|
||||
namespace Nncase;
|
||||
|
||||
/// <summary>
|
||||
/// The targets own compile options.
|
||||
/// </summary>
|
||||
public interface ITargetCompileOptions
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Target.
|
||||
/// </summary>
|
||||
|
@ -23,6 +30,12 @@ public interface ITarget
|
|||
/// </summary>
|
||||
string Kind { get; }
|
||||
|
||||
/// <summary>
|
||||
/// create the current target's command and parser.
|
||||
/// </summary>
|
||||
/// <returns>command.</returns>
|
||||
(System.CommandLine.Command Command, Func<System.CommandLine.Invocation.InvocationContext, System.CommandLine.Command, ITargetCompileOptions> Parser) RegisterCommandAndParser();
|
||||
|
||||
/// <summary>
|
||||
/// Bind Quant Method And Quant Cosine With IR.
|
||||
/// </summary>
|
||||
|
@ -91,3 +104,12 @@ public interface ITarget
|
|||
/// <returns>Module builder.</returns>
|
||||
IModuleBuilder CreateModuleBuilder(string moduleKind, CompileOptions options);
|
||||
}
|
||||
|
||||
public sealed class DefaultTargetCompileOptions : ITargetCompileOptions
|
||||
{
|
||||
public static readonly DefaultTargetCompileOptions Instance = new();
|
||||
|
||||
private DefaultTargetCompileOptions()
|
||||
{
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,21 @@ namespace Nncase;
|
|||
/// </summary>
|
||||
public static class LinqExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Get the ranges from range desc.
|
||||
/// </summary>
|
||||
/// <param name="stride">stride.</param>
|
||||
/// <param name="start">start.</param>
|
||||
/// <param name="stop">stop.</param>
|
||||
/// <returns>Ranges.</returns>
|
||||
public static IEnumerable<Range> Ranges(this int stride, int start, int stop)
|
||||
{
|
||||
for (int i = start; i < stop; i += stride)
|
||||
{
|
||||
yield return new Range(i, Math.Min(stop, i + stride));
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get cartesian product.
|
||||
/// </summary>
|
||||
|
@ -31,6 +46,23 @@ public static class LinqExtensions
|
|||
select accseq.Concat(new[] { item }));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the permutation of the source.
|
||||
/// </summary>
|
||||
/// <typeparam name="T">Element type.</typeparam>
|
||||
/// <param name="source">Source sequences.</param>
|
||||
/// <returns>Permutated sequences.</returns>
|
||||
public static IEnumerable<T[]> Permutate<T>(this IEnumerable<T> source)
|
||||
{
|
||||
return Permutation(source, Enumerable.Empty<T>());
|
||||
|
||||
IEnumerable<T[]> Permutation(IEnumerable<T> reminder, IEnumerable<T> prefix) =>
|
||||
!reminder.Any() ? new[] { prefix.ToArray() } :
|
||||
reminder.SelectMany((c, i) => Permutation(
|
||||
reminder.Take(i).Concat(reminder.Skip(i + 1)).ToArray(),
|
||||
prefix.Append(c)));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// take or default.
|
||||
/// </summary>
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
<Nullable>enable</Nullable>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<GenerateDocumentationFile>true</GenerateDocumentationFile>
|
||||
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
|
||||
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
|
||||
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
|
||||
</PropertyGroup>
|
||||
<ItemGroup>
|
||||
|
@ -21,6 +21,7 @@
|
|||
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" />
|
||||
<PackageReference Include="Microsoft.Extensions.Options" />
|
||||
<PackageReference Include="Microsoft.Toolkit.HighPerformance" />
|
||||
<PackageReference Include="System.CommandLine" />
|
||||
<PackageReference Include="NetFabric.Hyperlinq" />
|
||||
<PackageReference Include="System.Reactive" />
|
||||
<PackageReference Include="GiGraph.Dot" />
|
||||
|
|
|
@ -12,35 +12,43 @@ using Nncase.TIR;
|
|||
namespace Nncase.Passes.Mutators;
|
||||
|
||||
/// <summary>
|
||||
/// Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional Load/Store. Also remove Block to ensure that the flattened TIR can not be scheduled again.
|
||||
/// Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional Load/Store.
|
||||
/// </summary>
|
||||
public sealed class FlattenBuffer : ExprRewriter
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
protected override Expr RewriteLeafBlock(Block expr)
|
||||
{
|
||||
if (!expr.IterVars.IsEmpty)
|
||||
{
|
||||
throw new InvalidOperationException("Non-opaque blocks are not allowed in FlattenBuffer. Please call pass ConvertBlocksToOpaque before.");
|
||||
}
|
||||
|
||||
// 1. Visit the body
|
||||
var predicate = expr.Predicate;
|
||||
if (predicate is TensorConst { Value: { Length: 1 } t }
|
||||
&& t.ToScalar<bool>())
|
||||
// TODO: put the unfold block into this.
|
||||
if (expr.Predicate is TensorConst tc && tc.Value.ToScalar<bool>() == true)
|
||||
{
|
||||
return expr.Body;
|
||||
}
|
||||
else
|
||||
|
||||
return T.Nop();
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
protected override Expr RewriteLeafCall(Call expr)
|
||||
{
|
||||
if (expr.Target is IR.Buffers.BufferLoad)
|
||||
{
|
||||
return new IfThenElse(predicate, expr.Body);
|
||||
var indices = (IR.Tuple)expr[IR.Buffers.BufferLoad.Indices];
|
||||
var input = (TIR.Buffer)expr[IR.Buffers.BufferLoad.Input];
|
||||
return T.Load(input.MemSpan, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i])));
|
||||
}
|
||||
else if (expr.Target is IR.Buffers.BufferStore)
|
||||
{
|
||||
var indices = (IR.Tuple)expr[IR.Buffers.BufferStore.Indices];
|
||||
var input = (TIR.Buffer)expr[IR.Buffers.BufferStore.Input];
|
||||
return T.Store(input.MemSpan, Enumerable.Range(0, indices.Count).Aggregate((Expr)0, (acc, i) => acc + (input.Strides[i] * indices[i])), expr[IR.Buffers.BufferStore.Value]);
|
||||
}
|
||||
else if (expr.Target is IR.Buffers.MatchBuffer && expr.Arguments[0] is TIR.Buffer { MemSpan: { Start: Const or Var } })
|
||||
{
|
||||
// remove the all fixed match operation.
|
||||
return T.Nop();
|
||||
}
|
||||
|
||||
// Step 3. Handle allocations in reverse order
|
||||
// TODO add the alloc buffers.
|
||||
// for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
|
||||
// const Buffer& buffer = new_block->alloc_buffers[i - 1];
|
||||
// body = MakeAllocStmt(buffer, std::move(body));
|
||||
// }
|
||||
return expr;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System.Reactive;
|
||||
using Nncase.Evaluator;
|
||||
using Nncase.IR;
|
||||
using Nncase.Passes;
|
||||
using Nncase.TIR;
|
||||
|
||||
namespace Nncase.Passes.Mutators;
|
||||
|
||||
/// <summary>
|
||||
/// remove buffer BaseMentOf/DDrOf/MmuOF.
|
||||
/// </summary>
|
||||
public sealed class FoldBufferSlot : ExprRewriter
|
||||
{
|
||||
protected internal override Expr VisitPrimFunction(TIR.PrimFunction expr, Unit context)
|
||||
{
|
||||
if (expr.SchedResult.IsScheduled == true)
|
||||
{
|
||||
return base.VisitPrimFunction(expr, context);
|
||||
}
|
||||
|
||||
return expr;
|
||||
}
|
||||
|
||||
protected override Expr RewriteLeafCall(Call expr)
|
||||
{
|
||||
if (expr.Target is IR.Buffers.BaseMentOf)
|
||||
{
|
||||
var locate = ((TIR.MemSpan)expr.Arguments[0]).Location;
|
||||
return locate switch
|
||||
{
|
||||
MemoryLocation.Input => 0,
|
||||
MemoryLocation.Output => 1,
|
||||
MemoryLocation.Rdata => 2,
|
||||
MemoryLocation.Data => 3,
|
||||
_ => throw new ArgumentOutOfRangeException($"You Can't Assgin The BaseMent For {locate}!"),
|
||||
};
|
||||
}
|
||||
else if (expr.Target is IR.Buffers.DDrOf)
|
||||
{
|
||||
if (expr.Arguments[0] is TIR.MemSpan buf)
|
||||
{
|
||||
return buf.Start;
|
||||
}
|
||||
}
|
||||
|
||||
return expr;
|
||||
}
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System.Reactive;
|
||||
using NetFabric.Hyperlinq;
|
||||
using Nncase.Evaluator;
|
||||
using Nncase.IR;
|
||||
using Nncase.Passes;
|
||||
|
||||
namespace Nncase.Passes.Mutators;
|
||||
|
||||
/// <summary>
|
||||
/// fold math calc operator.
|
||||
/// </summary>
|
||||
public sealed class FoldMathCall : ExprRewriter
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
protected override Expr RewriteLeafCall(Call expr)
|
||||
{
|
||||
if (expr.Target is Op op && op.GetType().Namespace is string @namespace
|
||||
&& @namespace.StartsWith("Nncase.IR.Math"))
|
||||
{
|
||||
return expr.Arguments.AsValueEnumerable().All(x => x is Const)
|
||||
? Const.FromValue(CompilerServices.Evaluate(expr))
|
||||
: expr;
|
||||
}
|
||||
|
||||
return expr;
|
||||
}
|
||||
}
|
|
@ -50,10 +50,4 @@ public static class Mutator
|
|||
/// </summary>
|
||||
/// <returns>RemoveNop.</returns>
|
||||
public static Func<ExprRewriter> RemoveNop() => () => new Mutators.RemoveNop();
|
||||
|
||||
/// <summary>
|
||||
/// fold math calc operator.
|
||||
/// </summary>
|
||||
/// <returns>FoldMathCall.</returns>
|
||||
public static Func<ExprRewriter> FoldMathCall() => () => new Mutators.FoldMathCall();
|
||||
}
|
||||
|
|
|
@ -144,7 +144,10 @@ public sealed class UnRollLoopSequential : ExprRewriter
|
|||
}
|
||||
}
|
||||
|
||||
protected override Expr VisitLeafPhysicalBuffer(PhysicalBuffer expr, Unit context) => expr;
|
||||
protected override Expr VisitLeafMemSpan(MemSpan expr, Unit context)
|
||||
{
|
||||
return expr.With(Clone(expr.Start, context), Clone(expr.Size, context));
|
||||
}
|
||||
|
||||
protected override Expr VisitLeafVar(Var expr, Unit context)
|
||||
{
|
||||
|
@ -189,9 +192,10 @@ public sealed class UnRollLoopSequential : ExprRewriter
|
|||
return CSE(expr.With(start: Clone(expr.Start, context), stop: Clone(expr.Stop, context), step: Clone(expr.Step, context)));
|
||||
}
|
||||
|
||||
protected override Expr VisitLeafLogicalBuffer(LogicalBuffer expr, Unit context)
|
||||
protected override Expr VisitLeafBuffer(TIR.Buffer expr, Unit context)
|
||||
{
|
||||
return expr.With(
|
||||
memSpan: Clone<MemSpan>(expr.MemSpan, context),
|
||||
dimensions: CloneArray(expr.Dimensions, context).Select(e => CSE(e)).ToArray(),
|
||||
strides: CloneArray(expr.Strides, context));
|
||||
}
|
||||
|
|
|
@ -10,52 +10,6 @@ using Nncase.TIR;
|
|||
|
||||
namespace Nncase.Schedule;
|
||||
|
||||
/// <summary>
|
||||
/// the memory type.
|
||||
/// </summary>
|
||||
public enum MemoryLocation : byte
|
||||
{
|
||||
/// <summary>
|
||||
/// input.
|
||||
/// </summary>
|
||||
Input = 0,
|
||||
|
||||
/// <summary>
|
||||
/// output.
|
||||
/// </summary>
|
||||
Output = 1,
|
||||
|
||||
/// <summary>
|
||||
/// constant data.
|
||||
/// </summary>
|
||||
Rdata = 2,
|
||||
|
||||
/// <summary>
|
||||
/// compute temp data.
|
||||
/// </summary>
|
||||
Data = 3,
|
||||
|
||||
/// <summary>
|
||||
/// shared data.
|
||||
/// </summary>
|
||||
SharedData = 4,
|
||||
|
||||
/// <summary>
|
||||
/// l2 data.
|
||||
/// </summary>
|
||||
L2Data = 5,
|
||||
|
||||
/// <summary>
|
||||
/// L1 data.
|
||||
/// </summary>
|
||||
L1Data = 6,
|
||||
|
||||
/// <summary>
|
||||
/// base addr.
|
||||
/// </summary>
|
||||
PrivateBase = 64,
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// the scheduler interface.
|
||||
/// </summary>
|
||||
|
@ -261,12 +215,12 @@ public sealed class SchedFunctionResult
|
|||
/// <summary>
|
||||
/// Gets the buffer allocation.
|
||||
/// </summary>
|
||||
public HashSet<TIR.PhysicalBuffer> Rdatas { get; }
|
||||
public Dictionary<IR.Const, ValueRange<long>> Rdatas { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the data section length.
|
||||
/// </summary>
|
||||
public int DataUsage { get; set; }
|
||||
public long DataUsage { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets a value indicating whether the Scheduled status.
|
||||
|
@ -296,8 +250,8 @@ public sealed class SchedFunctionResult
|
|||
return true;
|
||||
}
|
||||
|
||||
return EqualityComparer<HashSet<TIR.PhysicalBuffer>>.Default.Equals(Rdatas, result.Rdatas) &&
|
||||
EqualityComparer<int>.Default.Equals(DataUsage, result.DataUsage);
|
||||
return EqualityComparer<Dictionary<IR.Const, ValueRange<long>>>.Default.Equals(Rdatas, result.Rdatas) &&
|
||||
EqualityComparer<long>.Default.Equals(DataUsage, result.DataUsage);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
|
|
|
@ -267,31 +267,34 @@ public record SelectedRange(int Start, int End, Padding Padding)
|
|||
/// <summary>
|
||||
/// buffer.
|
||||
/// </summary>
|
||||
public abstract class Buffer : Expr
|
||||
public sealed class Buffer : Expr
|
||||
{
|
||||
public Buffer(string name, DataType elemType, Schedule.MemoryLocation memoryLocation, Expr[] operands)
|
||||
: base(operands.ToArray())
|
||||
public Buffer(string name, DataType elemType, MemSpan memSpan, Expr[] dimensions, Expr[] strides)
|
||||
: base(new[] { memSpan }.Concat(dimensions).Concat(strides))
|
||||
{
|
||||
Name = name;
|
||||
ElemType = elemType;
|
||||
MemLocation = memoryLocation;
|
||||
Rank = dimensions.Length;
|
||||
}
|
||||
|
||||
public string Name { get; }
|
||||
|
||||
public DataType ElemType { get; }
|
||||
|
||||
public Schedule.MemoryLocation MemLocation { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets if this buffer from the constant !.
|
||||
/// </summary>
|
||||
public TensorConst? Const { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets rank of the tensor: number of dimensions.
|
||||
/// </summary>
|
||||
public abstract int Rank { get; }
|
||||
public int Rank { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the shape.
|
||||
/// </summary>
|
||||
public MemSpan MemSpan => (MemSpan)Operands[0];
|
||||
|
||||
/// <summary>
|
||||
/// Gets the shape.
|
||||
/// </summary>
|
||||
public ReadOnlySpan<Expr> Dimensions => Operands[1..(1 + Rank)];
|
||||
|
||||
/// <summary>
|
||||
/// Gets the strides.
|
||||
|
@ -299,201 +302,23 @@ public abstract class Buffer : Expr
|
|||
/// This Strides is by elements not by bytes!
|
||||
/// </remarks>
|
||||
/// </summary>
|
||||
public abstract ReadOnlySpan<Expr> Strides { get; }
|
||||
public ReadOnlySpan<Expr> Strides => Operands[(1 + Rank)..(1 + Rank + Rank)];
|
||||
|
||||
/// <summary>
|
||||
/// Gets the shape.
|
||||
/// </summary>
|
||||
public abstract ReadOnlySpan<Expr> Dimensions { get; }
|
||||
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context) => functor.VisitBuffer(this, context);
|
||||
|
||||
public Buffer With(MemSpan? memSpan = null, Expr[]? dimensions = null, Expr[]? strides = null)
|
||||
=> new Buffer(Name, ElemType, memSpan ?? MemSpan, dimensions ?? Dimensions.ToArray(), strides ?? Strides.ToArray());
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override bool Equals(object? obj)
|
||||
{
|
||||
if (obj is not Buffer other)
|
||||
if (ReferenceEquals(this, obj))
|
||||
{
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (Const is not null && !Const.Equals(other.Const))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return string.Equals(Name, other.Name, StringComparison.Ordinal) &&
|
||||
ElemType.Equals(other.ElemType) &&
|
||||
MemLocation.Equals(other.MemLocation) &&
|
||||
Rank.Equals(other.Rank) &&
|
||||
base.Equals(obj);
|
||||
return obj is TIR.Buffer other && GetHashCode() == other.GetHashCode() && Name == other.Name && ElemType == other.ElemType && Rank == other.Rank && Operands.SequenceEqual(other.Operands);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// the logical buffer.
|
||||
/// </summary>
|
||||
public sealed class LogicalBuffer : Buffer
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="LogicalBuffer"/> class.
|
||||
/// create from the IRType.
|
||||
/// </summary>
|
||||
/// <param name="name">the name.</param>
|
||||
/// <param name="location">the location.</param>
|
||||
/// <param name="elemType">prim type.</param>
|
||||
/// <param name="dimensions">the shape.</param>
|
||||
/// <param name="strides">the strides.</param>
|
||||
public LogicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan<Expr> dimensions, ReadOnlySpan<Expr> strides)
|
||||
: base(name, elemType, location, ArrayUtility.Concat(dimensions, strides))
|
||||
{
|
||||
Rank = dimensions.Length;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="LogicalBuffer"/> class.
|
||||
/// <see cref="LogicalBuffer"/>.
|
||||
/// </summary>
|
||||
public LogicalBuffer(string name, Schedule.MemoryLocation location, TensorConst tensor)
|
||||
: this(name, tensor.Value.ElementType, location, ArrayUtility.ToExprArray(tensor.Value.Dimensions), ArrayUtility.ToExprArray(tensor.Value.Strides))
|
||||
{
|
||||
Const = tensor;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="LogicalBuffer"/> class.
|
||||
/// <seealso cref="LogicalBuffer"/>
|
||||
/// </summary>
|
||||
public LogicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan<Expr> dimensions)
|
||||
: this(name, elemType, location, dimensions, TensorUtilities.GetStrides(dimensions))
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets get the total length.
|
||||
/// </summary>
|
||||
public Expr Length => TensorUtilities.GetProduct(Dimensions);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the shape.
|
||||
/// </summary>
|
||||
public override ReadOnlySpan<Expr> Dimensions => Operands[0..Rank];
|
||||
|
||||
/// <summary>
|
||||
/// Gets the strides.
|
||||
/// </summary>
|
||||
public override ReadOnlySpan<Expr> Strides => Operands[Rank..];
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override int Rank { get; }
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string ToString()
|
||||
{
|
||||
return $"LogicalBuffer({Name}, {ElemType}, {nameof(MemLocation)})";
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
|
||||
=> functor.VisitLogicalBuffer(this, context);
|
||||
|
||||
public LogicalBuffer With(string? name = null, DataType? elemType = null, Schedule.MemoryLocation? location = null, Expr[]? dimensions = null, Expr[]? strides = null)
|
||||
=> new LogicalBuffer(name ?? Name, elemType ?? ElemType, location ?? MemLocation, dimensions ?? Dimensions, strides ?? Strides) { Const = Const };
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// the physical buffer.
|
||||
/// </summary>
|
||||
public sealed class PhysicalBuffer : Buffer
|
||||
{
|
||||
private readonly int[] _fixedDimensions;
|
||||
private readonly int[] _fixedStrides;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PhysicalBuffer"/> class.
|
||||
/// ctor for physical buffer.
|
||||
/// </summary>
|
||||
public PhysicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan<int> dimensions, ReadOnlySpan<int> strides, int start, int size)
|
||||
: base(name, elemType, location, Array.Empty<Expr>())
|
||||
{
|
||||
Start = start;
|
||||
Size = size;
|
||||
_fixedDimensions = dimensions.ToArray();
|
||||
_fixedStrides = strides.ToArray();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PhysicalBuffer"/> class.
|
||||
/// <see cref="PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
public PhysicalBuffer(string name, DataType elemType, Schedule.MemoryLocation location, ReadOnlySpan<int> dimensions, int start, int size)
|
||||
: this(name, elemType, location, dimensions, TensorUtilities.GetStrides(dimensions), start, size)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="PhysicalBuffer"/> class.
|
||||
/// <see cref="PhysicalBuffer"/>.
|
||||
/// </summary>
|
||||
public PhysicalBuffer(string name, Schedule.MemoryLocation location, TensorConst tensor, int start, int size)
|
||||
: this(name, tensor.Value.ElementType, location, tensor.Value.Dimensions, tensor.Value.Strides, start, size)
|
||||
{
|
||||
Const = tensor;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets fixed dimensions.
|
||||
/// </summary>
|
||||
public ReadOnlySpan<int> FixedDimensions => _fixedDimensions;
|
||||
|
||||
/// <summary>
|
||||
/// Gets fixed strides.
|
||||
/// </summary>
|
||||
public ReadOnlySpan<int> FixedStrides => _fixedStrides;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets start.
|
||||
/// </summary>
|
||||
public int Start { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets total size in bytes.
|
||||
/// </summary>
|
||||
public int Size { get; init; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets dimensions.
|
||||
/// </summary>
|
||||
public override ReadOnlySpan<Expr> Dimensions => ArrayUtility.ToExprArray(FixedDimensions);
|
||||
|
||||
/// <summary>
|
||||
/// Gets strides.
|
||||
/// </summary>
|
||||
public override ReadOnlySpan<Expr> Strides => ArrayUtility.ToExprArray(FixedStrides);
|
||||
|
||||
/// <summary>
|
||||
/// Gets shape.
|
||||
/// </summary>
|
||||
public Shape Shape => new Shape(FixedDimensions);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override int Rank => FixedDimensions.Length;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override string ToString()
|
||||
{
|
||||
return $"PhysicalBuffer({Name}, {ElemType}, {nameof(MemLocation)})";
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override bool Equals(object? obj)
|
||||
{
|
||||
return base.Equals(obj) && obj is PhysicalBuffer other &&
|
||||
FixedDimensions.SequenceEqual(other.FixedDimensions) &&
|
||||
FixedStrides.SequenceEqual(other.FixedStrides);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
|
||||
=> functor.VisitPhysicalBuffer(this, context);
|
||||
|
||||
public PhysicalBuffer With(string? name = null, DataType? elemType = null, Schedule.MemoryLocation? location = null, int[]? dimensions = null, int[]? strides = null, int? start = null, int? size = null)
|
||||
=> new PhysicalBuffer(name ?? Name, elemType ?? ElemType, location ?? MemLocation, dimensions ?? FixedDimensions, strides ?? FixedStrides, start ?? Start, size ?? Size) { Const = Const };
|
||||
|
||||
protected override int GetHashCodeCore() => HashCode.Combine(Name, ElemType, Rank, base.GetHashCodeCore());
|
||||
}
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR;
|
||||
using Nncase.Utilities;
|
||||
|
||||
namespace Nncase.TIR;
|
||||
|
||||
/// <summary>
|
||||
/// Buffer load node.
|
||||
/// </summary>
|
||||
public sealed class BufferLoad : Expr
|
||||
{
|
||||
public BufferLoad(PhysicalBuffer buffer, ReadOnlySpan<Expr> indices)
|
||||
: base(ArrayUtility.Concat(buffer, indices))
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the buffer to be loaded.
|
||||
/// </summary>
|
||||
public PhysicalBuffer Buffer => (PhysicalBuffer)Operands[0];
|
||||
|
||||
/// <summary>
|
||||
/// Gets the buffer indices.
|
||||
/// </summary>
|
||||
public ReadOnlySpan<Expr> Indices => Operands.Slice(1);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
|
||||
=> functor.VisitBufferLoad(this, context);
|
||||
|
||||
public BufferLoad With(PhysicalBuffer? buffer = null, Expr[]? indices = null)
|
||||
=> new BufferLoad(buffer ?? Buffer, indices ?? Indices);
|
||||
}
|
|
@ -1,47 +0,0 @@
|
|||
// Copyright (c) Canaan Inc. All rights reserved.
|
||||
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Nncase.IR;
|
||||
|
||||
namespace Nncase.TIR;
|
||||
|
||||
/// <summary>
|
||||
/// Buffer store node.
|
||||
/// </summary>
|
||||
public sealed class BufferStore : Expr
|
||||
{
|
||||
private readonly int _indicesCount;
|
||||
|
||||
public BufferStore(PhysicalBuffer buffer, ReadOnlySpan<Expr> indices, Expr value)
|
||||
: base(new Expr[] { buffer }.Concat(indices.ToArray()).Append(value).ToArray())
|
||||
{
|
||||
_indicesCount = indices.Length;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the buffer.
|
||||
/// </summary>
|
||||
public PhysicalBuffer Buffer => (PhysicalBuffer)Operands[0];
|
||||
|
||||
/// <summary>
|
||||
/// Gets the value we to be stored.
|
||||
/// </summary>
|
||||
public ReadOnlySpan<Expr> Indices => Operands[1.._indicesCount];
|
||||
|
||||
/// <summary>
|
||||
/// Gets the indices location to be stored.
|
||||
/// </summary>
|
||||
public Expr Value => Operands[_indicesCount + 1];
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override TExprResult Accept<TExprResult, TTypeResult, TContext>(ExprFunctor<TExprResult, TTypeResult, TContext> functor, TContext context)
|
||||
=> functor.VisitBufferStore(this, context);
|
||||
|
||||
public BufferStore With(PhysicalBuffer? buffer = null, Expr[]? indices = null, Expr? value = null)
|
||||
=> new BufferStore(buffer ?? Buffer, indices ?? Indices, value ?? Value);
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue