mirror of https://github.com/kendryte/nncase.git
add csharp runtime export
parent
a736e2df03
commit
7fcd4efc4e
|
@ -32,6 +32,8 @@ project(nncase
|
|||
option(ENABLE_OPENMP "OpenMP support" ON)
|
||||
option(ENABLE_HALIDE "halide kernels support" ON)
|
||||
option(BUILD_PYTHON_BINDING "Build python binding" ON)
|
||||
option(BUILD_CSHARP_BINDING "Build csharp binding" ON)
|
||||
option(BUILD_BENCHMARK "Build benchmark programs" ON)
|
||||
option(BUILD_TESTING "Build test programs" OFF)
|
||||
|
||||
if (BUILDING_RUNTIME)
|
||||
|
@ -57,7 +59,7 @@ else() # in user space
|
|||
message(STATUS "Auto Cmake Conan Installation")
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/conan.cmake)
|
||||
conan_cmake_run(CONANFILE conanfile.py
|
||||
BASIC_SETUP TARGETS
|
||||
BASIC_SETUP
|
||||
OPTIONS ${CONAN_OPTS}
|
||||
SETTINGS ${CONAN_SETTINGS}
|
||||
BUILD missing)
|
||||
|
@ -74,13 +76,15 @@ if (BUILDING_RUNTIME)
|
|||
|
||||
if (MSVC)
|
||||
add_definitions(/D_CRT_SECURE_NO_WARNINGS /DNOMINMAX)
|
||||
add_compile_options(/wd4267 /wd4251 /FC /utf-8 /W3 /WX)
|
||||
add_compile_options(/wd4267 /wd4251 /wd4244 /FC /utf-8 /W3 /WX)
|
||||
else()
|
||||
add_compile_options(-Wall -Wextra -pedantic -Werror -Wno-multichar -Wno-missing-field-initializers -Wno-unused-function -Wno-type-limits)
|
||||
if (APPLE)
|
||||
add_compile_options(-Wno-four-char-constants -Wno-sometimes-uninitialized)
|
||||
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
add_compile_options(-Wno-uninitialized -Wno-unused-private-field)
|
||||
else ()
|
||||
add_compile_options(-Wno-maybe-uninitialized)
|
||||
add_compile_options(-Wno-maybe-uninitialized -Wno-unused-private-field)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
@ -90,12 +94,20 @@ if (BUILDING_RUNTIME)
|
|||
add_subdirectory(include/nncase)
|
||||
add_subdirectory(src/kernels)
|
||||
add_subdirectory(src/runtime)
|
||||
add_subdirectory(benchmark)
|
||||
add_subdirectory(src/functional)
|
||||
if(BUILD_BENCHMARK)
|
||||
add_subdirectory(benchmark)
|
||||
endif()
|
||||
|
||||
# Python binding
|
||||
if(BUILD_PYTHON_BINDING)
|
||||
add_subdirectory(python/nncaseruntime/native)
|
||||
endif()
|
||||
|
||||
# Csharp binding
|
||||
if(BUILD_CSHARP_BINDING)
|
||||
add_subdirectory(csharp)
|
||||
endif()
|
||||
|
||||
install(DIRECTORY ${NNCASE_INCLUDE_DIR}/nncase
|
||||
DESTINATION include
|
||||
|
@ -145,12 +157,15 @@ else()
|
|||
|
||||
if (MSVC)
|
||||
add_definitions(/D_SILENCE_ALL_CXX17_DEPRECATION_WARNINGS /D_CRT_SECURE_NO_WARNINGS /DNOMINMAX)
|
||||
add_compile_options(/wd4267 /wd4251 /FC /utf-8 /W3 /WX)
|
||||
add_compile_options(/wd4267 /wd4251 /wd4244 /FC /utf-8 /W3 /WX)
|
||||
set(PYBIND11_CPP_STANDARD "/std:c++latest")
|
||||
else()
|
||||
add_compile_options(-fvisibility=hidden)
|
||||
add_compile_options(-Wall -Wextra -pedantic -Werror -Wno-multichar -Wno-missing-field-initializers -Wno-unused-function -Wno-type-limits -Wno-unused-local-typedefs -Wno-sign-compare)
|
||||
if (APPLE)
|
||||
add_compile_options(-Wno-four-char-constants -Wno-sometimes-uninitialized -Wno-deprecated)
|
||||
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
add_compile_options(-Wno-uninitialized)
|
||||
else ()
|
||||
add_compile_options(-Wno-maybe-uninitialized -Wno-deprecated-copy)
|
||||
add_link_options(-Wl,--exclude-libs,ALL)
|
||||
|
@ -165,23 +180,24 @@ else()
|
|||
add_subdirectory(src/data)
|
||||
add_subdirectory(src/ir)
|
||||
add_subdirectory(src/importer)
|
||||
#add_subdirectory(src/schedule)
|
||||
#add_subdirectory(src/evaluator)
|
||||
add_subdirectory(src/schedule)
|
||||
add_subdirectory(src/evaluator)
|
||||
add_subdirectory(src/functional)
|
||||
add_subdirectory(src/transforms)
|
||||
#add_subdirectory(src/codegen)
|
||||
#add_subdirectory(src/kernels)
|
||||
#add_subdirectory(src/runtime)
|
||||
add_subdirectory(src/codegen)
|
||||
add_subdirectory(src/kernels)
|
||||
add_subdirectory(src/runtime)
|
||||
add_subdirectory(src/targets)
|
||||
#add_subdirectory(src/plugin)
|
||||
add_subdirectory(src/plugin)
|
||||
add_subdirectory(src/cli)
|
||||
|
||||
if(BUILD_TESTING)
|
||||
# add_subdirectory(tests/kernels)
|
||||
add_subdirectory(tests/kernels)
|
||||
endif()
|
||||
|
||||
# Python binding
|
||||
if(BUILD_PYTHON_BINDING)
|
||||
# add_subdirectory(python/nncase/native)
|
||||
add_subdirectory(python/nncase/native)
|
||||
endif()
|
||||
|
||||
# Thrid party
|
||||
|
@ -220,11 +236,11 @@ else()
|
|||
)
|
||||
|
||||
# Targets
|
||||
#add_subdirectory(targets/cpu)
|
||||
#add_subdirectory(targets/k210)
|
||||
#add_subdirectory(targets/vulkan)
|
||||
add_subdirectory(targets/cpu)
|
||||
add_subdirectory(targets/k210)
|
||||
add_subdirectory(targets/vulkan)
|
||||
endif()
|
||||
|
||||
# Modules
|
||||
#add_subdirectory(modules/k210)
|
||||
#add_subdirectory(modules/vulkan)
|
||||
add_subdirectory(modules/k210)
|
||||
add_subdirectory(modules/vulkan)
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
# but it is only necessary on the end-user side. It is not necessary to create conan
|
||||
# packages, in fact it shouldn't be use for that. Check the project documentation.
|
||||
|
||||
# version: 0.16.1
|
||||
# version: 0.18.0-dev
|
||||
|
||||
include(CMakeParseArguments)
|
||||
|
||||
|
@ -55,6 +55,8 @@ function(_get_msvc_ide_version result)
|
|||
set(${result} 15 PARENT_SCOPE)
|
||||
elseif(NOT MSVC_VERSION VERSION_LESS 1920 AND MSVC_VERSION VERSION_LESS 1930)
|
||||
set(${result} 16 PARENT_SCOPE)
|
||||
elseif(NOT MSVC_VERSION VERSION_LESS 1930 AND MSVC_VERSION VERSION_LESS 1940)
|
||||
set(${result} 17 PARENT_SCOPE)
|
||||
else()
|
||||
message(FATAL_ERROR "Conan: Unknown MSVC compiler version [${MSVC_VERSION}]")
|
||||
endif()
|
||||
|
@ -126,6 +128,10 @@ macro(_conan_detect_compiler)
|
|||
set(_CONAN_SETTING_ARCH ${ARGUMENTS_ARCH})
|
||||
endif()
|
||||
|
||||
if(USING_CXX)
|
||||
set(_CONAN_SETTING_COMPILER_CPPSTD ${CMAKE_CXX_STANDARD})
|
||||
endif()
|
||||
|
||||
if (${CMAKE_${LANGUAGE}_COMPILER_ID} STREQUAL GNU)
|
||||
# using GCC
|
||||
# TODO: Handle other params
|
||||
|
@ -415,7 +421,8 @@ endfunction()
|
|||
|
||||
function(_collect_settings result)
|
||||
set(ARGUMENTS_PROFILE_AUTO arch build_type compiler compiler.version
|
||||
compiler.runtime compiler.libcxx compiler.toolset)
|
||||
compiler.runtime compiler.libcxx compiler.toolset
|
||||
compiler.cppstd)
|
||||
foreach(ARG ${ARGUMENTS_PROFILE_AUTO})
|
||||
string(TOUPPER ${ARG} _arg_name)
|
||||
string(REPLACE "." "_" _arg_name ${_arg_name})
|
||||
|
@ -427,10 +434,10 @@ function(_collect_settings result)
|
|||
endfunction()
|
||||
|
||||
function(conan_cmake_autodetect detected_settings)
|
||||
_conan_detect_build_type()
|
||||
_conan_detect_build_type(${ARGV})
|
||||
_conan_check_system_name()
|
||||
_conan_check_language()
|
||||
_conan_detect_compiler()
|
||||
_conan_detect_compiler(${ARGV})
|
||||
_collect_settings(collected_settings)
|
||||
set(${detected_settings} ${collected_settings} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
@ -794,7 +801,6 @@ macro(conan_check)
|
|||
if(NOT CONAN_CMD AND CONAN_REQUIRED)
|
||||
message(FATAL_ERROR "Conan executable not found! Please install conan.")
|
||||
endif()
|
||||
|
||||
if(NOT CONAN_DETECT_QUIET)
|
||||
message(STATUS "Conan: Found program ${CONAN_CMD}")
|
||||
endif()
|
||||
|
@ -803,13 +809,13 @@ macro(conan_check)
|
|||
OUTPUT_VARIABLE CONAN_VERSION_OUTPUT
|
||||
ERROR_VARIABLE CONAN_VERSION_OUTPUT)
|
||||
|
||||
message("a $ENV{PATH}")
|
||||
if(NOT "${return_code}" STREQUAL "0")
|
||||
message(FATAL_ERROR "Conan --version failed='${return_code}'")
|
||||
endif()
|
||||
|
||||
if(NOT CONAN_DETECT_QUIET)
|
||||
message(STATUS "Conan: Version found ${CONAN_VERSION_OUTPUT}")
|
||||
string(STRIP "${CONAN_VERSION_OUTPUT}" _CONAN_VERSION_OUTPUT)
|
||||
message(STATUS "Conan: Version found ${_CONAN_VERSION_OUTPUT}")
|
||||
endif()
|
||||
|
||||
if(DEFINED CONAN_VERSION)
|
||||
|
@ -839,7 +845,7 @@ function(conan_add_remote)
|
|||
if(DEFINED CONAN_COMMAND)
|
||||
set(CONAN_CMD ${CONAN_COMMAND})
|
||||
else()
|
||||
conan_check(REQUIRED)
|
||||
conan_check(REQUIRED DETECT_QUIET)
|
||||
endif()
|
||||
set(CONAN_VERIFY_SSL_ARG "True")
|
||||
if(DEFINED CONAN_VERIFY_SSL)
|
||||
|
|
|
@ -20,7 +20,7 @@ _SET_CONANOPT(CONAN_OPTS "halide" ENABLE_HALIDE)
|
|||
|
||||
if (NOT DEFINED CMAKE_CXX_STANDARD)
|
||||
if (BUILDING_RUNTIME)
|
||||
set (CMAKE_CXX_STANDARD 14)
|
||||
set (CMAKE_CXX_STANDARD 17)
|
||||
else ()
|
||||
set (CMAKE_CXX_STANDARD 20)
|
||||
endif ()
|
||||
|
|
|
@ -26,7 +26,6 @@ if (NOT BUILDING_RUNTIME)
|
|||
find_package(libzippp REQUIRED)
|
||||
find_package(inja REQUIRED)
|
||||
find_package(shaderc REQUIRED)
|
||||
find_package(range-v3 REQUIRED)
|
||||
endif ()
|
||||
|
||||
if (BUILD_TESTING)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
include(${CMAKE_CURRENT_LIST_DIR}/nncasefunctionalTargets.cmake)
|
|
@ -0,0 +1,9 @@
|
|||
cmake_minimum_required (VERSION 3.18)
|
||||
|
||||
set(SRCS interpreter.cpp)
|
||||
|
||||
add_library(nncaseruntime_csharp MODULE ${SRCS})
|
||||
target_link_libraries(nncaseruntime_csharp PRIVATE nncaseruntime)
|
||||
target_include_directories(nncaseruntime_csharp PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
# set_target_properties(nncaseruntime_csharp PROPERTIES OUTPUT_NAME _nncaseruntime_csharp)
|
||||
install(TARGETS nncaseruntime_csharp DESTINATION lib)
|
|
@ -0,0 +1,230 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <nncase/runtime/interpreter.h>
|
||||
#include <nncase/runtime/runtime_op_utility.h>
|
||||
#include <nncase/runtime/runtime_tensor.h>
|
||||
#include <nncase/runtime/runtime_tensor_impl.h>
|
||||
#include <nncase/version.h>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include "stdprefix.h"
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
|
||||
template <typename T>
|
||||
datatype_t from_dtype()
|
||||
{
|
||||
if (std::is_same_v<T, float>)
|
||||
return dt_uint8;
|
||||
else if (std::is_same_v<T, uint16_t>)
|
||||
return dt_uint16;
|
||||
else if (std::is_same_v<T, uint32_t>)
|
||||
return dt_uint32;
|
||||
else if (std::is_same_v<T, uint64_t>)
|
||||
return dt_uint64;
|
||||
else if (std::is_same_v<T, int8_t>)
|
||||
return dt_int8;
|
||||
else if (std::is_same_v<T, int16_t>)
|
||||
return dt_int16;
|
||||
else if (std::is_same_v<T, int32_t>)
|
||||
return dt_int32;
|
||||
else if (std::is_same_v<T, int64_t>)
|
||||
return dt_int64;
|
||||
else if (std::is_same_v<T, std::bfloat16>)
|
||||
throw std::runtime_error("Unsupported float16");
|
||||
else if (std::is_same_v<T, float>)
|
||||
return dt_float32;
|
||||
else if (std::is_same_v<T, double>)
|
||||
return dt_float64;
|
||||
throw std::runtime_error("Unsupported dtype");
|
||||
}
|
||||
|
||||
inline runtime_shape_t to_rt_shape(const int *shape_ptr, int shape_size)
|
||||
{
|
||||
runtime_shape_t shape(shape_size);
|
||||
for (size_t i = 0; i < shape.size(); i++)
|
||||
shape[i] = (size_t)shape_ptr[i];
|
||||
return shape;
|
||||
}
|
||||
|
||||
inline runtime_shape_t to_rt_strides(size_t elemsize, const int *stride_ptr, int stride_size)
|
||||
{
|
||||
runtime_shape_t strides(stride_size);
|
||||
for (size_t i = 0; i < strides.size(); i++)
|
||||
strides[i] = (size_t)stride_ptr[i] / elemsize;
|
||||
|
||||
return strides;
|
||||
}
|
||||
|
||||
inline std::vector<size_t> to_py_strides(size_t elemsize, const runtime_shape_t &value)
|
||||
{
|
||||
std::vector<size_t> strides(value.size());
|
||||
for (size_t i = 0; i < strides.size(); i++)
|
||||
strides[i] = (size_t)value[i] * elemsize;
|
||||
return strides;
|
||||
}
|
||||
extern "C"
|
||||
{
|
||||
struct RuntimeTensor
|
||||
{
|
||||
void *impl;
|
||||
};
|
||||
}
|
||||
|
||||
inline std::shared_ptr<nncase::runtime::detail::runtime_tensor_impl> to_rt_impl(RuntimeTensor rt)
|
||||
{
|
||||
return std::shared_ptr<nncase::runtime::detail::runtime_tensor_impl>(static_cast<nncase::runtime::detail::runtime_tensor_impl *>(rt.impl));
|
||||
}
|
||||
|
||||
inline runtime_tensor to_rt(RuntimeTensor rt)
|
||||
{
|
||||
return runtime_tensor(to_rt_impl(rt));
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief create tensor form buffer
|
||||
*
|
||||
* @param buffer_ptr the buffer ptr type is dtype
|
||||
* @param datatype the datatype enum value.
|
||||
* @param shape_ptr the shape pointer
|
||||
* @param shape_size the shape array size
|
||||
* @param total_items the total elements counts
|
||||
* @param item_size the each element total bytes
|
||||
* @param stride_ptr the stide pointer( by element)
|
||||
* @return void*
|
||||
*/
|
||||
EXPORT_API(RuntimeTensor) RuntimeTensor_from_buffer(
|
||||
const uint8_t *buffer_ptr, uint8_t datatype, int *shape_ptr,
|
||||
int shape_size, size_t total_items,
|
||||
size_t item_size, int *stride_ptr)
|
||||
{
|
||||
auto tensor = host_runtime_tensor::create(
|
||||
(datatype_t)datatype,
|
||||
to_rt_shape(shape_ptr, shape_size),
|
||||
to_rt_strides(item_size, stride_ptr, shape_size),
|
||||
gsl::make_span((gsl::byte *)(buffer_ptr),
|
||||
total_items * item_size),
|
||||
[=](gsl::byte *) {})
|
||||
.unwrap_or_throw();
|
||||
return RuntimeTensor { .impl = tensor.impl() };
|
||||
}
|
||||
|
||||
EXPORT_API(void) RuntimeTensor_to_buffer(RuntimeTensor rt, uint8_t *buffer_ptr, uint8_t *datatype_ptr)
|
||||
{
|
||||
auto rtensor = to_rt(rt);
|
||||
if (!rtensor.is_contiguous())
|
||||
{
|
||||
throw std::errc::not_supported;
|
||||
}
|
||||
*datatype_ptr = rtensor.datatype();
|
||||
|
||||
auto host = rtensor.as_host().unwrap_or_throw();
|
||||
auto src_map = std::move(hrt::map(host, hrt::map_read).unwrap_or_throw());
|
||||
auto src_buffer = src_map.buffer();
|
||||
memcpy(buffer_ptr, src_buffer.data(), src_buffer.size_bytes());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief copy the runtime Tensor
|
||||
*
|
||||
* @param from
|
||||
* @param to
|
||||
*/
|
||||
EXPORT_API(void) RuntimeTensor_copy_to(RuntimeTensor from, RuntimeTensor to)
|
||||
{
|
||||
auto rt = to_rt(to);
|
||||
to_rt_impl(from)->copy_to(rt).unwrap_or_throw();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief get the runtime tensor datatype enum value.
|
||||
*
|
||||
* @return uint8_t
|
||||
*/
|
||||
EXPORT_API(uint8_t) RuntimeTensor_dtype(RuntimeTensor rt)
|
||||
{
|
||||
uint8_t res;
|
||||
switch (to_rt_impl(rt)->datatype())
|
||||
{
|
||||
case dt_int8:
|
||||
res = 0;
|
||||
break;
|
||||
case dt_int16:
|
||||
res = 1;
|
||||
break;
|
||||
case dt_int32:
|
||||
res = 2;
|
||||
break;
|
||||
case dt_int64:
|
||||
res = 3;
|
||||
break;
|
||||
case dt_uint8:
|
||||
res = 4;
|
||||
break;
|
||||
case dt_uint16:
|
||||
res = 5;
|
||||
break;
|
||||
case dt_uint32:
|
||||
res = 6;
|
||||
break;
|
||||
case dt_uint64:
|
||||
res = 7;
|
||||
break;
|
||||
case dt_float16:
|
||||
res = 8;
|
||||
break;
|
||||
case dt_float32:
|
||||
res = 9;
|
||||
break;
|
||||
case dt_float64:
|
||||
res = 10;
|
||||
break;
|
||||
case dt_bfloat16:
|
||||
res = 11;
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Not Support The DataType");
|
||||
break;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief get the shape, if shape_ptr is null, just return
|
||||
*
|
||||
* @param rt
|
||||
* @param shape_ptr
|
||||
* @return int
|
||||
*/
|
||||
EXPORT_API(int) RuntimeTensor_shape(RuntimeTensor rt, size_t *shape_ptr)
|
||||
{
|
||||
auto rt_shape = to_rt_impl(rt)->shape();
|
||||
if (shape_ptr != nullptr)
|
||||
{
|
||||
for (size_t i = 0; i < rt_shape.size(); i++)
|
||||
{
|
||||
shape_ptr[i] = rt_shape[i];
|
||||
}
|
||||
}
|
||||
return rt_shape.size();
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
#! /bin/bash
|
||||
mkdir -p ../build/runtime
|
||||
cmake -S .. -B ../build/runtime \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_C_COMPILER=/usr/bin/clang \
|
||||
-DCMAKE_CXX_COMPILER=/usr/bin/clang++ \
|
||||
-DENABLE_HALIDE=false \
|
||||
-DENABLE_OPENMP=false \
|
||||
-DCMAKE_EXPORT_COMPILE_COMMANDS=true \
|
||||
-DENABLE_VULKAN_RUNTIME=false \
|
||||
-DBUILDING_RUNTIME=true \
|
||||
-G "Unix Makefiles" \
|
||||
-DCMAKE_INSTALL_PREFIX:PATH=../runtime_install
|
||||
cmake --build ../build/runtime --target install
|
|
@ -0,0 +1,128 @@
|
|||
#include "RuntimeTensor.h"
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <nncase/runtime/interpreter.h>
|
||||
#include <nncase/runtime/runtime_op_utility.h>
|
||||
#include <nncase/version.h>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
|
||||
static std::unique_ptr<interpreter> _interp;
|
||||
|
||||
|
||||
/**
|
||||
* @brief init the interpreter
|
||||
*
|
||||
*/
|
||||
EXPORT_API(void) interpreter_init()
|
||||
{
|
||||
if (!_interp)
|
||||
{
|
||||
_interp = std::make_unique<nncase::runtime::interpreter>();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief load model
|
||||
*
|
||||
* @param buffer_ptr the buffer array
|
||||
* @param size buffer length
|
||||
*/
|
||||
EXPORT_API(void) interpreter_load_model(const uint8_t *buffer_ptr, int size)
|
||||
{
|
||||
auto buffer = gsl::span<const gsl::byte>((const gsl::byte *)(buffer_ptr), size);
|
||||
_interp->load_model(buffer).unwrap_or_throw();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief get the model inputs size
|
||||
*
|
||||
* @return size_t
|
||||
*/
|
||||
EXPORT_API(size_t) interpreter_inputs_size()
|
||||
{
|
||||
return _interp->inputs_size();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief get the model outputs size
|
||||
*
|
||||
* @return size_t
|
||||
*/
|
||||
EXPORT_API(size_t) interpreter_outputs_size()
|
||||
{
|
||||
return _interp->outputs_size();
|
||||
}
|
||||
/**
|
||||
* @brief get the input memory range desc
|
||||
*
|
||||
* @param index input number
|
||||
* @return memory_range
|
||||
*/
|
||||
EXPORT_API(memory_range) interpreter_get_input_desc(size_t index)
|
||||
{
|
||||
return _interp->input_desc(index);
|
||||
}
|
||||
/**
|
||||
* @brief get the output memory range desc
|
||||
*
|
||||
* @param index output number
|
||||
* @return memory_range
|
||||
*/
|
||||
EXPORT_API(memory_range) interpreter_get_output_desc(size_t index)
|
||||
{
|
||||
return _interp->output_desc(index);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief get the input tensor impl pointer
|
||||
*
|
||||
* @param index input number
|
||||
* @return void* the runtime_tensor impl pointer
|
||||
*/
|
||||
EXPORT_API(RuntimeTensor) interpreter_get_input_tensor(size_t index)
|
||||
{
|
||||
return RuntimeTensor { .impl = _interp->input_tensor(index).unwrap_or_throw().impl() };
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief set the input tensor
|
||||
*
|
||||
* @param index
|
||||
* @param rt
|
||||
*/
|
||||
EXPORT_API(void) interpreter_set_input_tensor(size_t index, RuntimeTensor rt)
|
||||
{
|
||||
_interp->input_tensor(index, to_rt(rt)).unwrap_or_throw();
|
||||
}
|
||||
|
||||
EXPORT_API(RuntimeTensor) interpreter_get_output_tensor(size_t index)
|
||||
{
|
||||
return RuntimeTensor { .impl = _interp->output_tensor(index).unwrap_or_throw().impl() };
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief set the output tensor
|
||||
*
|
||||
* @param index
|
||||
* @param rt
|
||||
*/
|
||||
EXPORT_API(void) interpreter_set_output_tensor(size_t index, RuntimeTensor rt)
|
||||
{
|
||||
_interp->input_tensor(index, to_rt(rt)).unwrap_or_throw();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief call the interpreter
|
||||
*
|
||||
*/
|
||||
EXPORT_API(void) interpreter_run()
|
||||
{
|
||||
_interp->run().unwrap_or_throw();
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
#pragma once
|
||||
#include <limits>
|
||||
#include <assert.h>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
#define DEBUG_ONLY(x) (void)(x)
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <intrin.h>
|
||||
|
||||
#define EXPORT_API(ret) extern "C" __declspec(dllexport) ret
|
||||
#else
|
||||
#include "unixprefix.h"
|
||||
|
||||
#define EXPORT_API(ret) extern "C" __attribute__((visibility("default"))) ret
|
||||
|
||||
#define __forceinline __attribute__((always_inline)) inline
|
||||
#endif
|
|
@ -0,0 +1,117 @@
|
|||
// Ignore SAL annotations for non-Windows build.
|
||||
#define _In_
|
||||
#define _Out_
|
||||
#define _Inout_
|
||||
#define _In_z_
|
||||
#define _Inout_z_
|
||||
#define _In_reads_(s)
|
||||
#define _In_reads_bytes_(s)
|
||||
#define _In_reads_z_(s)
|
||||
#define _In_reads_or_z_(s)
|
||||
#define _Out_writes_(s)
|
||||
#define _Out_writes_bytes_(s)
|
||||
#define _Out_writes_z_(s)
|
||||
#define _Inout_updates_(s)
|
||||
#define _Inout_updates_bytes_(s)
|
||||
#define _Inout_updates_z_(s)
|
||||
#define _Out_writes_to_(s,c)
|
||||
#define _Out_writes_bytes_to_(s,c)
|
||||
#define _Out_writes_all_(s)
|
||||
#define _Out_writes_bytes_all_(s)
|
||||
#define _Inout_updates_to_(s,c)
|
||||
#define _Inout_updates_bytes_to_(s,c)
|
||||
#define _Inout_updates_all_(s)
|
||||
#define _Inout_updates_bytes_all_(s)
|
||||
#define _In_reads_to_ptr_(p)
|
||||
#define _In_reads_to_ptr_z_(p)
|
||||
#define _Out_writes_to_ptr_(p)
|
||||
#define _Out_writes_to_ptr_z_(p)
|
||||
#define _In_opt_
|
||||
#define _Out_opt_
|
||||
#define _Inout_opt_
|
||||
#define _In_opt_z_
|
||||
#define _Inout_opt_z_
|
||||
#define _In_reads_opt_(s)
|
||||
#define _In_reads_bytes_opt_(s)
|
||||
#define _In_reads_opt_z_(s)
|
||||
#define _Out_writes_opt_(s)
|
||||
#define _Out_writes_opt_z_(s)
|
||||
#define _Inout_updates_opt_(s)
|
||||
#define _Inout_updates_bytes_opt_(s)
|
||||
#define _Inout_updates_opt_z_(s)
|
||||
#define _Out_writes_to_opt_(s,c)
|
||||
#define _Out_writes_bytes_to_opt_(s,c)
|
||||
#define _Out_writes_all_opt_(s)
|
||||
#define _Out_writes_bytes_all_opt_(s)
|
||||
#define _Inout_updates_to_opt_(s,c)
|
||||
#define _Inout_updates_bytes_to_opt_(s,c)
|
||||
#define _Inout_updates_all_opt_(s)
|
||||
#define _Inout_updates_bytes_all_opt_(s)
|
||||
#define _In_reads_to_ptr_opt_(p)
|
||||
#define _In_reads_to_ptr_opt_z_(p)
|
||||
#define _Out_writes_to_ptr_opt_(p)
|
||||
#define _Out_writes_to_ptr_opt_z_(p)
|
||||
#define _Outptr_
|
||||
#define _Outptr_opt_
|
||||
#define _Outptr_result_maybenull_
|
||||
#define _Outptr_opt_result_maybenull_
|
||||
#define _Outptr_result_z_
|
||||
#define _Outptr_opt_result_z_
|
||||
#define _Outptr_result_maybenull_z_
|
||||
#define _Ouptr_opt_result_maybenull_z_
|
||||
#define _COM_Outptr_
|
||||
#define _COM_Outptr_opt_
|
||||
#define _COM_Outptr_result_maybenull_
|
||||
#define _COM_Outptr_opt_result_maybenull_
|
||||
#define _Outptr_result_buffer_(s)
|
||||
#define _Outptr_result_bytebuffer_(s)
|
||||
#define _Outptr_opt_result_buffer_(s)
|
||||
#define _Outptr_opt_result_bytebuffer_(s)
|
||||
#define _Outptr_result_buffer_to_(s, c)
|
||||
#define _Outptr_result_bytebuffer_to_(s, c)
|
||||
#define _Outptr_opt_result_buffer_to_(s,c)
|
||||
#define _Outptr_opt_result_bytebuffer_to_(s,c)
|
||||
#define _Result_nullonfailure_
|
||||
#define _Result_zeroonfailure_
|
||||
#define _Outptr_result_nullonfailure_
|
||||
#define _Outptr_opt_result_nullonfailure_
|
||||
#define _Outref_result_nullonfailure_
|
||||
#define _Outref_
|
||||
#define _Outref_result_maybenull_
|
||||
#define _Outref_result_buffer_(s)
|
||||
#define _Outref_result_bytebuffer_(s)
|
||||
#define _Outref_result_buffer_to_(s, c)
|
||||
#define _Outref_result_bytebuffer_to_(s, c)
|
||||
#define _Outref_result_buffer_all_(s)
|
||||
#define _Outref_result_bytebuffer_all_(s)
|
||||
#define _Outref_result_buffer_maybenull_(s)
|
||||
#define _Outref_result_bytebuffer_maybenull_(s)
|
||||
#define _Outref_result_buffer_to_maybenull_(s, c)
|
||||
#define _Outref_result_bytebuffer_to_maybenull_(s,c)
|
||||
#define _Outref_result_buffer_all_maybenull_(s)
|
||||
#define _Outref_result_bytebuffer_all_maybenull_(s)
|
||||
#define _Ret_z_
|
||||
#define _Ret_writes_(s)
|
||||
#define _Ret_writes_bytes_(s)
|
||||
#define _Ret_writes_z_(s)
|
||||
#define _Ret_writes_to_(s,c)
|
||||
#define _Ret_writes_maybenull_(s)
|
||||
#define _Ret_writes_to_maybenull_(s)
|
||||
#define _Ret_writes_maybenull_z_(s)
|
||||
#define _Ret_maybenull_
|
||||
#define _Ret_maybenull_z_
|
||||
#define _Ret_null_
|
||||
#define _Ret_notnull_
|
||||
#define _Ret_writes_bytes_to_
|
||||
#define _Ret_writes_bytes_maybenull_
|
||||
#define _Ret_writes_bytes_to_maybenull_
|
||||
#define _In_range_(low, hi)
|
||||
#define _Out_range_(low, hi)
|
||||
#define _Ret_range_(low, hi)
|
||||
#define _Deref_in_range_(low, hi)
|
||||
#define _Deref_out_range_(low, hi)
|
||||
#define _Deref_inout_range_(low, hi)
|
||||
#define _Field_range_(low, hi)
|
||||
#define _Pre_equal_to_(expr)
|
||||
#define _Post_equal_to_(expr)
|
||||
#define _Struct_size_bytes_(size)
|
|
@ -1,24 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "runtime/datatypes.h"
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
namespace nncase {
|
||||
class NNCASE_API attribute_map {};
|
||||
} // namespace nncase
|
|
@ -90,6 +90,7 @@ public:
|
|||
virtual std::unique_ptr<section_decompiler> create_decompiler(std::string_view section_name);
|
||||
|
||||
protected:
|
||||
section *find_section(std::string_view section_name);
|
||||
void merge_to_rdata_section(std::string_view from);
|
||||
function_call_id function_id(ir::graph *graph);
|
||||
void set_current_entry_point(std::streampos pos);
|
||||
|
|
|
@ -39,6 +39,8 @@ public:
|
|||
}
|
||||
|
||||
void emit_abs() { emit_opcode(runtime::nnil_abs); }
|
||||
void emit_acos() { emit_opcode(runtime::nnil_acos); }
|
||||
void emit_asin() { emit_opcode(runtime::nnil_asin); }
|
||||
void emit_ceil() { emit_opcode(runtime::nnil_ceil); }
|
||||
void emit_cos() { emit_opcode(runtime::nnil_cos); }
|
||||
void emit_exp() { emit_opcode(runtime::nnil_exp); }
|
||||
|
@ -46,6 +48,7 @@ public:
|
|||
void emit_log() { emit_opcode(runtime::nnil_log); }
|
||||
void emit_neg() { emit_opcode(runtime::nnil_neg); }
|
||||
void emit_rsqrt() { emit_opcode(runtime::nnil_rsqrt); }
|
||||
void emit_sign() { emit_opcode(runtime::nnil_sign); }
|
||||
void emit_sin() { emit_opcode(runtime::nnil_sin); }
|
||||
void emit_sqrt() { emit_opcode(runtime::nnil_sqrt); }
|
||||
void emit_square() { emit_opcode(runtime::nnil_square); }
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/9/15 16:49:15 +08:00.
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 1/14/2022 2:00:39 PM +08:00.
|
||||
*
|
||||
* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
|
@ -1004,6 +1004,21 @@ struct op_writer<nncase::runtime::stackvm::tensor_convert_op_t>
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_cumsum_op_t>
|
||||
{
|
||||
void operator()(const nncase::runtime::stackvm::tensor_cumsum_op_t &op, binary_writer &writer) const
|
||||
{
|
||||
writer.write(static_cast<uint8_t>(op.opcode));
|
||||
writer.write(static_cast<uint16_t>(op.funct));
|
||||
writer.write(static_cast<uint8_t>(op.datatype));
|
||||
writer.write(op.rshape_src);
|
||||
writer.write(op.axis);
|
||||
writer.write(op.exclusive);
|
||||
writer.write(op.reverse);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_dequantize_op_t>
|
||||
{
|
||||
|
@ -1019,6 +1034,22 @@ struct op_writer<nncase::runtime::stackvm::tensor_dequantize_op_t>
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_equal_op_t>
|
||||
{
|
||||
void operator()(const nncase::runtime::stackvm::tensor_equal_op_t &op, binary_writer &writer) const
|
||||
{
|
||||
writer.write(static_cast<uint8_t>(op.opcode));
|
||||
writer.write(static_cast<uint16_t>(op.funct));
|
||||
writer.write(static_cast<uint8_t>(op.datatype));
|
||||
writer.write(op.rshape_src1);
|
||||
writer.write(op.rstride_src1);
|
||||
writer.write(op.rshape_src2);
|
||||
writer.write(op.rstride_src2);
|
||||
writer.write(op.rstride_dest);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_gather_op_t>
|
||||
{
|
||||
|
@ -1053,6 +1084,20 @@ struct op_writer<nncase::runtime::stackvm::tensor_gather_nd_op_t>
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_hardmax_op_t>
|
||||
{
|
||||
void operator()(const nncase::runtime::stackvm::tensor_hardmax_op_t &op, binary_writer &writer) const
|
||||
{
|
||||
writer.write(static_cast<uint8_t>(op.opcode));
|
||||
writer.write(static_cast<uint16_t>(op.funct));
|
||||
writer.write(static_cast<uint8_t>(op.datatype));
|
||||
writer.write(op.rshape_src);
|
||||
writer.write(op.rstride_src);
|
||||
writer.write(op.axis);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_lut1d_op_t>
|
||||
{
|
||||
|
@ -1068,6 +1113,20 @@ struct op_writer<nncase::runtime::stackvm::tensor_lut1d_op_t>
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_matmul_op_t>
|
||||
{
|
||||
void operator()(const nncase::runtime::stackvm::tensor_matmul_op_t &op, binary_writer &writer) const
|
||||
{
|
||||
writer.write(static_cast<uint8_t>(op.opcode));
|
||||
writer.write(static_cast<uint16_t>(op.funct));
|
||||
writer.write(op.rshape_src1);
|
||||
writer.write(op.rshape_src2);
|
||||
writer.write(op.fused_clamp_low);
|
||||
writer.write(op.fused_clamp_high);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_onehot_op_t>
|
||||
{
|
||||
|
@ -1115,6 +1174,36 @@ struct op_writer<nncase::runtime::stackvm::tensor_quantize_op_t>
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_random_normal_op_t>
|
||||
{
|
||||
void operator()(const nncase::runtime::stackvm::tensor_random_normal_op_t &op, binary_writer &writer) const
|
||||
{
|
||||
writer.write(static_cast<uint8_t>(op.opcode));
|
||||
writer.write(static_cast<uint16_t>(op.funct));
|
||||
writer.write(static_cast<uint8_t>(op.datatype_dest));
|
||||
writer.write(op.rshape_dest);
|
||||
writer.write(op.mean);
|
||||
writer.write(op.std);
|
||||
writer.write(op.seed);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_random_uniform_op_t>
|
||||
{
|
||||
void operator()(const nncase::runtime::stackvm::tensor_random_uniform_op_t &op, binary_writer &writer) const
|
||||
{
|
||||
writer.write(static_cast<uint8_t>(op.opcode));
|
||||
writer.write(static_cast<uint16_t>(op.funct));
|
||||
writer.write(static_cast<uint8_t>(op.datatype_dest));
|
||||
writer.write(op.rshape_dest);
|
||||
writer.write(op.low);
|
||||
writer.write(op.high);
|
||||
writer.write(op.seed);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_reduce_op_t>
|
||||
{
|
||||
|
@ -1132,6 +1221,40 @@ struct op_writer<nncase::runtime::stackvm::tensor_reduce_op_t>
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_reduce_arg_op_t>
|
||||
{
|
||||
void operator()(const nncase::runtime::stackvm::tensor_reduce_arg_op_t &op, binary_writer &writer) const
|
||||
{
|
||||
writer.write(static_cast<uint8_t>(op.opcode));
|
||||
writer.write(static_cast<uint16_t>(op.funct));
|
||||
writer.write(static_cast<uint8_t>(op.datatype_src));
|
||||
writer.write(op.rshape_src);
|
||||
writer.write(op.rstride_src);
|
||||
writer.write(static_cast<uint8_t>(op.datatype_dest));
|
||||
writer.write(op.rstride_dest);
|
||||
writer.write(static_cast<uint8_t>(op.reduce_arg_op));
|
||||
writer.write(op.rshape_axis);
|
||||
writer.write(op.keep_dims);
|
||||
writer.write(op.select_last_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_reduce_prod_op_t>
|
||||
{
|
||||
void operator()(const nncase::runtime::stackvm::tensor_reduce_prod_op_t &op, binary_writer &writer) const
|
||||
{
|
||||
writer.write(static_cast<uint8_t>(op.opcode));
|
||||
writer.write(static_cast<uint16_t>(op.funct));
|
||||
writer.write(op.rshape_src);
|
||||
writer.write(op.rstride_src);
|
||||
writer.write(op.rstride_dest);
|
||||
writer.write(op.rshape_axes);
|
||||
writer.write(op.keep_dims);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_reduce_window2d_op_t>
|
||||
{
|
||||
|
@ -1172,6 +1295,35 @@ struct op_writer<nncase::runtime::stackvm::tensor_resize_image_op_t>
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_roi_align_op_t>
|
||||
{
|
||||
void operator()(const nncase::runtime::stackvm::tensor_roi_align_op_t &op, binary_writer &writer) const
|
||||
{
|
||||
writer.write(static_cast<uint8_t>(op.opcode));
|
||||
writer.write(static_cast<uint16_t>(op.funct));
|
||||
writer.write(static_cast<uint8_t>(op.datatype));
|
||||
writer.write(op.rshape_src);
|
||||
writer.write(op.rshape_dest);
|
||||
writer.write(static_cast<uint8_t>(op.mode));
|
||||
writer.write(op.spatial_scale);
|
||||
writer.write(op.sampling_ratio);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_sigmoid_op_t>
|
||||
{
|
||||
void operator()(const nncase::runtime::stackvm::tensor_sigmoid_op_t &op, binary_writer &writer) const
|
||||
{
|
||||
writer.write(static_cast<uint8_t>(op.opcode));
|
||||
writer.write(static_cast<uint16_t>(op.funct));
|
||||
writer.write(static_cast<uint8_t>(op.datatype));
|
||||
writer.write(op.rshape_src);
|
||||
writer.write(op.rstride_src);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_slice_op_t>
|
||||
{
|
||||
|
@ -1189,6 +1341,59 @@ struct op_writer<nncase::runtime::stackvm::tensor_slice_op_t>
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_ternary_op_t>
|
||||
{
|
||||
void operator()(const nncase::runtime::stackvm::tensor_ternary_op_t &op, binary_writer &writer) const
|
||||
{
|
||||
writer.write(static_cast<uint8_t>(op.opcode));
|
||||
writer.write(static_cast<uint16_t>(op.funct));
|
||||
writer.write(static_cast<uint8_t>(op.datatype));
|
||||
writer.write(op.rshape_src1);
|
||||
writer.write(op.rstride_src1);
|
||||
writer.write(op.rshape_src2);
|
||||
writer.write(op.rstride_src2);
|
||||
writer.write(op.rshape_src3);
|
||||
writer.write(op.rstride_src3);
|
||||
writer.write(op.rstride_dest);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_topk_op_t>
|
||||
{
|
||||
void operator()(const nncase::runtime::stackvm::tensor_topk_op_t &op, binary_writer &writer) const
|
||||
{
|
||||
writer.write(static_cast<uint8_t>(op.opcode));
|
||||
writer.write(static_cast<uint16_t>(op.funct));
|
||||
writer.write(static_cast<uint8_t>(op.datatype));
|
||||
writer.write(op.rshape_src);
|
||||
writer.write(op.rstride_src);
|
||||
writer.write(op.rshape_dest1);
|
||||
writer.write(op.rstride_dest1);
|
||||
writer.write(op.rshape_dest2);
|
||||
writer.write(op.rstride_dest2);
|
||||
writer.write(op.k);
|
||||
writer.write(op.axis);
|
||||
writer.write(op.largest);
|
||||
writer.write(op.sorted);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_trilu_op_t>
|
||||
{
|
||||
void operator()(const nncase::runtime::stackvm::tensor_trilu_op_t &op, binary_writer &writer) const
|
||||
{
|
||||
writer.write(static_cast<uint8_t>(op.opcode));
|
||||
writer.write(static_cast<uint16_t>(op.funct));
|
||||
writer.write(static_cast<uint8_t>(op.datatype));
|
||||
writer.write(op.rshape_src);
|
||||
writer.write(op.upper);
|
||||
writer.write(op.k);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct op_writer<nncase::runtime::stackvm::tensor_unary_op_t>
|
||||
{
|
||||
|
@ -1218,7 +1423,7 @@ struct op_writer<nncase::runtime::stackvm::tensor_transpose_op_t>
|
|||
writer.write(op.rshape_perm);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class NNCASE_API op_builder
|
||||
{
|
||||
public:
|
||||
|
@ -1327,17 +1532,30 @@ public:
|
|||
void tensor_conv2d_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rshape_kernel, uint8_t rstride_kernel, uint8_t rstride_bias, uint8_t rstride_dest, uint16_t groups, uint16_t stride_h, uint16_t stride_w, uint16_t dilation_h, uint16_t dilation_w, float fused_clamp_low, float fused_clamp_high);
|
||||
void tensor_copy_(datatype_t datatype, uint8_t rshape, uint8_t rstride_src, uint8_t rstride_dest);
|
||||
void tensor_convert_(datatype_t in_datatype, datatype_t dst_datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest);
|
||||
void tensor_cumsum_(datatype_t datatype, uint8_t rshape_src, int32_t axis, bool exclusive, bool reverse);
|
||||
void tensor_dequantize_(datatype_t in_datatype, datatype_t dst_datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest);
|
||||
void tensor_equal_(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rstride_dest);
|
||||
void tensor_gather_(datatype_t datatype, uint8_t rshape_src, uint8_t rshape_dest, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_indices, uint8_t axis);
|
||||
void tensor_gather_nd_(datatype_t datatype, uint8_t rshape_src, uint8_t rshape_dest, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_indices, uint8_t batch_dims);
|
||||
void tensor_hardmax_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, int32_t axis);
|
||||
void tensor_lut1d_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint16_t table_len);
|
||||
void tensor_matmul_(uint8_t rshape_src1, uint8_t rshape_src2, float fused_clamp_low, float fused_clamp_high);
|
||||
void tensor_onehot_(datatype_t datatype, uint8_t rshape_indices, uint8_t rshape_dest, uint8_t rstride_dest, uint8_t axis, onehot_mode_t onehot_mode);
|
||||
void tensor_pad_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rpaddings, pad_mode_t pad_mode);
|
||||
void tensor_quantize_(datatype_t in_datatype, datatype_t dst_datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest);
|
||||
void tensor_random_normal_(datatype_t datatype_dest, uint8_t rshape_dest, float mean, float std, float seed);
|
||||
void tensor_random_uniform_(datatype_t datatype_dest, uint8_t rshape_dest, float low, float high, float seed);
|
||||
void tensor_reduce_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, reduce_op_t reduce_op, uint8_t rshape_axis, bool keep_dims);
|
||||
void tensor_reduce_arg_(datatype_t datatype_src, uint8_t rshape_src, uint8_t rstride_src, datatype_t datatype_dest, uint8_t rstride_dest, reduce_arg_op_t reduce_arg_op, uint8_t rshape_axis, bool keep_dims, bool select_last_idx);
|
||||
void tensor_reduce_prod_(uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_axes, bool keep_dims);
|
||||
void tensor_reduce_window2d_(datatype_t datatype, reduce_op_t reduce_op, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint16_t filter_h, uint16_t filter_w, uint16_t stride_h, uint16_t stride_w, uint16_t dilation_h, uint16_t dilation_w, float fused_clamp_low, float fused_clamp_high);
|
||||
void tensor_resize_image_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, bool align_corners, bool half_pixel_centers, image_resize_mode_t image_resize_mode);
|
||||
void tensor_roi_align_(datatype_t datatype, uint8_t rshape_src, uint8_t rshape_dest, roi_align_mode_t mode, float spatial_scale, int64_t sampling_ratio);
|
||||
void tensor_sigmoid_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src);
|
||||
void tensor_slice_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rbegins, uint8_t rends, uint8_t rstrides);
|
||||
void tensor_ternary_(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rshape_src3, uint8_t rstride_src3, uint8_t rstride_dest);
|
||||
void tensor_topk_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rshape_dest1, uint8_t rstride_dest1, uint8_t rshape_dest2, uint8_t rstride_dest2, int64_t k, int32_t axis, bool largest, bool sorted);
|
||||
void tensor_trilu_(datatype_t datatype, uint8_t rshape_src, bool upper, int64_t k);
|
||||
void tensor_unary_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, unary_op_t unary_op);
|
||||
void tensor_transpose_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_perm);
|
||||
|
||||
|
|
|
@ -36,21 +36,31 @@ struct compile_options
|
|||
{
|
||||
bool dump_ir;
|
||||
bool dump_asm;
|
||||
bool dump_quant_error;
|
||||
bool dump_import_op_range;
|
||||
bool is_fpga;
|
||||
bool use_dataset_as_input_stat = false;
|
||||
bool benchmark_only = false;
|
||||
bool preprocess = false;
|
||||
bool swapRB = false;
|
||||
std::string target;
|
||||
std::filesystem::path dump_dir;
|
||||
std::string input_type = "default";
|
||||
std::string output_type = "float32";
|
||||
std::string quant_type = "uint8";
|
||||
std::vector<float> mean { 0.f, 0.f, 0.f };
|
||||
std::vector<float> std { 1.f, 1.f, 1.f };
|
||||
std::vector<float> input_range { 0.f, 1.f };
|
||||
float letterbox_value = 0.f;
|
||||
std::vector<int32_t> input_shape {};
|
||||
std::string w_quant_type = "uint8";
|
||||
bool use_mse_quant_w = false;
|
||||
std::string input_layout = "NCHW";
|
||||
std::string output_layout = "NCHW";
|
||||
};
|
||||
|
||||
struct import_options
|
||||
{
|
||||
std::string input_layout = "NCHW";
|
||||
std::string output_layout = "NCHW";
|
||||
std::span<const std::string> output_arrays;
|
||||
};
|
||||
|
||||
|
@ -58,9 +68,6 @@ struct ptq_options_base
|
|||
{
|
||||
std::string calibrate_method = "no_clip";
|
||||
std::function<void(size_t cnt, size_t total)> progress;
|
||||
|
||||
float input_mean = 0.f;
|
||||
float input_std = 1.f;
|
||||
};
|
||||
|
||||
struct ptq_dataset_options : ptq_options_base
|
||||
|
@ -75,6 +82,23 @@ struct ptq_tensor_options : ptq_options_base
|
|||
size_t samples_count;
|
||||
};
|
||||
|
||||
struct dump_range_options_base
|
||||
{
|
||||
std::string calibrate_method = "no_clip";
|
||||
std::function<void(size_t cnt, size_t total)> progress;
|
||||
};
|
||||
|
||||
struct dump_range_dataset_options : dump_range_options_base
|
||||
{
|
||||
std::filesystem::path dataset;
|
||||
std::string dataset_format;
|
||||
};
|
||||
struct dump_range_tensor_options : dump_range_options_base
|
||||
{
|
||||
std::vector<uint8_t> tensor_data;
|
||||
size_t samples_count;
|
||||
};
|
||||
|
||||
class NNCASE_API compiler
|
||||
{
|
||||
public:
|
||||
|
@ -83,8 +107,11 @@ public:
|
|||
virtual ~compiler();
|
||||
virtual void import_tflite(std::span<const uint8_t> model, const import_options &options) = 0;
|
||||
virtual void import_onnx(std::span<const uint8_t> model, const import_options &options) = 0;
|
||||
virtual void import_caffe(std::span<const uint8_t> model, std::span<const uint8_t> prototxt) = 0;
|
||||
virtual void use_ptq(ptq_dataset_options options) = 0;
|
||||
virtual void use_ptq(ptq_tensor_options options) = 0;
|
||||
virtual void dump_range_options(dump_range_dataset_options options) = 0;
|
||||
virtual void dump_range_options(dump_range_tensor_options options) = 0;
|
||||
virtual ir::graph &graph(uint32_t stage) = 0;
|
||||
virtual nncase::target &target() = 0;
|
||||
virtual void compile() = 0;
|
||||
|
|
|
@ -116,6 +116,7 @@ protected:
|
|||
virtual void process(const std::vector<uint8_t> &src, float *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) = 0;
|
||||
virtual void process(const std::vector<uint8_t> &src, uint8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) = 0;
|
||||
virtual void process(const std::vector<uint8_t> &src, int8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) = 0;
|
||||
virtual bool do_normalize() const noexcept { return true; }
|
||||
|
||||
private:
|
||||
template <class T>
|
||||
|
@ -130,6 +131,7 @@ private:
|
|||
process(file, batch.data(), batch.shape(), input_layout_);
|
||||
|
||||
std::span<const std::filesystem::path> filenames(filenames_.data() + start, filenames_.data() + from);
|
||||
|
||||
return data_batch<T> { std::move(batch), filenames };
|
||||
}
|
||||
|
||||
|
@ -162,5 +164,6 @@ protected:
|
|||
void process(const std::vector<uint8_t> &src, float *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) override;
|
||||
void process(const std::vector<uint8_t> &src, uint8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) override;
|
||||
void process(const std::vector<uint8_t> &src, int8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) override;
|
||||
bool do_normalize() const noexcept override { return false; }
|
||||
};
|
||||
}
|
||||
|
|
|
@ -0,0 +1,293 @@
|
|||
/**
|
||||
* @file ops.h
|
||||
* @date 2021-09-27
|
||||
*
|
||||
* @copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
#pragma once
|
||||
#include <nncase/runtime/runtime_tensor.h>
|
||||
|
||||
#ifndef NNCASE_FUNCTIONAL_IMPL_PLATFORM_HEADER
|
||||
#include <nncase/functional/ops.platform.h>
|
||||
#else
|
||||
#include NNCASE_FUNCTIONAL_IMPL_PLATFORM_HEADER
|
||||
#endif
|
||||
|
||||
namespace nncase::F
|
||||
{
|
||||
|
||||
/**
|
||||
* @brief unary_square
|
||||
*
|
||||
* @param input runtime_tensor
|
||||
* @param dtype output tensor datatype
|
||||
* @return result<runtime::runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> square(runtime::runtime_tensor &input, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::unary(input, dtype, unary_square);
|
||||
}
|
||||
/**
|
||||
* @brief unary_sqrt
|
||||
*
|
||||
* @param input runtime_tensor
|
||||
* @param dtype output tensor datatype
|
||||
* @return result<runtime::runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> sqrt(runtime::runtime_tensor &input, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::unary(input, dtype, unary_sqrt);
|
||||
}
|
||||
/**
|
||||
* @brief unary_log
|
||||
*
|
||||
* @param input runtime_tensor
|
||||
* @param dtype output tensor datatype
|
||||
* @return result<runtime::runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> log(runtime::runtime_tensor &input, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::unary(input, dtype, unary_log);
|
||||
}
|
||||
/**
|
||||
* @brief unary_exp
|
||||
*
|
||||
* @param input runtime_tensor
|
||||
* @param dtype output tensor datatype
|
||||
* @return result<runtime::runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> exp(runtime::runtime_tensor &input, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::unary(input, dtype, unary_exp);
|
||||
}
|
||||
/**
|
||||
* @brief unary_sin
|
||||
*
|
||||
* @param input runtime_tensor
|
||||
* @param dtype output tensor datatype
|
||||
* @return result<runtime::runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> sin(runtime::runtime_tensor &input, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::unary(input, dtype, unary_sin);
|
||||
}
|
||||
/**
|
||||
* @brief unary_cos
|
||||
*
|
||||
* @param input runtime_tensor
|
||||
* @param dtype output tensor datatype
|
||||
* @return result<runtime::runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> cos(runtime::runtime_tensor &input, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::unary(input, dtype, unary_cos);
|
||||
}
|
||||
/**
|
||||
* @brief unary_round
|
||||
*
|
||||
* @param input runtime_tensor
|
||||
* @param dtype output tensor datatype
|
||||
* @return result<runtime::runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> round(runtime::runtime_tensor &input, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::unary(input, dtype, unary_round);
|
||||
}
|
||||
/**
|
||||
* @brief unary_floor
|
||||
*
|
||||
* @param input runtime_tensor
|
||||
* @param dtype output tensor datatype
|
||||
* @return result<runtime::runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> floor(runtime::runtime_tensor &input, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::unary(input, dtype, unary_floor);
|
||||
}
|
||||
/**
|
||||
* @brief unary_ceil
|
||||
*
|
||||
* @param input runtime_tensor
|
||||
* @param dtype output tensor datatype
|
||||
* @return result<runtime::runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> ceil(runtime::runtime_tensor &input, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::unary(input, dtype, unary_ceil);
|
||||
}
|
||||
/**
|
||||
* @brief unary_abs
|
||||
*
|
||||
* @param input runtime_tensor
|
||||
* @param dtype output tensor datatype
|
||||
* @return result<runtime::runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> abs(runtime::runtime_tensor &input, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::unary(input, dtype, unary_abs);
|
||||
}
|
||||
/**
|
||||
* @brief unary_neg
|
||||
*
|
||||
* @param input runtime_tensor
|
||||
* @param dtype output tensor datatype
|
||||
* @return result<runtime::runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> neg(runtime::runtime_tensor &input, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::unary(input, dtype, unary_neg);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief binary add
|
||||
* temporary not support
|
||||
* @param input_a runtime_tensor
|
||||
* @param input_b runtime_tensor
|
||||
* @param dtype datatype, output tensor datatype
|
||||
* @return result<runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> add(runtime::runtime_tensor &input_a, runtime::runtime_tensor &input_b, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::binary(input_a, input_b, dtype, binary_add);
|
||||
}
|
||||
/**
|
||||
* @brief binary sub
|
||||
* temporary not support
|
||||
* @param input_a runtime_tensor
|
||||
* @param input_b runtime_tensor
|
||||
* @param dtype datatype, output tensor datatype
|
||||
* @return result<runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> sub(runtime::runtime_tensor &input_a, runtime::runtime_tensor &input_b, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::binary(input_a, input_b, dtype, binary_sub);
|
||||
}
|
||||
/**
|
||||
* @brief binary mul
|
||||
* temporary not support
|
||||
* @param input_a runtime_tensor
|
||||
* @param input_b runtime_tensor
|
||||
* @param dtype datatype, output tensor datatype
|
||||
* @return result<runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> mul(runtime::runtime_tensor &input_a, runtime::runtime_tensor &input_b, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::binary(input_a, input_b, dtype, binary_mul);
|
||||
}
|
||||
/**
|
||||
* @brief binary div
|
||||
* temporary not support
|
||||
* @param input_a runtime_tensor
|
||||
* @param input_b runtime_tensor
|
||||
* @param dtype datatype, output tensor datatype
|
||||
* @return result<runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> div(runtime::runtime_tensor &input_a, runtime::runtime_tensor &input_b, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::binary(input_a, input_b, dtype, binary_div);
|
||||
}
|
||||
/**
|
||||
* @brief binary min
|
||||
* temporary not support
|
||||
* @param input_a runtime_tensor
|
||||
* @param input_b runtime_tensor
|
||||
* @param dtype datatype, output tensor datatype
|
||||
* @return result<runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> min(runtime::runtime_tensor &input_a, runtime::runtime_tensor &input_b, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::binary(input_a, input_b, dtype, binary_min);
|
||||
}
|
||||
/**
|
||||
* @brief binary max
|
||||
* temporary not support
|
||||
* @param input_a runtime_tensor
|
||||
* @param input_b runtime_tensor
|
||||
* @param dtype datatype, output tensor datatype
|
||||
* @return result<runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> max(runtime::runtime_tensor &input_a, runtime::runtime_tensor &input_b, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::binary(input_a, input_b, dtype, binary_max);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief quantize float or bfloat tensor to uint8 or int8
|
||||
*
|
||||
* @param input runtime_tensor
|
||||
* @param dtype datatype, output tensor datatype
|
||||
* @return result<runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> quantize(runtime::runtime_tensor &input, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::quantize(input, dtype);
|
||||
}
|
||||
/**
|
||||
* @brief dequantize uint8 or int8 tensor to float or bfloat
|
||||
*
|
||||
* @param input runtime_tensor
|
||||
* @param dtype datatype, output tensor datatype
|
||||
* @return result<runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> dequantize(runtime::runtime_tensor &input, datatype_t dtype) noexcept
|
||||
{
|
||||
return impl::dequantize(input, dtype);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief give bboxs, crop new tensor from current tensor.
|
||||
*
|
||||
* @param input
|
||||
* @param bbox runtime tensor, shape should be [1,1,roi_amounts,4], layout should be [y0, x0, y1, x1]
|
||||
* @param out_h output tensor height
|
||||
* @param out_w output tensor width
|
||||
* @param resize_mode resize mode
|
||||
* @return result<runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> crop(runtime::runtime_tensor &input, runtime::runtime_tensor &bbox, size_t out_h, size_t out_w, image_resize_mode_t resize_mode, bool align_corners, bool half_pixel_centers) noexcept
|
||||
{
|
||||
return impl::crop(input, bbox, out_h, out_w, resize_mode, align_corners, half_pixel_centers);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief resize tensor to new height or width
|
||||
*
|
||||
* @param input
|
||||
* @param out_h
|
||||
* @param out_w
|
||||
* @param resize_mode
|
||||
* @return result<runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> resize(runtime::runtime_tensor &input, size_t out_h, size_t out_w, image_resize_mode_t resize_mode, bool align_corners, bool half_pixel_centers) noexcept
|
||||
{
|
||||
return impl::resize(input, out_h, out_w, resize_mode, align_corners, half_pixel_centers);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief padding value on the input tensor
|
||||
* temporary not support
|
||||
* @param input
|
||||
* @param padding vector for padding param, from last to frist. eg. vector [ {2,3}, {1,3} ] mean pad {2,3} in last dim, pad {1,3} in last second dim
|
||||
* @param pad_mode
|
||||
* @param fill_v const fill value
|
||||
* @return result<runtime_tensor>
|
||||
*/
|
||||
NNCASE_API inline result<runtime::runtime_tensor> pad(runtime::runtime_tensor &input, runtime_paddings_t &paddings, pad_mode_t pad_mode, float fill_v) noexcept
|
||||
{
|
||||
return impl::pad(input, paddings, pad_mode, fill_v);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* @file ops.platform.h
|
||||
* @copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
#pragma once
|
||||
#include <nncase/runtime/runtime_tensor.h>
|
||||
|
||||
namespace nncase::F::impl
|
||||
{
|
||||
|
||||
result<runtime::runtime_tensor> unary(runtime::runtime_tensor &input, datatype_t dtype, unary_op_t op_type) noexcept;
|
||||
|
||||
result<runtime::runtime_tensor> binary(runtime::runtime_tensor &input_a, runtime::runtime_tensor &input_b, datatype_t dtype, binary_op_t op_type) noexcept;
|
||||
|
||||
result<runtime::runtime_tensor> quantize(runtime::runtime_tensor &input, datatype_t dtype) noexcept;
|
||||
|
||||
result<runtime::runtime_tensor> dequantize(runtime::runtime_tensor &input, datatype_t dtype) noexcept;
|
||||
|
||||
result<runtime::runtime_tensor> crop(runtime::runtime_tensor &input, runtime::runtime_tensor &bbox, size_t out_h, size_t out_w, image_resize_mode_t resize_mode, bool align_corners, bool half_pixel_centers) noexcept;
|
||||
|
||||
result<runtime::runtime_tensor> resize(runtime::runtime_tensor &input, size_t out_h, size_t out_w, image_resize_mode_t resize_mode, bool align_corners, bool half_pixel_centers) noexcept;
|
||||
|
||||
result<runtime::runtime_tensor> pad(runtime::runtime_tensor &input, runtime_paddings_t &paddings, pad_mode_t pad_mode, float fill_v) noexcept;
|
||||
}
|
|
@ -14,7 +14,7 @@
|
|||
*/
|
||||
#pragma once
|
||||
#include <filesystem>
|
||||
#include <nncase/ir/module.h>
|
||||
#include <nncase/ir/graph.h>
|
||||
#include <span>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
@ -23,11 +23,10 @@ namespace nncase::importer
|
|||
{
|
||||
struct import_options
|
||||
{
|
||||
std::string input_layout = "NCHW";
|
||||
std::string output_layout = "NCHW";
|
||||
std::span<const std::string> output_arrays;
|
||||
};
|
||||
|
||||
NNCASE_API ir::module_t import_tflite(std::span<const uint8_t> model, const import_options &options);
|
||||
//ir::module import_onnx(std::span<const uint8_t> model, const import_options &options);
|
||||
void import_tflite(ir::graph &graph, std::span<const uint8_t> model, const import_options &options, std::string &real_inlayout, std::string &real_outlayout);
|
||||
void import_onnx(ir::graph &graph, std::span<const uint8_t> model, const import_options &options, std::string &real_inlayout, std::string &real_outlayout);
|
||||
void import_caffe(ir::graph &graph, std::span<const uint8_t> model, std::span<const uint8_t> prototxt, std::string &real_inlayout, std::string &real_outlayout);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include <nncase/ir/debug.h>
|
||||
#include <nncase/ir/ir_types.h>
|
||||
#include <nncase/ir/ops/convert.h>
|
||||
#include <nncase/runtime/debug.h>
|
||||
|
||||
namespace nncase::importer
|
||||
{
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "function.h"
|
||||
#include "op.h"
|
||||
#include <variant>
|
||||
|
||||
namespace nncase::ir {
|
||||
/** @brief Call node */
|
||||
class NNCASE_API call_node : public expr_node {
|
||||
DEFINE_OBJECT_KIND(expr_node, object_call)
|
||||
public:
|
||||
call_node(expr target, std::vector<expr> arguments);
|
||||
|
||||
/** @brief Get the arguments of the call expression */
|
||||
std::span<const expr> arguments() const noexcept { return arguments_; }
|
||||
/** @brief Get the mutable arguments of the call expression */
|
||||
std::span<expr> arguments() noexcept { return arguments_; }
|
||||
|
||||
/** @brief Get the target of the call expression */
|
||||
const expr &target() const noexcept { return target_; }
|
||||
/** @brief Set the target of the function expression */
|
||||
void target(expr value) noexcept { target_ = std::move(value); }
|
||||
|
||||
private:
|
||||
expr target_;
|
||||
std::vector<expr> arguments_;
|
||||
};
|
||||
|
||||
class call : public object_t<call_node> {
|
||||
public:
|
||||
using object_t::object_t;
|
||||
|
||||
NNCASE_API call(expr target, std::vector<expr> arguments);
|
||||
};
|
||||
} // namespace nncase::ir
|
|
@ -0,0 +1,91 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "ir_types.h"
|
||||
#include <optional>
|
||||
#include <span>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <xtensor/xshape.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class node;
|
||||
class output_connector;
|
||||
|
||||
class NNCASE_API base_connector
|
||||
{
|
||||
public:
|
||||
template <class TName, class TShape>
|
||||
base_connector(node &owner, TName &&name, datatype_t type, TShape &&shape)
|
||||
: owner_(owner), name_(std::forward<TName>(name)), type_(type), shape_(std::forward<TShape>(shape))
|
||||
{
|
||||
}
|
||||
|
||||
base_connector(base_connector &) = delete;
|
||||
base_connector(base_connector &&) = default;
|
||||
|
||||
node &owner() const noexcept { return owner_; }
|
||||
const std::string &name() const noexcept { return name_; }
|
||||
datatype_t type() const noexcept { return type_; }
|
||||
const shape_t &shape() const noexcept { return shape_; }
|
||||
connector_attributes attributes() const noexcept { return attributes_; }
|
||||
void attributes(connector_attributes value) noexcept { attributes_ = value; }
|
||||
|
||||
private:
|
||||
node &owner_;
|
||||
std::string name_;
|
||||
datatype_t type_;
|
||||
shape_t shape_;
|
||||
connector_attributes attributes_ = cnctr_attr_none;
|
||||
};
|
||||
|
||||
class NNCASE_API input_connector : public base_connector
|
||||
{
|
||||
public:
|
||||
using base_connector::base_connector;
|
||||
|
||||
output_connector *connection() const noexcept { return connection_; }
|
||||
void connect(output_connector &connector);
|
||||
void clear_connection();
|
||||
|
||||
private:
|
||||
output_connector *connection_ = nullptr;
|
||||
};
|
||||
|
||||
class NNCASE_API output_connector : public base_connector
|
||||
{
|
||||
public:
|
||||
template <class TName, class TShape>
|
||||
output_connector(node &owner, TName &&name, datatype_t type, TShape &&shape, memory_location_t memory_location = mem_data)
|
||||
: base_connector(owner, std::forward<TName>(name), type, std::forward<TShape>(shape)), memory_location_(memory_location)
|
||||
{
|
||||
}
|
||||
|
||||
std::span<input_connector *const> connections() const noexcept { return connections_; }
|
||||
void connect(input_connector &connector);
|
||||
void disconnect(input_connector &connector);
|
||||
void clear_connections();
|
||||
// connector_attributes attributes() const noexcept { return attributes_; }
|
||||
// void attributes(connector_attributes value) noexcept { attributes_ = value; }
|
||||
memory_location_t memory_location() const noexcept { return memory_location_; }
|
||||
void memory_location(memory_location_t value) noexcept { memory_location_ = value; }
|
||||
|
||||
private:
|
||||
std::vector<input_connector *> connections_;
|
||||
// connector_attributes attributes_ = cnctr_attr_none;
|
||||
memory_location_t memory_location_;
|
||||
};
|
||||
}
|
|
@ -1,61 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "expr.h"
|
||||
#include "type.h"
|
||||
|
||||
namespace nncase::ir {
|
||||
/** @brief Constant node */
|
||||
class NNCASE_API constant_node : public expr_node {
|
||||
DEFINE_OBJECT_KIND(expr_node, object_constant);
|
||||
|
||||
public:
|
||||
constant_node(type value_type, std::vector<std::byte> data);
|
||||
|
||||
/** @brief Get the type of the constant expression */
|
||||
const type &value_type() const noexcept { return value_type_; }
|
||||
/** @brief Get the mutable type of the constant expression */
|
||||
type &value_type() noexcept { return value_type_; }
|
||||
|
||||
/** @brief Get the data of the constant expression */
|
||||
std::span<const std::byte> data() const noexcept { return data_; }
|
||||
/** @brief Get the mutable data of the constant expression */
|
||||
std::span<std::byte> data() noexcept { return data_; }
|
||||
|
||||
private:
|
||||
type value_type_;
|
||||
std::vector<std::byte> data_;
|
||||
};
|
||||
|
||||
class constant : public object_t<constant_node> {
|
||||
public:
|
||||
using object_t::object_t;
|
||||
|
||||
NNCASE_API constant(type value_type, std::vector<std::byte> data);
|
||||
|
||||
constant(type value_type, std::span<const std::byte> data)
|
||||
: constant(std::move(value_type),
|
||||
std::vector<std::byte>(data.begin(), data.end())) {}
|
||||
|
||||
template <class T>
|
||||
constant(type value_type, std::span<const T> data)
|
||||
: constant(std::move(value_type), std::as_bytes(data)) {}
|
||||
|
||||
template <class TScalar>
|
||||
constant(TScalar scalar)
|
||||
: constant(prim_type(to_datatype<TScalar>()),
|
||||
std::span<const TScalar>(&scalar, 1)) {}
|
||||
};
|
||||
} // namespace nncase::ir
|
|
@ -13,37 +13,28 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "graph.h"
|
||||
#include "ir_types.h"
|
||||
#include "module.h"
|
||||
#include <filesystem>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
namespace nncase {
|
||||
constexpr std::string_view datatype_names(datatype_t dt) {
|
||||
switch (dt) {
|
||||
#define DEFINE_DATATYPE(id, t, name, value) \
|
||||
case dt_##id: \
|
||||
return #name;
|
||||
#include <nncase/runtime/datatypes.def>
|
||||
#undef DEFINE_DATATYPE
|
||||
default:
|
||||
throw std::invalid_argument("invalid datatype");
|
||||
}
|
||||
namespace nncase
|
||||
{
|
||||
inline std::string to_string(const padding &value)
|
||||
{
|
||||
return "{" + std::to_string(value.before) + ", " + std::to_string(value.after) + "}";
|
||||
}
|
||||
|
||||
inline std::string to_string(const padding &value) {
|
||||
return "{" + std::to_string(value.before) + ", " +
|
||||
std::to_string(value.after) + "}";
|
||||
inline std::string to_string(const quant_param_t &value)
|
||||
{
|
||||
return "(q - " + std::to_string(value.zero_point) + ") * " + std::to_string(value.scale);
|
||||
}
|
||||
|
||||
inline std::string to_string(const quant_param_t &value) {
|
||||
return "(q - " + std::to_string(value.zero_point) + ") * " +
|
||||
std::to_string(value.scale);
|
||||
}
|
||||
|
||||
inline std::string to_string(memory_location_t location) {
|
||||
switch (location) {
|
||||
inline std::string to_string(memory_location_t location)
|
||||
{
|
||||
switch (location)
|
||||
{
|
||||
case mem_input:
|
||||
return "input";
|
||||
case mem_output:
|
||||
|
@ -58,53 +49,52 @@ inline std::string to_string(memory_location_t location) {
|
|||
}
|
||||
|
||||
template <typename Tv, typename T>
|
||||
static size_t index_of(const Tv &v, const T &e) {
|
||||
for (size_t i = 0; i < v.size(); i++) {
|
||||
if (&v[i] == &e) {
|
||||
static size_t index_of(const Tv &v, const T &e)
|
||||
{
|
||||
for (size_t i = 0; i < v.size(); i++)
|
||||
{
|
||||
if (&v[i] == &e)
|
||||
{
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return SIZE_MAX;
|
||||
}
|
||||
|
||||
namespace ir {
|
||||
inline std::string to_string(const dim_t &dim) {
|
||||
return dim.is_fixed() ? std::to_string(dim.value) : "?";
|
||||
}
|
||||
|
||||
inline std::string to_string(const shape_t &shape) {
|
||||
std::string str{'['};
|
||||
if (shape.is_invalid()) {
|
||||
str += "invalid";
|
||||
} else if (shape.is_unranked()) {
|
||||
str += '*';
|
||||
} else {
|
||||
for (size_t i = 0; i < shape.rank(); i++) {
|
||||
if (i != 0) {
|
||||
namespace ir
|
||||
{
|
||||
inline std::string to_string(const shape_t &shape)
|
||||
{
|
||||
std::string str { '[' };
|
||||
for (size_t i = 0; i < shape.size(); i++)
|
||||
{
|
||||
if (i != 0)
|
||||
{
|
||||
str.append(",");
|
||||
}
|
||||
str.append(to_string(shape[i]));
|
||||
str.append(std::to_string(shape[i]));
|
||||
}
|
||||
|
||||
str += ']';
|
||||
return str;
|
||||
}
|
||||
|
||||
str += ']';
|
||||
return str;
|
||||
}
|
||||
|
||||
inline std::string to_string(const axis_t &axis) {
|
||||
std::string str{'['};
|
||||
for (size_t i = 0; i < axis.size(); i++) {
|
||||
if (i != 0) {
|
||||
str.append(",");
|
||||
inline std::string to_string(const axis_t &axis)
|
||||
{
|
||||
std::string str { '[' };
|
||||
for (size_t i = 0; i < axis.size(); i++)
|
||||
{
|
||||
if (i != 0)
|
||||
{
|
||||
str.append(",");
|
||||
}
|
||||
str.append(std::to_string(axis[i]));
|
||||
}
|
||||
str.append(std::to_string(axis[i]));
|
||||
|
||||
str += ']';
|
||||
return str;
|
||||
}
|
||||
|
||||
str += ']';
|
||||
return str;
|
||||
NNCASE_API void dump_graph(const ir::graph &src_graph, const std::filesystem::path &dst_path);
|
||||
}
|
||||
}
|
||||
|
||||
NNCASE_API void dump_function(const ir::function &func,
|
||||
const std::filesystem::path &dst_path);
|
||||
} // namespace ir
|
||||
} // namespace nncase
|
||||
|
|
|
@ -24,6 +24,13 @@ class target;
|
|||
|
||||
namespace nncase::ir
|
||||
{
|
||||
enum class eval_step
|
||||
{
|
||||
after_import,
|
||||
after_calib,
|
||||
after_quant
|
||||
};
|
||||
|
||||
class module_evaluate_context;
|
||||
class model_evaluate_context;
|
||||
|
||||
|
@ -53,7 +60,7 @@ public:
|
|||
|
||||
module_evaluate_context &module() const noexcept { return mod_eval_; }
|
||||
|
||||
void evaluate();
|
||||
void evaluate(eval_step step, size_t stage, bool record_output_buffers);
|
||||
|
||||
private:
|
||||
const schedule::function_schedule_result &sched_;
|
||||
|
@ -129,7 +136,7 @@ public:
|
|||
void end_sample();
|
||||
void end_collect_distribution(const std::function<void(size_t cnt, size_t total)> &progress);
|
||||
|
||||
void evaluate();
|
||||
void evaluate(eval_step step, size_t stage, bool record_output_buffers);
|
||||
|
||||
private:
|
||||
const schedule::model_schedule_result &sched_;
|
||||
|
|
|
@ -26,7 +26,7 @@ public:
|
|||
evaluator(evaluator &&) = default;
|
||||
|
||||
void enable_ptq(target &target, ir::calibrate_method calib_method);
|
||||
void evaluate();
|
||||
void evaluate(eval_step step = nncase::ir::eval_step::after_import, size_t stage = 0, bool record_output_buffers = false);
|
||||
|
||||
ir::quantizer *quantizer(const module_type_t &module_type);
|
||||
void begin_collect_distribution();
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../object.h"
|
||||
#include "type.h"
|
||||
#include <range/v3/range/concepts.hpp>
|
||||
|
||||
namespace nncase::ir {
|
||||
/** @brief Expression node */
|
||||
class NNCASE_API expr_node : public object_node {
|
||||
DEFINE_OBJECT_KIND(object_node, object_expr)
|
||||
|
||||
public:
|
||||
expr_node();
|
||||
|
||||
/** @brief Get the checked type of the expression */
|
||||
const type &checked_type() const noexcept { return checked_type_; }
|
||||
/** @brief Get the mutable checked type of the expression */
|
||||
type &checked_type() noexcept { return checked_type_; }
|
||||
/** @brief Set the checked type of the expression */
|
||||
void checked_type(type value) noexcept { checked_type_ = std::move(value); }
|
||||
|
||||
private:
|
||||
type checked_type_;
|
||||
};
|
||||
|
||||
using expr = object_t<expr_node>;
|
||||
|
||||
template <class T>
|
||||
concept Expr = Object<T> &&
|
||||
(concepts::same_as<expr_node, typename T::node_type> ||
|
||||
concepts::derived_from<typename T::node_type, expr_node>);
|
||||
} // namespace nncase::ir
|
|
@ -1,65 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "expr.h"
|
||||
#include "var.h"
|
||||
|
||||
namespace nncase::ir {
|
||||
/** @brief Function node */
|
||||
class NNCASE_API function_node : public expr_node {
|
||||
DEFINE_OBJECT_KIND(expr_node, object_function)
|
||||
public:
|
||||
function_node(std::string name, std::vector<var> parameters, expr body);
|
||||
|
||||
/** @brief Get the name of the function expression */
|
||||
const std::string &name() const noexcept { return name_; }
|
||||
/** @brief Get the mutable name of the function expression */
|
||||
std::string &name() noexcept { return name_; }
|
||||
|
||||
/** @brief Get the parameters of the function expression */
|
||||
std::span<const var> parameters() const noexcept { return parameters_; }
|
||||
|
||||
/** @brief Get the body of the function expression */
|
||||
const expr &body() const noexcept { return body_; }
|
||||
/** @brief Set the body of the function expression */
|
||||
void body(expr value) noexcept { body_ = std::move(value); }
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
std::vector<var> parameters_;
|
||||
expr body_;
|
||||
};
|
||||
|
||||
/** @brief Function expression */
|
||||
class function : public object_t<function_node> {
|
||||
public:
|
||||
using object_t::object_t;
|
||||
|
||||
/** @brief Construct a named function expression with auto-generated name
|
||||
* @param[in] name The name of the function
|
||||
* @param[in] parameters The parameters of the function
|
||||
* @param[in] body The body of the function
|
||||
*/
|
||||
NNCASE_API function(std::vector<var> parameters, expr body);
|
||||
|
||||
/** @brief Construct a named function expression
|
||||
* @param[in] name The name of the function
|
||||
* @param[in] parameters The parameters of the function
|
||||
* @param[in] body The body of the function
|
||||
*/
|
||||
NNCASE_API function(std::string name, std::vector<var> parameters,
|
||||
expr body);
|
||||
};
|
||||
} // namespace nncase::ir
|
|
@ -1,37 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../object.h"
|
||||
#include "call.h"
|
||||
#include "constant.h"
|
||||
|
||||
namespace nncase::ir::F {
|
||||
namespace detail {
|
||||
template <class T> concept Scalar = requires {
|
||||
nncase::detail::cpp_type_to_datatype<T>::type;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
class fexpr : public expr {
|
||||
public:
|
||||
template <Expr T> fexpr(T &&other) noexcept : expr(std::move(other)) {}
|
||||
template <Expr T> fexpr(const T &other) noexcept : expr(other) {}
|
||||
|
||||
template <detail::Scalar T>
|
||||
fexpr(T scalar)
|
||||
: expr(constant(tensor_type(to_datatype<T>()),
|
||||
std::span<const T>(&scalar, 1))) {}
|
||||
};
|
||||
} // namespace nncase::ir::F
|
|
@ -0,0 +1,86 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "node.h"
|
||||
#include "placeholders.h"
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class graph;
|
||||
|
||||
struct split_graph_result
|
||||
{
|
||||
std::unique_ptr<graph> subgraph;
|
||||
std::unordered_map<input_node *, output_connector *> inputs;
|
||||
std::unordered_map<output_node *, std::vector<input_connector *>> outputs;
|
||||
};
|
||||
|
||||
class NNCASE_API graph
|
||||
{
|
||||
public:
|
||||
graph() noexcept;
|
||||
explicit graph(const module_type_t &module_type) noexcept
|
||||
: module_type_(module_type) { }
|
||||
|
||||
graph(graph &) = delete;
|
||||
graph(graph &&) = delete;
|
||||
|
||||
const std::string &name() const noexcept { return name_; }
|
||||
std::string escaped_name() const noexcept;
|
||||
void name(std::string value) { name_ = std::move(value); }
|
||||
const module_type_t &module_type() const noexcept { return module_type_; }
|
||||
void set_module_type(module_type_t type) { this->module_type_ = type; }
|
||||
|
||||
std::span<std::unique_ptr<node>> nodes() noexcept { return nodes_; }
|
||||
std::span<input_node *> inputs() noexcept { return inputs_; }
|
||||
std::span<output_node *> outputs() noexcept { return outputs_; }
|
||||
std::span<std::unique_ptr<graph>> subgraphs() noexcept { return subgraphs_; }
|
||||
std::vector<graph *> reachable_graphs() noexcept;
|
||||
|
||||
std::span<std::unique_ptr<node> const> nodes() const noexcept { return nodes_; }
|
||||
std::span<input_node *const> inputs() const noexcept { return inputs_; }
|
||||
std::span<output_node *const> outputs() const noexcept { return outputs_; }
|
||||
std::span<std::unique_ptr<graph> const> subgraphs() const noexcept { return subgraphs_; }
|
||||
|
||||
template <class T, class... TArgs>
|
||||
T *emplace(TArgs &&...args)
|
||||
{
|
||||
auto node = static_cast<T *>(nodes_.emplace_back(new T(std::forward<TArgs>(args)...)).get());
|
||||
if constexpr (std::is_same_v<T, input_node>)
|
||||
inputs_.emplace_back(node);
|
||||
else if constexpr (std::is_same_v<T, output_node>)
|
||||
outputs_.emplace_back(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
void assign_names();
|
||||
void dce();
|
||||
void cse();
|
||||
void merge_module_regions();
|
||||
split_graph_result split_subgraph(std::span<node *const> nodes);
|
||||
graph &add_subgraph(std::unique_ptr<graph> subgraph);
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
module_type_t module_type_;
|
||||
std::vector<std::unique_ptr<node>> nodes_;
|
||||
std::vector<std::unique_ptr<graph>> subgraphs_;
|
||||
std::vector<input_node *> inputs_;
|
||||
std::vector<output_node *> outputs_;
|
||||
};
|
||||
}
|
|
@ -13,18 +13,18 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "shape.h"
|
||||
#include "type.h"
|
||||
#include <nncase/runtime/datatypes.h>
|
||||
#include <span>
|
||||
#include <type_traits>
|
||||
#include <xtensor/xshape.hpp>
|
||||
|
||||
namespace nncase::ir {
|
||||
class op_node;
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
using shape_t = xt::dynamic_shape<std::size_t>;
|
||||
using axis_t = xt::dynamic_shape<int32_t>;
|
||||
|
||||
enum node_attributes {
|
||||
enum node_attributes
|
||||
{
|
||||
node_attr_none = 0,
|
||||
node_attr_action = 1,
|
||||
node_attr_need_quantize = 2,
|
||||
|
@ -33,44 +33,22 @@ enum node_attributes {
|
|||
node_attr_skip_constant_folding = 16
|
||||
};
|
||||
|
||||
enum connector_attributes {
|
||||
enum connector_attributes
|
||||
{
|
||||
cnctr_attr_none = 0,
|
||||
cnctr_attr_need_quantize = 1,
|
||||
cnctr_attr_no_layout_strides = 2,
|
||||
cnctr_attr_no_buffer_fusion = 4,
|
||||
cnctr_attr_no_dummy_for_benchmark = 8
|
||||
cnctr_attr_buffer_slice = 8,
|
||||
cnctr_attr_no_dummy_for_benchmark = 16
|
||||
};
|
||||
|
||||
DEFINE_ENUM_BITMASK_OPERATORS(node_attributes)
|
||||
DEFINE_ENUM_BITMASK_OPERATORS(connector_attributes)
|
||||
|
||||
class NNCASE_API connector_info {
|
||||
public:
|
||||
connector_info(op_node &owner, std::string name, size_t index)
|
||||
: owner_(owner), name_(std::move(name)), index_(index) {}
|
||||
|
||||
connector_info(const connector_info &) = delete;
|
||||
connector_info(connector_info &&) = default;
|
||||
connector_info &operator=(const connector_info &) = delete;
|
||||
|
||||
op_node &owner() const noexcept { return owner_; }
|
||||
const std::string &name() const noexcept { return name_; }
|
||||
size_t index() const noexcept { return index_; }
|
||||
|
||||
connector_attributes attributes() const noexcept { return attributes_; }
|
||||
void attributes(connector_attributes value) noexcept {
|
||||
attributes_ = value;
|
||||
}
|
||||
|
||||
private:
|
||||
op_node &owner_;
|
||||
std::string name_;
|
||||
size_t index_;
|
||||
connector_attributes attributes_ = cnctr_attr_none;
|
||||
};
|
||||
|
||||
template <class T, class = std::enable_if_t<std::is_pointer_v<T>>>
|
||||
std::vector<std::decay_t<T>> dup(std::span<T> source) {
|
||||
return {source.begin(), source.end()};
|
||||
std::vector<std::decay_t<T>> dup(std::span<T> source)
|
||||
{
|
||||
return { source.begin(), source.end() };
|
||||
}
|
||||
}
|
||||
} // namespace nncase::ir
|
||||
|
|
|
@ -1,53 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../op.h"
|
||||
#include "nncase/runtime/datatypes.h"
|
||||
#include "opcode.h"
|
||||
|
||||
namespace nncase::ir::math {
|
||||
/** @brief Binary operator node */
|
||||
class NNCASE_API binary_node : public op_node {
|
||||
DEFINE_OBJECT_KIND(op_node, op_math_binary)
|
||||
public:
|
||||
binary_node(binary_op_t binary_op);
|
||||
|
||||
/** @brief Get the binary opcode of the binary expression */
|
||||
binary_op_t binary_op() const noexcept { return binary_op_; }
|
||||
/** @brief Set the binary opcode of the binary expression */
|
||||
void binary_op(binary_op_t value) noexcept { binary_op_ = value; }
|
||||
|
||||
/** @brief Get the lhs the binary expression */
|
||||
const connector_info &lhs() const noexcept { return parameter_at(0); }
|
||||
/** @brief Get the rhs the binary expression */
|
||||
const connector_info &rhs() const noexcept { return parameter_at(1); }
|
||||
|
||||
type infer_invoke_result_type(type_infer_context &context) override;
|
||||
|
||||
private:
|
||||
binary_op_t binary_op_;
|
||||
};
|
||||
|
||||
/** @brief Binary expression */
|
||||
class binary : public object_t<binary_node> {
|
||||
public:
|
||||
using object_t::object_t;
|
||||
|
||||
/** @brief Construct an binary expression
|
||||
* @param[in] binary_op The opcode of the binary
|
||||
*/
|
||||
NNCASE_API binary(binary_op_t binary_op);
|
||||
};
|
||||
} // namespace nncase::ir::math
|
|
@ -1,38 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../op.h"
|
||||
#include "nncase/runtime/datatypes.h"
|
||||
#include "opcode.h"
|
||||
|
||||
namespace nncase::ir::math {
|
||||
/** @brief Clamp operator node */
|
||||
class NNCASE_API clamp_node : public op_node {
|
||||
DEFINE_OBJECT_KIND(op_node, op_math_clamp)
|
||||
public:
|
||||
clamp_node();
|
||||
|
||||
/** @brief Get the input the clamp expression */
|
||||
const connector_info &input() const noexcept { return parameter_at(0); }
|
||||
/** @brief Get the min the clamp expression */
|
||||
const connector_info &min() const noexcept { return parameter_at(1); }
|
||||
/** @brief Get the max the clamp expression */
|
||||
const connector_info &max() const noexcept { return parameter_at(2); }
|
||||
|
||||
type infer_invoke_result_type(type_infer_context &context) override;
|
||||
};
|
||||
|
||||
using clamp = object_t<clamp_node>;
|
||||
} // namespace nncase::ir::math
|
|
@ -1,94 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../call.h"
|
||||
#include "../functional.h"
|
||||
|
||||
namespace nncase::ir::F {
|
||||
NNCASE_API call unary(unary_op_t unary_op, fexpr input);
|
||||
NNCASE_API call binary(binary_op_t binary_op, fexpr lhs, fexpr rhs);
|
||||
NNCASE_API call clamp(fexpr input, fexpr min, fexpr max);
|
||||
|
||||
#define DEFINE_UNARY_FUNC(name, unary_op) \
|
||||
inline call name(fexpr input) { return F::unary(unary_op, input); }
|
||||
|
||||
DEFINE_UNARY_FUNC(abs, unary_abs)
|
||||
DEFINE_UNARY_FUNC(ceil, unary_ceil)
|
||||
DEFINE_UNARY_FUNC(cos, unary_cos)
|
||||
DEFINE_UNARY_FUNC(exp, unary_exp)
|
||||
DEFINE_UNARY_FUNC(floor, unary_floor)
|
||||
DEFINE_UNARY_FUNC(log, unary_log)
|
||||
DEFINE_UNARY_FUNC(neg, unary_neg)
|
||||
DEFINE_UNARY_FUNC(round, unary_round)
|
||||
DEFINE_UNARY_FUNC(rsqrt, unary_rsqrt)
|
||||
DEFINE_UNARY_FUNC(sin, unary_sin)
|
||||
DEFINE_UNARY_FUNC(sqrt, unary_sqrt)
|
||||
DEFINE_UNARY_FUNC(square, unary_square)
|
||||
DEFINE_UNARY_FUNC(tanh, unary_tanh)
|
||||
DEFINE_UNARY_FUNC(bitwise_not, unary_bitwise_not)
|
||||
DEFINE_UNARY_FUNC(logical_not, unary_logical_not)
|
||||
|
||||
#undef DEFINE_UNARY_FUNC
|
||||
|
||||
#define DEFINE_BINARY_FUNC(name, binary_op) \
|
||||
inline call name(fexpr lhs, expr rhs) { \
|
||||
return F::binary(binary_op, lhs, rhs); \
|
||||
}
|
||||
|
||||
DEFINE_BINARY_FUNC(add, binary_add)
|
||||
DEFINE_BINARY_FUNC(sub, binary_sub)
|
||||
DEFINE_BINARY_FUNC(mul, binary_mul)
|
||||
DEFINE_BINARY_FUNC(div, binary_div)
|
||||
DEFINE_BINARY_FUNC(mod, binary_mod)
|
||||
DEFINE_BINARY_FUNC(min, binary_min)
|
||||
DEFINE_BINARY_FUNC(max, binary_max)
|
||||
DEFINE_BINARY_FUNC(pow, binary_pow)
|
||||
DEFINE_BINARY_FUNC(bitwise_and, binary_bitwise_and)
|
||||
DEFINE_BINARY_FUNC(bitwise_or, binary_bitwise_or)
|
||||
DEFINE_BINARY_FUNC(bitwise_xor, binary_bitwise_xor)
|
||||
DEFINE_BINARY_FUNC(logical_and, binary_logical_and)
|
||||
DEFINE_BINARY_FUNC(logical_or, binary_logical_or)
|
||||
DEFINE_BINARY_FUNC(logical_xor, binary_logical_xor)
|
||||
|
||||
#undef DEFINE_BINARY_FUNC
|
||||
} // namespace nncase::ir::F
|
||||
|
||||
namespace nncase::ir {
|
||||
inline call operator-(expr input) { return F::neg(input); }
|
||||
inline call operator~(expr input) { return F::bitwise_not(input); }
|
||||
inline call operator!(expr input) { return F::logical_not(input); }
|
||||
|
||||
#define DEFINE_BINARY_OPERATOR(op, impl) \
|
||||
inline call operator op(expr lhs, expr rhs) { return F::impl(lhs, rhs); } \
|
||||
inline call operator op(expr lhs, F::fexpr rhs) { \
|
||||
return F::impl(lhs, rhs); \
|
||||
} \
|
||||
inline call operator op(F::fexpr lhs, expr rhs) { \
|
||||
return F::impl(lhs, rhs); \
|
||||
}
|
||||
|
||||
DEFINE_BINARY_OPERATOR(+, add)
|
||||
DEFINE_BINARY_OPERATOR(-, sub)
|
||||
DEFINE_BINARY_OPERATOR(*, mul)
|
||||
DEFINE_BINARY_OPERATOR(/, div)
|
||||
DEFINE_BINARY_OPERATOR(%, mod)
|
||||
DEFINE_BINARY_OPERATOR(&, bitwise_and)
|
||||
DEFINE_BINARY_OPERATOR(|, bitwise_or)
|
||||
DEFINE_BINARY_OPERATOR(^, bitwise_xor)
|
||||
DEFINE_BINARY_OPERATOR(&&, logical_and)
|
||||
DEFINE_BINARY_OPERATOR(||, logical_or)
|
||||
|
||||
#undef DEFINE_BINARY_OPERATOR
|
||||
} // namespace nncase::ir
|
|
@ -1,4 +0,0 @@
|
|||
DEFINE_OPCODE(math, unary, Unary, 0x1001)
|
||||
DEFINE_OPCODE(math, binary, Binary, 0x1002)
|
||||
DEFINE_OPCODE(math, compare, Compare, 0x1003)
|
||||
DEFINE_OPCODE(math, clamp, Clamp, 0x1004)
|
|
@ -1,25 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../../object.h"
|
||||
|
||||
namespace nncase::ir::math {
|
||||
#define DEFINE_OPCODE(dialect, id, name, value) \
|
||||
NNCASE_INLINE_VAR constexpr object_kind op_##dialect##_##id{value, #name};
|
||||
|
||||
#include "opcode.def"
|
||||
|
||||
#undef DEFINE_OPCODE
|
||||
} // namespace nncase::ir::math
|
|
@ -1,52 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../call.h"
|
||||
#include "../op.h"
|
||||
#include "nncase/runtime/datatypes.h"
|
||||
#include "opcode.h"
|
||||
|
||||
namespace nncase::ir::math {
|
||||
/** @brief Unary operator node */
|
||||
class NNCASE_API unary_node : public op_node {
|
||||
DEFINE_OBJECT_KIND(op_node, op_math_unary)
|
||||
public:
|
||||
unary_node(unary_op_t unary_op);
|
||||
|
||||
/** @brief Get the unary opcode of the unary expression */
|
||||
unary_op_t unary_op() const noexcept { return unary_op_; }
|
||||
/** @brief Set the unary opcode of the unary expression */
|
||||
void unary_op(unary_op_t value) noexcept { unary_op_ = value; }
|
||||
|
||||
/** @brief Get the input the unary expression */
|
||||
const connector_info &input() const noexcept { return parameter_at(0); }
|
||||
|
||||
type infer_invoke_result_type(type_infer_context &context) override;
|
||||
|
||||
private:
|
||||
unary_op_t unary_op_;
|
||||
};
|
||||
|
||||
/** @brief Unary expression */
|
||||
class unary : public object_t<unary_node> {
|
||||
public:
|
||||
using object_t::object_t;
|
||||
|
||||
/** @brief Construct an unary expression
|
||||
* @param[in] unary_op The opcode of the unary
|
||||
*/
|
||||
NNCASE_API unary(unary_op_t unary_op);
|
||||
};
|
||||
} // namespace nncase::ir::math
|
|
@ -1,51 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../object.h"
|
||||
#include "function.h"
|
||||
|
||||
namespace nncase::ir {
|
||||
/** @brief Module node*/
|
||||
class NNCASE_API module_node : public object_node {
|
||||
DEFINE_OBJECT_KIND(object_node, object_module)
|
||||
public:
|
||||
module_node();
|
||||
|
||||
/** @brief Get the functions of the module */
|
||||
const std::vector<function> &functions() const noexcept {
|
||||
return functions_;
|
||||
}
|
||||
|
||||
/** @brief Add new funtion to the module */
|
||||
const function &add_function(function func);
|
||||
|
||||
/** @brief Get the entry of the module */
|
||||
const function &entry() const noexcept { return entry_; }
|
||||
/** @brief Get the mutable entry of module */
|
||||
function &entry() noexcept { return entry_; }
|
||||
/** @brief Set the entry of the module */
|
||||
void entry(function value) noexcept { entry_ = std::move(value); }
|
||||
|
||||
private:
|
||||
std::vector<function> functions_;
|
||||
function entry_;
|
||||
};
|
||||
|
||||
/** @brief Module */
|
||||
class NNCASE_API module_t : public object_t<module_node> {
|
||||
public:
|
||||
using object_t::object_t;
|
||||
};
|
||||
} // namespace nncase::ir
|
|
@ -1,31 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../call.h"
|
||||
#include "../op.h"
|
||||
#include "nncase/runtime/datatypes.h"
|
||||
#include "opcode.h"
|
||||
|
||||
namespace nncase::ir::nn {
|
||||
/** @brief Conv1D operator node */
|
||||
class NNCASE_API conv1d_node : public op_node {
|
||||
DEFINE_OBJECT_KIND(op_node, op_nn_conv1d)
|
||||
public:
|
||||
// conv1d_node();
|
||||
};
|
||||
|
||||
/** @brief Conv1D expression */
|
||||
using conv1d = object_t<conv1d_node>;
|
||||
} // namespace nncase::ir::nn
|
|
@ -1,63 +0,0 @@
|
|||
DEFINE_OPCODE(nn, conv1d, Conv1D, 0x3001)
|
||||
DEFINE_OPCODE(nn, conv2d, Conv2D, 0x3002)
|
||||
DEFINE_OPCODE(nn, conv3d, Conv3D, 0x3003)
|
||||
DEFINE_OPCODE(nn, conv_transpose1d, ConvTranspose1D, 0x3004)
|
||||
DEFINE_OPCODE(nn, conv_transpose2d, ConvTranspose2D, 0x3005)
|
||||
DEFINE_OPCODE(nn, conv_transpose3d, ConvTranspose3D, 0x3006)
|
||||
DEFINE_OPCODE(nn, avg_pool1d, AvgPool1D, 0x3007)
|
||||
DEFINE_OPCODE(nn, avg_pool2d, AvgPool2D, 0x3008)
|
||||
DEFINE_OPCODE(nn, avg_pool3d, AvgPool3D, 0x3009)
|
||||
DEFINE_OPCODE(nn, max_pool1d, MaxPool1D, 0x300A)
|
||||
DEFINE_OPCODE(nn, max_pool2d, MaxPool2D, 0x300B)
|
||||
DEFINE_OPCODE(nn, max_pool3d, MaxPool3D, 0x300C)
|
||||
DEFINE_OPCODE(nn, l2_pool1d, L2Pool1D, 0x300D)
|
||||
DEFINE_OPCODE(nn, l2_pool2d, L2Pool2D, 0x300E)
|
||||
DEFINE_OPCODE(nn, l2_pool3d, L2Pool3D, 0x300F)
|
||||
DEFINE_OPCODE(nn, adaptive_avg_pool1d, AdaptiveAvgPool1D, 0x3010)
|
||||
DEFINE_OPCODE(nn, adaptive_avg_pool2d, AdaptiveAvgPool2D, 0x3011)
|
||||
DEFINE_OPCODE(nn, adaptive_avg_pool3d, AdaptiveAvgPool3D, 0x3012)
|
||||
DEFINE_OPCODE(nn, adaptive_max_pool1d, AdaptiveMaxPool1D, 0x3013)
|
||||
DEFINE_OPCODE(nn, adaptive_max_pool2d, AdaptiveMaxPool2D, 0x3014)
|
||||
DEFINE_OPCODE(nn, adaptive_max_pool3d, AdaptiveMaxPool3D, 0x3015)
|
||||
DEFINE_OPCODE(nn, relu, ReLU, 0x3016)
|
||||
DEFINE_OPCODE(nn, hardtanh, Hardtanh, 0x3017)
|
||||
DEFINE_OPCODE(nn, hardswish, HardSwish, 0x3018)
|
||||
DEFINE_OPCODE(nn, relu6, ReLU6, 0x3019)
|
||||
DEFINE_OPCODE(nn, elu, ELU, 0x301A)
|
||||
DEFINE_OPCODE(nn, selu, SELU, 0x301B)
|
||||
DEFINE_OPCODE(nn, celu, CELU, 0x301C)
|
||||
DEFINE_OPCODE(nn, leaky_relu, LeakyReLU, 0x301D)
|
||||
DEFINE_OPCODE(nn, prelu, PReLU, 0x301E)
|
||||
DEFINE_OPCODE(nn, rrelu, RReLU, 0x301F)
|
||||
DEFINE_OPCODE(nn, glu, GLU, 0x3020)
|
||||
DEFINE_OPCODE(nn, gelu, GELU, 0x3021)
|
||||
DEFINE_OPCODE(nn, logsigmoid, LogSigmoid, 0x3022)
|
||||
DEFINE_OPCODE(nn, hardshrink, Hardshrink, 0x3023)
|
||||
DEFINE_OPCODE(nn, tanhshrink, Tanhshrink, 0x3024)
|
||||
DEFINE_OPCODE(nn, softsign, SoftSign, 0x3025)
|
||||
DEFINE_OPCODE(nn, softplus, Softplus, 0x3026)
|
||||
DEFINE_OPCODE(nn, softmin, Softmin, 0x3027)
|
||||
DEFINE_OPCODE(nn, softmax, Softmax, 0x3028)
|
||||
DEFINE_OPCODE(nn, softshrink, Softshrink, 0x3029)
|
||||
DEFINE_OPCODE(nn, gumbel_softmax, GumbelSoftmax, 0x302A)
|
||||
DEFINE_OPCODE(nn, log_softmax, LogSoftmax, 0x302B)
|
||||
DEFINE_OPCODE(nn, sigmoid, Sigmoid, 0x302C)
|
||||
DEFINE_OPCODE(nn, hardsigmoid, Hardsigmoid, 0x302D)
|
||||
DEFINE_OPCODE(nn, silu, SiLU, 0x302E)
|
||||
DEFINE_OPCODE(nn, mish, Mish, 0x302F)
|
||||
DEFINE_OPCODE(nn, batch_norm, BatchNorm, 0x3030)
|
||||
DEFINE_OPCODE(nn, group_norm, GroupNorm, 0x3031)
|
||||
DEFINE_OPCODE(nn, instance_norm, InstanceNorm, 0x3032)
|
||||
DEFINE_OPCODE(nn, layer_norm, LayerNorm, 0x3033)
|
||||
DEFINE_OPCODE(nn, local_response_norm, LocalResponseNorm, 0x3034)
|
||||
DEFINE_OPCODE(nn, normalize, Normalize, 0x3035)
|
||||
DEFINE_OPCODE(nn, linear, Linear, 0x3036)
|
||||
DEFINE_OPCODE(nn, bilinear, Bilinear, 0x3037)
|
||||
DEFINE_OPCODE(nn, embedding, Embedding, 0x3038)
|
||||
DEFINE_OPCODE(nn, embedding_bag, EmbeddingBag, 0x3039)
|
||||
DEFINE_OPCODE(nn, one_hot, OneHot, 0x303A)
|
||||
DEFINE_OPCODE(nn, pixel_shuffle, PixelShuffle, 0x3040)
|
||||
DEFINE_OPCODE(nn, pad, Pad, 0x3041)
|
||||
DEFINE_OPCODE(nn, interpolate, Interpolate, 0x3042)
|
||||
DEFINE_OPCODE(nn, grid_sample, GridSample, 0x3043)
|
||||
DEFINE_OPCODE(nn, affine_grid, AffineGrid, 0x3044)
|
|
@ -1,25 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../../object.h"
|
||||
|
||||
namespace nncase::ir::nn {
|
||||
#define DEFINE_OPCODE(dialect, id, name, value) \
|
||||
NNCASE_INLINE_VAR constexpr object_kind op_##dialect##_##id{value, #name};
|
||||
|
||||
#include "opcode.def"
|
||||
|
||||
#undef DEFINE_OPCODE
|
||||
} // namespace nncase::ir::nn
|
|
@ -0,0 +1,92 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "connectors.h"
|
||||
#include "opcode.h"
|
||||
#include <list>
|
||||
#include <span>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
#define DEFINE_NODE_OPCODE(value) \
|
||||
static constexpr node_opcode opcode() noexcept { return value; } \
|
||||
const node_opcode &runtime_opcode() const noexcept override { return value; }
|
||||
|
||||
class NNCASE_API node
|
||||
{
|
||||
public:
|
||||
node(std::string name = "");
|
||||
node(node &) = delete;
|
||||
node &operator=(node &) = delete;
|
||||
virtual ~node();
|
||||
|
||||
const std::string &name() const noexcept { return name_; }
|
||||
template <class TArg, class... TArgs>
|
||||
void name(TArg arg, TArgs... args) { name_.assign(std::forward<TArg>(arg), std::forward<TArgs>(args)...); }
|
||||
std::string escaped_name() const noexcept;
|
||||
|
||||
const module_type_t &module_type() const noexcept { return module_type_; }
|
||||
void module_type(const module_type_t &type) noexcept { module_type_ = type; }
|
||||
|
||||
std::span<input_connector *const> inputs() const noexcept { return input_connectors_; }
|
||||
std::span<output_connector *const> outputs() const noexcept { return output_connectors_; }
|
||||
|
||||
input_connector &input_at(size_t index) const { return *input_connectors_.at(index); }
|
||||
output_connector &output_at(size_t index) const { return *output_connectors_.at(index); }
|
||||
|
||||
virtual const node_opcode &runtime_opcode() const noexcept = 0;
|
||||
node_attributes attributes() const noexcept { return attributes_; }
|
||||
void attributes(node_attributes value) noexcept { attributes_ = value; }
|
||||
|
||||
bool equals(node &other) const;
|
||||
|
||||
void record_output_connectors_quant_map(output_connector &oc_after_quant, output_connector &oc_before_quant) noexcept { output_connectors_quant_map_.emplace(&oc_after_quant, &oc_before_quant); }
|
||||
std::unordered_map<output_connector *, output_connector *> get_output_connectors_quant_map() const noexcept { return output_connectors_quant_map_; }
|
||||
|
||||
void record_node_name_before_quant(std::string name) noexcept { node_name_before_quant_.assign(name); }
|
||||
std::string get_node_name_before_quant() const noexcept { return node_name_before_quant_; }
|
||||
|
||||
protected:
|
||||
template <class TName, class TShape>
|
||||
input_connector &add_input(TName &&name, datatype_t type, TShape &&shape)
|
||||
{
|
||||
auto ptr = input_connectors_storage_.emplace_back(std::make_unique<input_connector>(*this, std::forward<TName>(name), type, std::forward<TShape>(shape))).get();
|
||||
input_connectors_.emplace_back(ptr);
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
template <class TName, class TShape>
|
||||
output_connector &add_output(TName &&name, datatype_t type, TShape &&shape, memory_location_t memory_location = mem_data)
|
||||
{
|
||||
auto ptr = output_connectors_storage_.emplace_back(std::make_unique<output_connector>(*this, std::forward<TName>(name), type, std::forward<TShape>(shape), memory_location)).get();
|
||||
output_connectors_.emplace_back(ptr);
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
virtual bool properties_equal(node &other) const = 0;
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
module_type_t module_type_;
|
||||
node_attributes attributes_ = node_attributes::node_attr_action;
|
||||
std::vector<input_connector *> input_connectors_;
|
||||
std::vector<output_connector *> output_connectors_;
|
||||
std::vector<std::unique_ptr<input_connector>> input_connectors_storage_;
|
||||
std::vector<std::unique_ptr<output_connector>> output_connectors_storage_;
|
||||
std::unordered_map<output_connector *, output_connector *> output_connectors_quant_map_;
|
||||
std::string node_name_before_quant_;
|
||||
};
|
||||
}
|
|
@ -1,47 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "expr.h"
|
||||
#include "ir_types.h"
|
||||
#include "type_infer.h"
|
||||
#include <utility>
|
||||
|
||||
namespace nncase::ir {
|
||||
/** @brief Operator node */
|
||||
class NNCASE_API op_node : public expr_node {
|
||||
DEFINE_OBJECT_KIND(expr_node, object_op)
|
||||
public:
|
||||
/** @brief Get the parameters of the function expression */
|
||||
std::span<const connector_info> parameters() const noexcept {
|
||||
return parameters_;
|
||||
}
|
||||
|
||||
/** @brief Get the parameter at the index */
|
||||
const connector_info ¶meter_at(size_t index) const noexcept {
|
||||
return parameters_.at(index);
|
||||
}
|
||||
|
||||
/** @brief Infer the invoke result type */
|
||||
virtual type infer_invoke_result_type(type_infer_context &context) = 0;
|
||||
|
||||
protected:
|
||||
connector_info &add_parameter(std::string name);
|
||||
|
||||
private:
|
||||
std::vector<connector_info> parameters_;
|
||||
};
|
||||
|
||||
using op = object_t<op_node>;
|
||||
} // namespace nncase::ir
|
|
@ -27,13 +27,20 @@ inline shape_t get_transposed_shape(const shape_t &input_shape, const axis_t &pe
|
|||
return new_shape;
|
||||
}
|
||||
|
||||
inline size_t get_windowed_output_size(int32_t size, int32_t filter, int32_t stride, int32_t dilation, bool same)
|
||||
inline size_t get_windowed_output_size(int32_t size, int32_t filter, int32_t stride, int32_t dilation, bool same, bool ceil_mode = false)
|
||||
{
|
||||
auto effective_filter_size = (filter - 1) * dilation + 1;
|
||||
if (same)
|
||||
return (size_t(size) + stride - 1) / stride;
|
||||
else
|
||||
return (size_t(size) - effective_filter_size + stride) / stride;
|
||||
{
|
||||
if (!ceil_mode)
|
||||
return (size_t(size) - effective_filter_size + stride) / stride;
|
||||
else
|
||||
{
|
||||
return static_cast<int>(ceil(static_cast<float>(size_t(size) - effective_filter_size + stride) / stride));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline padding get_windowed_padding(int32_t input_size, int32_t output_size, int32_t filter, int32_t stride, int32_t dilation)
|
||||
|
@ -131,7 +138,7 @@ inline shape_t normalize_reshape(const shape_t &in_shape, const axis_t &new_shap
|
|||
if (v == -1)
|
||||
{
|
||||
if (non_det_id)
|
||||
throw std::runtime_error("Reshap can only have 1 non-determined dimension at most");
|
||||
throw std::runtime_error("Reshape can only have 1 non-determined dimension at most");
|
||||
non_det_id = i;
|
||||
}
|
||||
else
|
||||
|
@ -230,7 +237,7 @@ inline shape_t get_padded_shape(const shape_t &in_shape, const xt::svector<paddi
|
|||
{
|
||||
auto new_shape = in_shape;
|
||||
for (size_t i = 0; i < in_shape.size(); i++)
|
||||
new_shape[i] = size_t(int32_t(new_shape[i]) + paddings[i].sum());
|
||||
new_shape[i] = size_t(int32_t(new_shape[i]) + paddings[i].sum() + (new_shape[i] - 1) * paddings[i].interior);
|
||||
return new_shape;
|
||||
}
|
||||
|
||||
|
@ -293,6 +300,50 @@ inline shape_t get_strided_slice_output_shape(const axis_t &begin, const axis_t
|
|||
return new_shape.size() ? new_shape : shape_t { 1 };
|
||||
}
|
||||
|
||||
inline bool is_copy_slice(const axis_t &strides)
|
||||
{
|
||||
return std::all_of(strides.begin(), strides.end(), [](int32_t stride) { return stride == 1; });
|
||||
}
|
||||
|
||||
inline bool is_simple_slice(const axis_t &begin, const axis_t &end, const axis_t &strides, const shape_t &input_shape)
|
||||
{
|
||||
if (!is_copy_slice(strides))
|
||||
return false;
|
||||
|
||||
bool is_simple_slice = true;
|
||||
bool allow_not_equal = true;
|
||||
for (size_t i = 0; i < begin.size(); i++)
|
||||
{
|
||||
if (begin[i] != 0
|
||||
|| end[i] != input_shape[i])
|
||||
{
|
||||
if (allow_not_equal)
|
||||
{
|
||||
allow_not_equal = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
is_simple_slice = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
else if (input_shape[i] != 1)
|
||||
{
|
||||
allow_not_equal = false;
|
||||
}
|
||||
}
|
||||
|
||||
return is_simple_slice;
|
||||
}
|
||||
|
||||
inline bool is_axis0_squeeze_or_expand_dim_bitcast(const shape_t &in_shape, const shape_t &out_shape)
|
||||
{
|
||||
auto in_begin = std::find_if_not(in_shape.begin(), in_shape.end(), [](size_t dim) { return dim == 1; });
|
||||
auto out_begin = std::find_if_not(out_shape.begin(), out_shape.end(), [](size_t dim) { return dim == 1; });
|
||||
return std::distance(in_begin, in_shape.end()) == std::distance(out_begin, out_shape.end())
|
||||
&& std::equal(in_begin, in_shape.end(), out_begin);
|
||||
}
|
||||
|
||||
template <class U, class T>
|
||||
std::span<U> as_span(const std::span<T> &src) noexcept
|
||||
{
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
DEFINE_NEUTRAL_OPCODE(input_node, Input, 0x01)
|
||||
DEFINE_NEUTRAL_OPCODE(output_node, Output, 0x02)
|
||||
DEFINE_NEUTRAL_OPCODE(ignore_node, Ignore, 0x03)
|
||||
DEFINE_NEUTRAL_OPCODE(constant, Constant, 0x04)
|
||||
DEFINE_NEUTRAL_OPCODE(uninitialized, Uninitialized, 0x05)
|
||||
DEFINE_NEUTRAL_OPCODE(call, Call, 0x06)
|
||||
DEFINE_NEUTRAL_OPCODE(copy, Copy, 0x07)
|
||||
|
||||
DEFINE_NEUTRAL_OPCODE(conv2d, Conv2D, 0x100)
|
||||
DEFINE_NEUTRAL_OPCODE(matmul, MatMul, 0x101)
|
||||
DEFINE_NEUTRAL_OPCODE(transpose, Transpose, 0x102)
|
||||
DEFINE_NEUTRAL_OPCODE(reduce, Reduce, 0x103)
|
||||
DEFINE_NEUTRAL_OPCODE(reduce_window2d, ReduceWindow2D, 0x104)
|
||||
DEFINE_NEUTRAL_OPCODE(binary, Binary, 0x105)
|
||||
DEFINE_NEUTRAL_OPCODE(concat, Concat, 0x106)
|
||||
DEFINE_NEUTRAL_OPCODE(unary, Unary, 0x107)
|
||||
DEFINE_NEUTRAL_OPCODE(fused_unary, FusedUnary, 0x108)
|
||||
DEFINE_NEUTRAL_OPCODE(quantize, Quantize, 0x109)
|
||||
DEFINE_NEUTRAL_OPCODE(dequantize, Dequantize, 0x10A)
|
||||
DEFINE_NEUTRAL_OPCODE(pad, Pad, 0x10B)
|
||||
DEFINE_NEUTRAL_OPCODE(bitcast, Bitcast, 0x10C)
|
||||
DEFINE_NEUTRAL_OPCODE(resize_image, ResizeImage, 0x10D)
|
||||
DEFINE_NEUTRAL_OPCODE(slice, Slice, 0x10E)
|
||||
DEFINE_NEUTRAL_OPCODE(table_lookup1d, TableLookup1D, 0x10F)
|
||||
DEFINE_NEUTRAL_OPCODE(conv2d_transpose, Conv2DTranspose, 0x110)
|
||||
DEFINE_NEUTRAL_OPCODE(clamp, Clamp, 0x111)
|
||||
DEFINE_NEUTRAL_OPCODE(convert, Convert, 0x112)
|
||||
DEFINE_NEUTRAL_OPCODE(broadcast, Broadcast, 0x113)
|
||||
DEFINE_NEUTRAL_OPCODE(take, Take, 0x114)
|
||||
DEFINE_NEUTRAL_OPCODE(space_to_batch, SpaceToBatch, 0x115)
|
||||
DEFINE_NEUTRAL_OPCODE(batch_to_space, BatchToSpace, 0x116)
|
||||
DEFINE_NEUTRAL_OPCODE(split, Split, 0x117)
|
||||
DEFINE_NEUTRAL_OPCODE(gather, Gather, 0x118)
|
||||
DEFINE_NEUTRAL_OPCODE(gather_nd, GatherND, 0x119)
|
||||
DEFINE_NEUTRAL_OPCODE(onehot, OneHot, 0x11A)
|
||||
DEFINE_NEUTRAL_OPCODE(lstm, LSTM, 0x11B)
|
||||
DEFINE_NEUTRAL_OPCODE(reduce_arg, ReduceArg, 0x11C)
|
||||
DEFINE_NEUTRAL_OPCODE(cumsum, CumSum, 0x11D)
|
||||
DEFINE_NEUTRAL_OPCODE(hardmax, HardMax, 0x11E)
|
||||
DEFINE_NEUTRAL_OPCODE(random_normal, RandomNormal, 0x11F)
|
||||
DEFINE_NEUTRAL_OPCODE(random_uniform, RandomUniform, 0x120)
|
||||
DEFINE_NEUTRAL_OPCODE(reduce_prod, ReduceProd, 0x121)
|
||||
DEFINE_NEUTRAL_OPCODE(ternary, Ternary, 0x122)
|
||||
DEFINE_NEUTRAL_OPCODE(topk, TopK, 0x123)
|
||||
DEFINE_NEUTRAL_OPCODE(trilu, Trilu, 0x124)
|
||||
DEFINE_NEUTRAL_OPCODE(sigmoid, Sigmoid, 0x125)
|
||||
DEFINE_NEUTRAL_OPCODE(roi_align, RoiAlign, 0x126)
|
||||
DEFINE_NEUTRAL_OPCODE(equal, Equal, 0x127)
|
|
@ -0,0 +1,51 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include <compare>
|
||||
#include <nncase/runtime/datatypes.h>
|
||||
#include <stdexcept>
|
||||
#include <string_view>
|
||||
#include <type_traits>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
struct node_opcode
|
||||
{
|
||||
uint32_t id;
|
||||
std::string_view name;
|
||||
};
|
||||
|
||||
constexpr inline bool operator==(const node_opcode &lhs, const node_opcode &rhs) noexcept { return lhs.id == rhs.id; }
|
||||
|
||||
#define DEFINE_NEUTRAL_OPCODE(id, name, value) NNCASE_INLINE_VAR constexpr node_opcode op_##id { value, #name };
|
||||
#define DEFINE_OPCODE(target, id, name, value) NNCASE_INLINE_VAR constexpr node_opcode op_##target##_##id { value, #name };
|
||||
|
||||
#include "opcode.def"
|
||||
|
||||
#undef DEFINE_NEUTRAL_OPCODE
|
||||
#undef DEFINE_OPCODE
|
||||
}
|
||||
|
||||
namespace std
|
||||
{
|
||||
template <>
|
||||
struct hash<nncase::ir::node_opcode>
|
||||
{
|
||||
[[nodiscard]] size_t operator()(const nncase::ir::node_opcode &opcode) const noexcept
|
||||
{
|
||||
return std::hash<uint32_t>()(opcode.id);
|
||||
}
|
||||
};
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xarray.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API batch_to_space : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_batch_to_space);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
int32_t block_size_h() const noexcept { return block_size_h_; }
|
||||
int32_t block_size_w() const noexcept { return block_size_w_; }
|
||||
|
||||
const axis_t &begin() const noexcept { return begin_; }
|
||||
const axis_t &end() const noexcept { return end_; }
|
||||
const axis_t &strides() const noexcept { return strides_; }
|
||||
int32_t begin_mask() const noexcept { return begin_mask_; }
|
||||
int32_t end_mask() const noexcept { return end_mask_; }
|
||||
int32_t ellipsis_mask() const noexcept { return ellipsis_mask_; }
|
||||
int32_t new_axis_mask() const noexcept { return new_axis_mask_; }
|
||||
int32_t shrink_axis_mask() const noexcept { return shrink_axis_mask_; }
|
||||
std::array<int32_t, 2> crop_h() const noexcept { return crop_h_; }
|
||||
std::array<int32_t, 2> crop_w() const noexcept { return crop_w_; }
|
||||
|
||||
batch_to_space(datatype_t input_type, shape_t input_shape, int32_t block_shape_h, int32_t block_shape_w, axis_t stride, axis_t begin, axis_t end, std::array<int32_t, 2> crop_h_, std::array<int32_t, 2> crop_w_);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
int32_t block_size_h_;
|
||||
int32_t block_size_w_;
|
||||
axis_t begin_;
|
||||
axis_t end_;
|
||||
axis_t strides_;
|
||||
int32_t begin_mask_;
|
||||
int32_t end_mask_;
|
||||
int32_t ellipsis_mask_;
|
||||
int32_t new_axis_mask_;
|
||||
int32_t shrink_axis_mask_;
|
||||
std::array<int32_t, 2> crop_h_;
|
||||
std::array<int32_t, 2> crop_w_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API binary : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_binary);
|
||||
|
||||
input_connector &input_a() { return input_at(0); }
|
||||
input_connector &input_b() { return input_at(1); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
binary_op_t binary_op() const noexcept { return binary_op_; }
|
||||
value_range<float> fused_activation() const noexcept { return fused_activation_; }
|
||||
|
||||
binary(binary_op_t binary_op, shape_t input_a_shape, shape_t input_b_shape, value_range<float> input_fused_activation);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
binary_op_t binary_op_;
|
||||
value_range<float> fused_activation_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API bitcast : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_bitcast);
|
||||
|
||||
const input_connector &input() const { return input_at(0); }
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
bool is_reshape() const noexcept { return new_type_ == input().type(); }
|
||||
datatype_t new_type() const noexcept { return new_type_; }
|
||||
const shape_t &new_shape() const noexcept { return new_shape_; }
|
||||
|
||||
bitcast(datatype_t input_type, shape_t input_shape, axis_t new_shape);
|
||||
bitcast(datatype_t input_type, shape_t input_shape, shape_t new_shape);
|
||||
bitcast(datatype_t input_type, shape_t input_shape, datatype_t new_type, axis_t new_shape);
|
||||
bitcast(datatype_t input_type, shape_t input_shape, datatype_t new_type, shape_t new_shape);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
datatype_t new_type_;
|
||||
shape_t new_shape_;
|
||||
};
|
||||
}
|
|
@ -13,17 +13,28 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../op.h"
|
||||
#include "nncase/runtime/datatypes.h"
|
||||
#include "opcode.h"
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir::tensors {
|
||||
/** @brief Broadcast operator node */
|
||||
class NNCASE_API broadcast_node : public op_node {
|
||||
DEFINE_OBJECT_KIND(op_node, op_tensors_broadcast)
|
||||
public:
|
||||
broadcast_node();
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API broadcast : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_broadcast);
|
||||
|
||||
const input_connector &input() const { return input_at(0); }
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
const shape_t &new_shape() const noexcept { return new_shape_; }
|
||||
|
||||
broadcast(datatype_t input_type, shape_t input_shape, shape_t new_shape);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
shape_t new_shape_;
|
||||
};
|
||||
|
||||
using broadcast = object_t<broadcast_node>;
|
||||
} // namespace nncase::ir::tensors
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../graph.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API call : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_call);
|
||||
|
||||
graph &target() const noexcept { return target_; }
|
||||
|
||||
call(graph &target);
|
||||
|
||||
input_connector &outer_connector(input_node &target_input);
|
||||
input_connector &outer_connector(input_connector &target_input);
|
||||
output_connector &outer_connector(output_node &target_output);
|
||||
output_connector &outer_connector(output_connector &target_output);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
graph &target_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API clamp : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_clamp);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
input_connector &input_low() { return input_at(1); }
|
||||
const input_connector &input_low() const { return input_at(1); }
|
||||
input_connector &input_high() { return input_at(2); }
|
||||
const input_connector &input_high() const { return input_at(2); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
const output_connector &output() const { return output_at(0); }
|
||||
|
||||
clamp(shape_t input_shape, shape_t input_low_shape, shape_t input_high_shape);
|
||||
|
||||
protected:
|
||||
bool properties_equal([[maybe_unused]] node &other) const override { return true; }
|
||||
};
|
||||
}
|
|
@ -13,25 +13,28 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../op.h"
|
||||
#include "nncase/runtime/datatypes.h"
|
||||
#include "opcode.h"
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir::tensors {
|
||||
/** @brief Split operator node */
|
||||
class NNCASE_API split_node : public op_node {
|
||||
DEFINE_OBJECT_KIND(op_node, op_tensors_split)
|
||||
public:
|
||||
split_node(int32_t axis);
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API concat : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_concat);
|
||||
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
/** @brief Get the axis of the concat expression */
|
||||
int32_t axis() const noexcept { return axis_; }
|
||||
/** @brief Set the axis of the concat expression */
|
||||
void axis(int32_t value) noexcept { axis_ = value; }
|
||||
std::span<const size_t> concat_dims() const noexcept { return concat_dims_; }
|
||||
|
||||
private:
|
||||
concat(datatype_t type, std::span<shape_t> input_shapes, int32_t axis);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
int32_t axis_;
|
||||
std::vector<size_t> concat_dims_;
|
||||
};
|
||||
|
||||
using split = object_t<split_node>;
|
||||
} // namespace nncase::ir::tensors
|
||||
}
|
|
@ -0,0 +1,141 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../debug.h"
|
||||
#include "../node.h"
|
||||
#include "../op_utils.h"
|
||||
#include <nncase/runtime/debug.h>
|
||||
#include <vector>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API constant : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_constant);
|
||||
|
||||
output_connector &output() { return output_at(0); }
|
||||
const output_connector &output() const { return output_at(0); }
|
||||
|
||||
size_t alignment() const noexcept { return alignment_; }
|
||||
void alignment(size_t value) { alignment_ = value; }
|
||||
|
||||
std::span<const std::byte> data() const noexcept { return data_; }
|
||||
datatype_t data_type() { return datatype_; }
|
||||
|
||||
template <class TShape>
|
||||
constant(datatype_t type, TShape &&shape, std::span<const std::byte> data)
|
||||
: constant(type, std::forward<TShape>(shape), data.begin(), data.end())
|
||||
{
|
||||
}
|
||||
|
||||
template <class TShape>
|
||||
constant(datatype_t type, TShape &&shape, gsl::span<const gsl::byte> data)
|
||||
: constant(type, std::forward<TShape>(shape), reinterpret_cast<const std::byte *>(data.begin()), reinterpret_cast<const std::byte *>(data.end()))
|
||||
{
|
||||
}
|
||||
|
||||
template <class TShape, class T>
|
||||
constant(datatype_t type, TShape &&shape, std::span<const T> data)
|
||||
: constant(type, std::forward<TShape>(shape), std::as_bytes(data))
|
||||
{
|
||||
}
|
||||
|
||||
template <class TShape, class T>
|
||||
constant(datatype_t type, TShape &&shape, gsl::span<const T> data)
|
||||
: constant(type, std::forward<TShape>(shape), gsl::as_bytes(data))
|
||||
{
|
||||
}
|
||||
|
||||
template <class TShape, class T>
|
||||
constant(datatype_t type, TShape &&shape, std::span<T> data)
|
||||
: constant(type, std::forward<TShape>(shape), std::as_bytes(data))
|
||||
{
|
||||
}
|
||||
|
||||
template <class TShape, class T>
|
||||
constant(datatype_t type, TShape &&shape, gsl::span<T> data)
|
||||
: constant(type, std::forward<TShape>(shape), gsl::as_bytes(data))
|
||||
{
|
||||
}
|
||||
|
||||
template <class TShape, class T>
|
||||
constant(datatype_t type, TShape &&shape, const std::vector<T> &data)
|
||||
: constant(type, std::forward<TShape>(shape), std::as_bytes(std::span<const T>(data)))
|
||||
{
|
||||
}
|
||||
|
||||
template <class TShape, class... TDataArgs>
|
||||
constant(datatype_t type, TShape &&shape, TDataArgs... data_args)
|
||||
: data_(std::forward<TDataArgs>(data_args)...), datatype_(type)
|
||||
{
|
||||
if (ir::get_bytes(type, shape) != data_.size())
|
||||
throw std::invalid_argument("Shape and data size don't match");
|
||||
add_output("output", type, std::forward<TShape>(shape), mem_rdata)
|
||||
.attributes(cnctr_attr_no_layout_strides);
|
||||
}
|
||||
|
||||
template <class TScalar>
|
||||
constant(TScalar scalar)
|
||||
: constant(to_datatype<TScalar>(), shape_t { 1 }, std::span<const TScalar>(&scalar, 1))
|
||||
{
|
||||
}
|
||||
|
||||
std::string to_string() const
|
||||
{
|
||||
auto shape = this->output().shape();
|
||||
auto dtype = this->output().type();
|
||||
auto total_size = 1;
|
||||
for (auto i : shape)
|
||||
{
|
||||
total_size *= i;
|
||||
}
|
||||
|
||||
if (total_size == 1)
|
||||
{
|
||||
switch (dtype)
|
||||
{
|
||||
case dt_int8:
|
||||
return std::to_string(*(to_cpp_type_t<dt_int8> *)data_.data());
|
||||
case dt_uint8:
|
||||
return std::to_string(*(to_cpp_type_t<dt_uint8> *)data_.data());
|
||||
#define DT_TO_STRING_CASE(dt) \
|
||||
case dt: \
|
||||
return std::to_string(*(to_cpp_type_t<dt> *)data_.data());
|
||||
|
||||
DT_TO_STRING_CASE(dt_uint32);
|
||||
DT_TO_STRING_CASE(dt_float32);
|
||||
DT_TO_STRING_CASE(dt_bfloat16);
|
||||
DT_TO_STRING_CASE(dt_int32);
|
||||
#undef DT_TO_STRING_CASE
|
||||
default:
|
||||
throw "un supported dtype to_string: " + std::string(nncase::datatype_names(dtype));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return "[...]";
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
std::vector<std::byte> data_;
|
||||
datatype_t datatype_;
|
||||
size_t alignment_ = 8;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xarray.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API conv2d : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_conv2d);
|
||||
|
||||
const input_connector &weights() const { return input_at(1); }
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
input_connector &weights() { return input_at(1); }
|
||||
input_connector &bias() { return input_at(2); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
int32_t filter_h() const noexcept { return (int32_t)weights().shape()[2]; }
|
||||
int32_t filter_w() const noexcept { return (int32_t)weights().shape()[3]; }
|
||||
int32_t input_channels() const noexcept { return (int32_t)weights().shape()[1] * groups(); }
|
||||
int32_t output_channels() const noexcept { return (int32_t)weights().shape()[0]; }
|
||||
int32_t groups() const noexcept { return groups_; }
|
||||
bool is_depthwise() const noexcept { return input_channels() == output_channels() && output_channels() == groups() && groups() != 1; }
|
||||
padding padding_h() const noexcept { return padding_h_; }
|
||||
padding padding_w() const noexcept { return padding_w_; }
|
||||
int32_t stride_h() const noexcept { return stride_h_; }
|
||||
int32_t stride_w() const noexcept { return stride_w_; }
|
||||
int32_t dilation_h() const noexcept { return dilation_h_; }
|
||||
int32_t dilation_w() const noexcept { return dilation_w_; }
|
||||
value_range<float> fused_activation() const noexcept { return fused_activation_; }
|
||||
|
||||
conv2d(shape_t input_shape, shape_t weights_shape, int32_t groups, padding padding_h, padding padding_w, int32_t stride_h, int32_t stride_w, int32_t dilation_h, int32_t dilation_w, value_range<float> fused_activation);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
int32_t groups_;
|
||||
padding padding_h_;
|
||||
padding padding_w_;
|
||||
int32_t stride_h_;
|
||||
int32_t stride_w_;
|
||||
int32_t dilation_h_;
|
||||
int32_t dilation_w_;
|
||||
value_range<float> fused_activation_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API conv2d_transpose : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_conv2d_transpose);
|
||||
|
||||
const input_connector &weights() const { return input_at(1); }
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
input_connector &weights() { return input_at(1); }
|
||||
input_connector &bias() { return input_at(2); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
int32_t filter_h() const noexcept { return (int32_t)weights().shape()[2]; }
|
||||
int32_t filter_w() const noexcept { return (int32_t)weights().shape()[3]; }
|
||||
int32_t input_channels() const noexcept { return (int32_t)weights().shape()[1] * groups(); }
|
||||
int32_t output_channels() const noexcept { return (int32_t)weights().shape()[0]; }
|
||||
int32_t groups() const noexcept { return groups_; }
|
||||
padding padding_h() const noexcept { return padding_h_; }
|
||||
padding padding_w() const noexcept { return padding_w_; }
|
||||
int32_t output_padding_h() const noexcept { return output_padding_h_; }
|
||||
int32_t output_padding_w() const noexcept { return output_padding_w_; }
|
||||
int32_t stride_h() const noexcept { return stride_h_; }
|
||||
int32_t stride_w() const noexcept { return stride_w_; }
|
||||
int32_t dilation_h() const noexcept { return dilation_h_; }
|
||||
int32_t dilation_w() const noexcept { return dilation_w_; }
|
||||
value_range<float> fused_activation() const noexcept { return fused_activation_; }
|
||||
|
||||
conv2d_transpose(shape_t input_shape, shape_t weights_shape, shape_t output_shape, int32_t groups, padding padding_h, padding padding_w, int32_t output_padding_h, int32_t output_padding_w, int32_t stride_h, int32_t stride_w, int32_t dilation_h, int32_t dilation_w, value_range<float> fused_activation);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
int32_t groups_;
|
||||
padding padding_h_;
|
||||
padding padding_w_;
|
||||
int32_t output_padding_h_;
|
||||
int32_t output_padding_w_;
|
||||
int32_t stride_h_;
|
||||
int32_t stride_w_;
|
||||
int32_t dilation_h_;
|
||||
int32_t dilation_w_;
|
||||
value_range<float> fused_activation_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API convert : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_convert);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
datatype_t new_type() const noexcept { return new_type_; }
|
||||
|
||||
convert(datatype_t input_type, shape_t input_shape, datatype_t new_type);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
datatype_t new_type_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/* Copyright 2019-2020 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API copy : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_copy);
|
||||
|
||||
const input_connector &input() const { return input_at(0); }
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
copy(datatype_t input_type, shape_t input_shape);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API cumsum : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_cumsum);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
int32_t axis() const noexcept { return axis_; }
|
||||
bool exclusive() const noexcept { return exclusive_; }
|
||||
bool reverse() const noexcept { return reverse_; }
|
||||
|
||||
cumsum(datatype_t input_type, shape_t input_shape, int32_t axis, bool exclusive = false, bool reverse = false);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
int32_t axis_;
|
||||
bool exclusive_;
|
||||
bool reverse_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API dequantize : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_dequantize);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
const quant_param_t quant_param() const noexcept { return quant_param_; }
|
||||
|
||||
dequantize(datatype_t input_type, shape_t input_shape, datatype_t output_type, quant_param_t quant_param);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
quant_param_t quant_param_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API equal : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_equal);
|
||||
|
||||
input_connector &input_a() { return input_at(0); }
|
||||
input_connector &input_b() { return input_at(1); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
equal(datatype_t input_type, shape_t input_a_shape, shape_t input_b_shape);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
};
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../graph.h"
|
||||
#include "../node.h"
|
||||
#include <nncase/codegen/nnil_builder.h>
|
||||
#include <nncase/runtime/nnil.h>
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
enum fused_unary_opcode
|
||||
{
|
||||
fu_constant,
|
||||
fu_identity,
|
||||
fu_ldx,
|
||||
fu_unary,
|
||||
fu_binary,
|
||||
fu_clamp
|
||||
};
|
||||
|
||||
struct fused_unary_arg
|
||||
{
|
||||
size_t op_id;
|
||||
};
|
||||
|
||||
struct fused_unary_constant
|
||||
{
|
||||
float value;
|
||||
};
|
||||
|
||||
struct fused_unary_identity
|
||||
{
|
||||
fused_unary_arg input;
|
||||
};
|
||||
|
||||
struct fused_unary_ldx
|
||||
{
|
||||
};
|
||||
|
||||
struct fused_unary_unary
|
||||
{
|
||||
unary_op_t unary_op;
|
||||
fused_unary_arg input;
|
||||
};
|
||||
|
||||
struct fused_unary_binary
|
||||
{
|
||||
binary_op_t binary_op;
|
||||
fused_unary_arg input_a;
|
||||
fused_unary_arg input_b;
|
||||
};
|
||||
|
||||
struct fused_unary_clamp
|
||||
{
|
||||
fused_unary_arg input;
|
||||
fused_unary_arg low;
|
||||
fused_unary_arg high;
|
||||
};
|
||||
|
||||
struct fused_unary_op
|
||||
{
|
||||
fused_unary_opcode opcode;
|
||||
|
||||
union
|
||||
{
|
||||
fused_unary_constant constant;
|
||||
fused_unary_identity identity;
|
||||
fused_unary_ldx ldx;
|
||||
fused_unary_unary unary;
|
||||
fused_unary_binary binary;
|
||||
fused_unary_clamp clamp;
|
||||
};
|
||||
|
||||
static fused_unary_op make_ldx() noexcept
|
||||
{
|
||||
fused_unary_op op { fu_ldx, {} };
|
||||
return op;
|
||||
}
|
||||
|
||||
static fused_unary_op make_constant(float value) noexcept
|
||||
{
|
||||
fused_unary_op op { fu_constant, {} };
|
||||
op.constant.value = value;
|
||||
return op;
|
||||
}
|
||||
|
||||
static fused_unary_op make_unary(unary_op_t unary_op, fused_unary_arg input) noexcept
|
||||
{
|
||||
fused_unary_op op { fu_unary, {} };
|
||||
op.unary = { unary_op, input };
|
||||
return op;
|
||||
}
|
||||
|
||||
static fused_unary_op make_binary(binary_op_t binary_op, fused_unary_arg input_a, fused_unary_arg input_b) noexcept
|
||||
{
|
||||
fused_unary_op op { fu_binary, {} };
|
||||
op.binary = { binary_op, input_a, input_b };
|
||||
return op;
|
||||
}
|
||||
|
||||
static fused_unary_op make_clamp(fused_unary_arg input, fused_unary_arg low, fused_unary_arg high) noexcept
|
||||
{
|
||||
fused_unary_op op { fu_clamp, {} };
|
||||
op.clamp = { input, low, high };
|
||||
return op;
|
||||
}
|
||||
|
||||
static fused_unary_op make_identity(fused_unary_arg input) noexcept
|
||||
{
|
||||
fused_unary_op op { fu_identity, {} };
|
||||
op.identity = { input };
|
||||
return op;
|
||||
}
|
||||
};
|
||||
|
||||
NNCASE_API std::vector<fused_unary_op> concat_subgraph(const std::vector<fused_unary_op> &src1, const std::vector<fused_unary_op> &src2);
|
||||
|
||||
class NNCASE_API fused_unary : public node
|
||||
{
|
||||
public:
|
||||
static void compile_graph(const std::vector<fused_unary_op> &subgraph, codegen::nnil_builder &builder);
|
||||
|
||||
DEFINE_NODE_OPCODE(op_fused_unary);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
std::vector<fused_unary_op> &subgraph() noexcept { return subgraph_; }
|
||||
|
||||
fused_unary(std::vector<fused_unary_op> subgraph, shape_t in_shape);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
std::vector<fused_unary_op> subgraph_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API gather : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_gather);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
input_connector &indices() { return input_at(1); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
int32_t axis() const noexcept { return axis_; }
|
||||
|
||||
gather(datatype_t in_type, shape_t input_shape, shape_t indices_shape, shape_t output_shape, int32_t axis);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
int32_t axis_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API gather_nd : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_gather_nd);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
input_connector &indices() { return input_at(1); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
int32_t batch_dims() const noexcept { return batch_dims_; }
|
||||
|
||||
gather_nd(datatype_t type, shape_t input_shape, shape_t indices_shape, shape_t output_shape, int32_t batch_dims);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
int32_t batch_dims_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API hardmax : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_hardmax);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
int32_t axis() const noexcept { return axis_; }
|
||||
hardmax(datatype_t input_type, shape_t input_shape, int32_t axis);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
int32_t axis_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API lstm : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_lstm);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
input_connector &w_xc() { return input_at(1); }
|
||||
input_connector &b_xc() { return input_at(2); }
|
||||
input_connector &w_rc() { return input_at(3); }
|
||||
input_connector &b_rc() { return input_at(4); }
|
||||
input_connector &initial_h() { return input_at(5); }
|
||||
input_connector &initial_c() { return input_at(6); }
|
||||
input_connector &w_static() { return input_at(7); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
output_connector &output_h() { return output_at(1); }
|
||||
output_connector &output_c() { return output_at(2); }
|
||||
|
||||
int32_t num_output() const noexcept { return num_output_; }
|
||||
bool has_static() const noexcept { return has_static_; }
|
||||
std::string framework() const noexcept { return framework_; }
|
||||
|
||||
lstm(shape_t input_shape, shape_t w_xc_shape, shape_t b_xc_shape, shape_t w_rc_shape, shape_t b_rc_shape, shape_t initial_h_shape, shape_t initial_c_shape, int32_t num_output, bool has_static, std::string framework);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
int32_t num_output_;
|
||||
bool has_static_;
|
||||
std::string framework_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API matmul : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_matmul);
|
||||
|
||||
input_connector &input_a() { return input_at(0); }
|
||||
input_connector &input_b() { return input_at(1); }
|
||||
input_connector &bias() { return input_at(2); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
value_range<float> fused_activation() const noexcept { return fused_activation_; }
|
||||
|
||||
matmul(shape_t input_a_shape, shape_t input_b_shape, value_range<float> fused_activation);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
value_range<float> fused_activation_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API onehot : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_onehot);
|
||||
|
||||
input_connector &indices() { return input_at(0); }
|
||||
input_connector &depth() { return input_at(1); }
|
||||
input_connector &on_value() { return input_at(2); }
|
||||
input_connector &off_value() { return input_at(3); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
int32_t axis() const noexcept { return axis_; }
|
||||
onehot_mode_t mode() const noexcept { return mode_; }
|
||||
|
||||
onehot(datatype_t type, shape_t indices_shape, shape_t output_shape, int32_t axis, onehot_mode_t mode = onehot_mode_t::onehot_normal);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
int32_t axis_;
|
||||
onehot_mode_t mode_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API pad : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_pad);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
const xt::svector<padding> &paddings() const noexcept { return paddings_; }
|
||||
pad_mode_t pad_mode() const noexcept { return pad_mode_; }
|
||||
const scalar &pad_value() const noexcept { return pad_value_; }
|
||||
|
||||
pad(datatype_t type, shape_t input_shape, xt::svector<padding> paddings, pad_mode_t pad_mode, scalar pad_value);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
xt::svector<padding> paddings_;
|
||||
pad_mode_t pad_mode_;
|
||||
scalar pad_value_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API quantize : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_quantize);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
const quant_param_t quant_param() const noexcept { return quant_param_; }
|
||||
|
||||
quantize(datatype_t input_type, shape_t input_shape, datatype_t output_type, quant_param_t quant_param);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
quant_param_t quant_param_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API random_normal : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_random_normal);
|
||||
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
float mean() const noexcept { return mean_; }
|
||||
float std() const noexcept { return std_; }
|
||||
float seed() const noexcept { return seed_; }
|
||||
random_normal(datatype_t output_type, shape_t output_shape, float mean, float std, float seed);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
float mean_;
|
||||
float std_;
|
||||
float seed_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API random_uniform : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_random_uniform);
|
||||
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
float low() const noexcept { return low_; }
|
||||
float high() const noexcept { return high_; }
|
||||
float seed() const noexcept { return seed_; }
|
||||
random_uniform(datatype_t output_type, shape_t output_shape, float low, float high, float seed);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
float low_;
|
||||
float high_;
|
||||
float seed_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API reduce : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_reduce);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
reduce_op_t reduce_op() const noexcept { return reduce_op_; }
|
||||
const axis_t &axis() const noexcept { return axis_; }
|
||||
float init_value() const noexcept { return init_value_; }
|
||||
bool keep_dims() const noexcept { return keep_dims_; }
|
||||
|
||||
reduce(reduce_op_t reduce_op, shape_t input_shape, axis_t axis, float init_value, bool keep_dims);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
reduce_op_t reduce_op_;
|
||||
axis_t axis_;
|
||||
float init_value_;
|
||||
bool keep_dims_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API reduce_arg : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_reduce_arg);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
reduce_arg_op_t reduce_arg_op() const noexcept { return reduce_arg_op_; }
|
||||
int32_t axis() const noexcept { return axis_; }
|
||||
bool keep_dims() const noexcept { return keep_dims_; }
|
||||
bool select_last_index() const noexcept { return select_last_index_; }
|
||||
|
||||
reduce_arg(reduce_arg_op_t op, datatype_t input_type, shape_t input_shape, datatype_t output_type, int32_t axis, bool keep_dims = true, bool select_last_index = false);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
reduce_arg_op_t reduce_arg_op_;
|
||||
int32_t axis_;
|
||||
bool keep_dims_;
|
||||
bool select_last_index_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API reduce_prod : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_reduce_prod);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
const axis_t &axis() const noexcept { return axis_; }
|
||||
bool keep_dims() const noexcept { return keep_dims_; }
|
||||
|
||||
reduce_prod(shape_t input_shape, axis_t axis, bool keep_dims);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
axis_t axis_;
|
||||
bool keep_dims_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API reduce_window2d : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_reduce_window2d);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
reduce_op_t reduce_op() const noexcept { return reduce_op_; }
|
||||
float init_value() const noexcept { return init_value_; }
|
||||
int32_t filter_h() const noexcept { return filter_h_; }
|
||||
int32_t filter_w() const noexcept { return filter_w_; }
|
||||
padding padding_h() const noexcept { return padding_h_; }
|
||||
padding padding_w() const noexcept { return padding_w_; }
|
||||
int32_t stride_h() const noexcept { return stride_h_; }
|
||||
int32_t stride_w() const noexcept { return stride_w_; }
|
||||
int32_t dilation_h() const noexcept { return dilation_h_; }
|
||||
int32_t dilation_w() const noexcept { return dilation_w_; }
|
||||
value_range<float> fused_activation() const noexcept { return fused_activation_; }
|
||||
bool ceil_mode() const noexcept { return ceil_mode_; }
|
||||
bool count_include_pad() const noexcept { return count_include_pad_; }
|
||||
std::vector<int32_t> padding_h_w_after() const noexcept { return padding_h_w_after_; }
|
||||
bool strict_inside_input() const noexcept { return strict_inside_input_; }
|
||||
|
||||
reduce_window2d(reduce_op_t reduce_op, shape_t input_shape, float init_value, int32_t filter_h, int32_t filter_w, padding padding_h, padding padding_w, int32_t stride_h, int32_t stride_w, int32_t dilation_h, int32_t dilation_w, value_range<float> fused_activation, bool ceil_mode = false, bool count_include_pad = false, std::vector<int32_t> padding_h_w_after = { 0, 0 }, bool strict_inside_input = false);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
reduce_op_t reduce_op_;
|
||||
float init_value_;
|
||||
int32_t filter_h_;
|
||||
int32_t filter_w_;
|
||||
padding padding_h_;
|
||||
padding padding_w_;
|
||||
int32_t stride_h_;
|
||||
int32_t stride_w_;
|
||||
int32_t dilation_h_;
|
||||
int32_t dilation_w_;
|
||||
value_range<float> fused_activation_;
|
||||
bool ceil_mode_;
|
||||
bool count_include_pad_;
|
||||
std::vector<int32_t> padding_h_w_after_;
|
||||
bool strict_inside_input_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API resize_image : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_resize_image);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
const std::array<int32_t, 2> &new_size() const noexcept { return new_size_; }
|
||||
image_resize_mode_t mode() const noexcept { return mode_; }
|
||||
bool align_corners() const noexcept { return align_corners_; }
|
||||
bool half_pixel_centers() const noexcept { return half_pixel_centers_; }
|
||||
resize_image(datatype_t type, image_resize_mode_t mode, shape_t input_shape, std::array<int32_t, 2> new_size,
|
||||
bool align_corners = false, bool half_pixel_centers = false);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
std::array<int32_t, 2> new_size_;
|
||||
image_resize_mode_t mode_;
|
||||
bool align_corners_;
|
||||
bool half_pixel_centers_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API roi_align : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_roi_align);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
input_connector &rois() { return input_at(1); }
|
||||
input_connector &batch_indices() { return input_at(2); }
|
||||
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
roi_align_mode_t mode() const noexcept { return mode_; }
|
||||
const float &spatial_scale() const noexcept { return spatial_scale_; }
|
||||
const int64_t &sampling_ratio() const noexcept { return sampling_ratio_; }
|
||||
|
||||
roi_align(datatype_t input_type, shape_t input_shape, shape_t rois, shape_t batch_indices, roi_align_mode_t mode,
|
||||
float spatial_scale, int64_t output_height, int64_t output_width, int64_t sampling_ratio);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
roi_align_mode_t mode_;
|
||||
float spatial_scale_;
|
||||
int64_t sampling_ratio_;
|
||||
};
|
||||
}
|
|
@ -13,18 +13,23 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "egraph.h"
|
||||
#include "../node.h"
|
||||
|
||||
namespace nncase::ir::transforms {
|
||||
class NNCASE_API egraph_pattern {
|
||||
public:
|
||||
virtual bool match() = 0;
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API sigmoid : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_sigmoid);
|
||||
|
||||
private:
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
sigmoid(datatype_t input_type, shape_t input_shape);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
namespace patterns {
|
||||
class NNCASE_API wildcard : public egraph_pattern {};
|
||||
} // namespace patterns
|
||||
|
||||
} // namespace nncase::ir::transforms
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API slice : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_slice);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
const axis_t &begin() const noexcept { return begin_; }
|
||||
const axis_t &end() const noexcept { return end_; }
|
||||
const axis_t &strides() const noexcept { return strides_; }
|
||||
int32_t begin_mask() const noexcept { return begin_mask_; }
|
||||
int32_t end_mask() const noexcept { return end_mask_; }
|
||||
int32_t ellipsis_mask() const noexcept { return ellipsis_mask_; }
|
||||
int32_t new_axis_mask() const noexcept { return new_axis_mask_; }
|
||||
|
||||
slice(datatype_t type, shape_t input_shape, axis_t begin, axis_t end);
|
||||
slice(datatype_t type, shape_t input_shape, axis_t begin, axis_t end, axis_t strides, int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask, int32_t new_axis_mask);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
axis_t begin_;
|
||||
axis_t end_;
|
||||
axis_t strides_;
|
||||
int32_t begin_mask_;
|
||||
int32_t end_mask_;
|
||||
int32_t ellipsis_mask_;
|
||||
int32_t new_axis_mask_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xarray.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API space_to_batch : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_space_to_batch);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
int32_t block_size_h() const noexcept { return block_size_h_; }
|
||||
int32_t block_size_w() const noexcept { return block_size_w_; }
|
||||
padding padding_h() const noexcept { return padding_h_; }
|
||||
padding padding_w() const noexcept { return padding_w_; }
|
||||
const scalar &pad_value() const noexcept { return pad_value_; }
|
||||
|
||||
space_to_batch(datatype_t input_type, shape_t input_shape, int32_t block_shape_h, int32_t block_shape_w, padding padding_h, padding padding_w, scalar pad_value);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
int32_t block_size_h_;
|
||||
int32_t block_size_w_;
|
||||
padding padding_h_;
|
||||
padding padding_w_;
|
||||
scalar pad_value_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API split : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_split);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
std::vector<size_t> indices_or_sections() const noexcept { return indices_or_sections_; }
|
||||
int32_t axis() const noexcept { return axis_; }
|
||||
bool is_indices() const noexcept { return is_indices_; }
|
||||
|
||||
split(datatype_t type, shape_t inputs_shape, std::vector<shape_t> outputs_shape, std::vector<size_t> indices_or_sections, int32_t axis, bool is_indices);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
std::vector<size_t> indices_or_sections_;
|
||||
int32_t axis_;
|
||||
bool is_indices_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API table_lookup1d : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_table_lookup1d);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
input_connector &table() { return input_at(1); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
table_lookup1d(datatype_t type, shape_t input_shape, size_t table_size);
|
||||
|
||||
protected:
|
||||
bool properties_equal([[maybe_unused]] node &other) const override { return true; }
|
||||
};
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API take : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_take);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
input_connector &indices() { return input_at(1); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
int32_t axis() const noexcept { return axis_; }
|
||||
const std::string &mode() const noexcept { return mode_; }
|
||||
|
||||
take(datatype_t type, shape_t input_shape, shape_t indices_shape, shape_t output_shape, int32_t axis, std::string mode);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
int32_t axis_;
|
||||
std::string mode_;
|
||||
};
|
||||
}
|
|
@ -13,24 +13,26 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../call.h"
|
||||
#include "../op.h"
|
||||
#include "nncase/runtime/datatypes.h"
|
||||
#include "opcode.h"
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir::nn {
|
||||
/** @brief Sigmoid operator node */
|
||||
class NNCASE_API sigmoid_node : public op_node {
|
||||
DEFINE_OBJECT_KIND(op_node, op_nn_sigmoid)
|
||||
public:
|
||||
sigmoid_node();
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API ternary : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_ternary);
|
||||
|
||||
/** @brief Get the input the sigmoid expression */
|
||||
const connector_info &input() const noexcept { return parameter_at(0); }
|
||||
input_connector &input_a() { return input_at(0); }
|
||||
input_connector &input_b() { return input_at(1); }
|
||||
input_connector &input_c() { return input_at(2); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
type infer_invoke_result_type(type_infer_context &context) override;
|
||||
ternary(datatype_t input_a_type, datatype_t input_bc_type, shape_t input_a_shape, shape_t input_b_shape, shape_t input_c_shape);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
/** @brief Sigmoid expression */
|
||||
using sigmoid = object_t<sigmoid_node>;
|
||||
} // namespace nncase::ir::nn
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API topk : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_topk);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
|
||||
// output largest values
|
||||
output_connector &output_a() { return output_at(0); }
|
||||
|
||||
// output indices of largest values
|
||||
output_connector &output_b() { return output_at(1); }
|
||||
|
||||
const int64_t &k() const noexcept { return k_; }
|
||||
const int32_t &axis() const noexcept { return axis_; }
|
||||
bool largest() const noexcept { return largest_; }
|
||||
bool sorted() const noexcept { return sorted_; }
|
||||
|
||||
topk(datatype_t input_type, shape_t input_shape, int64_t k, int32_t axis, bool largest, bool sorted);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
int64_t k_;
|
||||
int32_t axis_;
|
||||
bool largest_;
|
||||
bool sorted_;
|
||||
};
|
||||
}
|
|
@ -13,17 +13,27 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../op.h"
|
||||
#include "nncase/runtime/datatypes.h"
|
||||
#include "opcode.h"
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir::tensors {
|
||||
/** @brief Transpose operator node */
|
||||
class NNCASE_API transpose_node : public op_node {
|
||||
DEFINE_OBJECT_KIND(op_node, op_tensors_transpose)
|
||||
public:
|
||||
transpose_node();
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API transpose : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_transpose);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
const axis_t &perm() const noexcept { return perm_; }
|
||||
|
||||
transpose(datatype_t type, shape_t input_shape, axis_t perm);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
axis_t perm_;
|
||||
};
|
||||
|
||||
using transpose = object_t<transpose_node>;
|
||||
} // namespace nncase::ir::tensors
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API trilu : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_trilu);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
bool upper() const noexcept { return upper_; }
|
||||
const int64_t &k() const noexcept { return k_; }
|
||||
|
||||
trilu(datatype_t input_type, shape_t input_shape, bool upper, int64_t k);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
bool upper_;
|
||||
int64_t k_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../node.h"
|
||||
#include <xtensor/xtensor.hpp>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API unary : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_unary);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
unary_op_t unary_op() const noexcept { return unary_op_; }
|
||||
|
||||
unary(unary_op_t unary_op, shape_t input_shape);
|
||||
|
||||
protected:
|
||||
bool properties_equal(node &other) const override;
|
||||
|
||||
private:
|
||||
unary_op_t unary_op_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "node.h"
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API input_node : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_input_node);
|
||||
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
template <class TShape>
|
||||
input_node(datatype_t type, TShape &&shape)
|
||||
{
|
||||
add_output("output", type, std::forward<TShape>(shape), mem_input)
|
||||
.attributes(cnctr_attr_no_layout_strides);
|
||||
}
|
||||
|
||||
protected:
|
||||
bool properties_equal([[maybe_unused]] node &other) const override { return true; }
|
||||
};
|
||||
|
||||
class NNCASE_API output_node : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_output_node);
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
|
||||
template <class TShape>
|
||||
output_node(datatype_t type, TShape &&shape)
|
||||
{
|
||||
attributes(attributes() | node_attr_skip_constant_folding);
|
||||
add_input("input", type, std::forward<TShape>(shape));
|
||||
}
|
||||
|
||||
protected:
|
||||
bool properties_equal([[maybe_unused]] node &other) const override { return true; }
|
||||
};
|
||||
|
||||
class NNCASE_API ignore_node : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_ignore_node);
|
||||
~ignore_node() = default;
|
||||
|
||||
input_connector &input() { return input_at(0); }
|
||||
|
||||
template <class TShape>
|
||||
ignore_node(datatype_t type, TShape &&shape)
|
||||
{
|
||||
add_input("input", type, std::forward<TShape>(shape));
|
||||
}
|
||||
|
||||
protected:
|
||||
bool properties_equal([[maybe_unused]] node &other) const override { return true; }
|
||||
};
|
||||
|
||||
class NNCASE_API uninitialized : public node
|
||||
{
|
||||
public:
|
||||
DEFINE_NODE_OPCODE(op_uninitialized);
|
||||
|
||||
output_connector &output() { return output_at(0); }
|
||||
|
||||
template <class TShape>
|
||||
uninitialized(datatype_t type, TShape &&shape, memory_location_t memory_location = mem_data)
|
||||
{
|
||||
add_output("output", type, std::forward<TShape>(shape), memory_location);
|
||||
}
|
||||
|
||||
protected:
|
||||
bool properties_equal([[maybe_unused]] node &other) const override { return true; }
|
||||
};
|
||||
}
|
|
@ -46,6 +46,7 @@ class NNCASE_API quantizer
|
|||
|
||||
void record(std::span<const float> data);
|
||||
void record(std::span<const bfloat16> data);
|
||||
void record(std::span<const half> data);
|
||||
void finish();
|
||||
value_range<float> optimal_range() const noexcept { return optimal_range_; }
|
||||
|
||||
|
@ -90,27 +91,30 @@ public:
|
|||
}
|
||||
else
|
||||
{
|
||||
if (range.min < -1e3)
|
||||
range.min = -1e3;
|
||||
if (range.max > 1e3)
|
||||
range.max = 1e3;
|
||||
if (range.max < 0)
|
||||
range.max = 0;
|
||||
if (range.min > 0)
|
||||
range.min = 0;
|
||||
|
||||
auto r = range.max - range.min;
|
||||
if (r == 0)
|
||||
r = 0.1f;
|
||||
else if (r < 0.01f)
|
||||
r = 0.01f;
|
||||
range.max = range.min + r;
|
||||
|
||||
if (range.max < 0)
|
||||
range.max = 0;
|
||||
if (range.min > 0)
|
||||
range.min = 0;
|
||||
}
|
||||
|
||||
return range;
|
||||
}
|
||||
|
||||
static quant_param_t get_quant_param(value_range<float> range, int32_t bits);
|
||||
enum class quant_mode
|
||||
{
|
||||
unsigned_mode,
|
||||
signed_symmetric_mode,
|
||||
signed_asymmetric_mode
|
||||
};
|
||||
|
||||
static quant_param_t get_quant_param(value_range<float> range, int32_t bits, quant_mode qm);
|
||||
static fixed_mul get_fixed_mul(float value, int32_t max_bits, uint8_t max_shift, bool is_signed);
|
||||
|
||||
void record(ir::output_connector &connector, value_range<float> range);
|
||||
|
@ -118,6 +122,13 @@ public:
|
|||
bool has_record(ir::output_connector &connector) const;
|
||||
void record(ir::output_connector &connector, std::span<const float> data);
|
||||
void record(ir::output_connector &connector, std::span<const bfloat16> data);
|
||||
void record(ir::output_connector &connector, std::span<const half> data);
|
||||
void record_buffers(ir::output_connector &connector, std::span<const float> data);
|
||||
void record_buffers(ir::output_connector &connector, std::span<const bfloat16> data);
|
||||
void record_buffers(ir::output_connector &connector, std::span<const half> data);
|
||||
void record_quant_buffers(ir::output_connector &connector, std::span<const float> data);
|
||||
void record_quant_buffers(ir::output_connector &connector, std::span<const bfloat16> data);
|
||||
void record_quant_buffers(ir::output_connector &connector, std::span<const half> data);
|
||||
value_range<float> get(ir::output_connector &connector) const;
|
||||
void broadcast_output(ir::graph &graph, const std::unordered_set<node_opcode> &ops);
|
||||
void broadcast_output(ir::node &node, const value_range<float> &range, const std::unordered_set<node_opcode> &ops);
|
||||
|
@ -125,6 +136,10 @@ public:
|
|||
void end_collect_distribution(std::function<void(size_t cnt, size_t total)> progress);
|
||||
size_t histograms_count() const noexcept { return histograms_.size(); }
|
||||
void end_sample() { has_record_.clear(); }
|
||||
std::unordered_map<ir::output_connector *, std::vector<float>> output_buffers() const noexcept { return output_buffers_; }
|
||||
std::vector<ir::output_connector *> quant_buffers_insert_order() const noexcept { return quant_buffers_insert_order_; }
|
||||
std::unordered_map<ir::output_connector *, value_range<float>> ranges() const noexcept { return quant_ranges_; }
|
||||
std::vector<ir::output_connector *> ranges_insert_order() const noexcept { return ranges_insert_order_; }
|
||||
|
||||
private:
|
||||
calibrate_method cali_method_;
|
||||
|
@ -133,5 +148,8 @@ private:
|
|||
std::unordered_map<ir::output_connector *, value_range<float>> quant_ranges_;
|
||||
std::unordered_map<ir::output_connector *, histogram> histograms_;
|
||||
std::unordered_map<ir::output_connector *, bool> has_record_;
|
||||
std::unordered_map<ir::output_connector *, std::vector<float>> output_buffers_;
|
||||
std::vector<ir::output_connector *> quant_buffers_insert_order_;
|
||||
std::vector<ir::output_connector *> ranges_insert_order_;
|
||||
};
|
||||
}
|
||||
|
|
|
@ -1,166 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../object.h"
|
||||
#include <algorithm>
|
||||
#include <nncase/runtime/datatypes.h>
|
||||
#include <nncase/runtime/small_vector.hpp>
|
||||
#include <optional>
|
||||
#include <range/v3/algorithm/any_of.hpp>
|
||||
#include <range/v3/range/concepts.hpp>
|
||||
|
||||
namespace nncase::ir {
|
||||
struct unknown_dim_t {};
|
||||
|
||||
inline constexpr unknown_dim_t unknown_dim;
|
||||
|
||||
enum dim_kind_t { dim_fixed = 0, dim_unknown = 1 };
|
||||
|
||||
using dim_value_t = int64_t;
|
||||
|
||||
/** @brief Dimension */
|
||||
struct dim_t {
|
||||
/** @brief Initialize an unknown dim */
|
||||
constexpr dim_t(unknown_dim_t = unknown_dim) noexcept
|
||||
: kind(dim_unknown), value(0) {}
|
||||
|
||||
/** @brief Initialize an fixed dim */
|
||||
constexpr dim_t(dim_value_t value) noexcept
|
||||
: kind(dim_fixed), value(value) {}
|
||||
|
||||
/** @brief Is this a fixed dim */
|
||||
bool is_fixed() const noexcept { return kind == dim_fixed; }
|
||||
/** @brief Is this an unknown dim */
|
||||
bool is_unknown() const noexcept { return kind == dim_unknown; }
|
||||
|
||||
dim_value_t fixed_value() const {
|
||||
assert(is_fixed());
|
||||
return value;
|
||||
}
|
||||
|
||||
dim_kind_t kind;
|
||||
dim_value_t value;
|
||||
};
|
||||
|
||||
struct scalar_shape_t {};
|
||||
|
||||
inline constexpr scalar_shape_t scalar_shape;
|
||||
|
||||
struct unranked_shape_t {};
|
||||
|
||||
inline constexpr unranked_shape_t unranked_shape;
|
||||
|
||||
struct invalid_shape_t {};
|
||||
|
||||
inline constexpr invalid_shape_t invalid_shape;
|
||||
|
||||
/** @brief Shape type */
|
||||
class NNCASE_API shape_t {
|
||||
enum shape_kind_t {
|
||||
shape_kind_fixed,
|
||||
shape_kind_has_unknown_dim,
|
||||
shape_kind_unranked,
|
||||
shape_kind_invalid
|
||||
};
|
||||
|
||||
public:
|
||||
using value_type = dim_t;
|
||||
|
||||
/** @brief Initialize a scalar shape */
|
||||
shape_t(scalar_shape_t) noexcept : kind_(shape_kind_fixed) {}
|
||||
|
||||
/** @brief Initialize an unranked shape */
|
||||
shape_t(unranked_shape_t) noexcept : kind_(shape_kind_unranked) {}
|
||||
|
||||
/** @brief Initialize an invalid shape */
|
||||
shape_t(invalid_shape_t) noexcept : kind_(shape_kind_invalid) {}
|
||||
|
||||
/** @brief Initialize a ranked shape */
|
||||
template <ranges::range R>
|
||||
shape_t(R dims) : kind_(kind_of(dims)), dims_(dims.begin(), dims.end()) {}
|
||||
|
||||
/** @brief Initialize a fixed shape */
|
||||
shape_t(std::initializer_list<dim_value_t> dims) : kind_(shape_kind_fixed) {
|
||||
dims_.reserve(dims.size());
|
||||
std::transform(dims.begin(), dims.end(), std::back_inserter(dims_),
|
||||
[](dim_value_t dim) -> dim_t { return dim; });
|
||||
}
|
||||
|
||||
/** @brief Get kind */
|
||||
shape_kind_t kind() const noexcept { return kind_; }
|
||||
|
||||
/** @brief Is this a fixed shape */
|
||||
bool is_fixed() const noexcept { return kind() == shape_kind_fixed; }
|
||||
/** @brief Is this a scalar */
|
||||
bool is_scalar() const noexcept {
|
||||
return kind() == shape_kind_fixed && dims_.empty();
|
||||
}
|
||||
/** @brief Is this an ranked shape */
|
||||
bool is_ranked() const noexcept { return is_fixed() || has_unknown_dim(); }
|
||||
/** @brief Is this an unranked shape */
|
||||
bool is_unranked() const noexcept { return kind() == shape_kind_unranked; }
|
||||
/** @brief Has at least one unknown dimension */
|
||||
bool has_unknown_dim() const noexcept {
|
||||
return kind() == shape_kind_has_unknown_dim;
|
||||
}
|
||||
/** @brief Is this an invalid shape */
|
||||
bool is_invalid() const noexcept { return kind() == shape_kind_invalid; }
|
||||
|
||||
/** @brief Get dimensions */
|
||||
std::span<const dim_t> dims() const noexcept { return dims_; }
|
||||
|
||||
/** @brief Get rank */
|
||||
std::optional<size_t> rank() const noexcept {
|
||||
return is_ranked() ? std::make_optional(dims_.size()) : std::nullopt;
|
||||
}
|
||||
|
||||
auto begin() const noexcept { return dims_.cbegin(); }
|
||||
auto end() const noexcept { return dims_.cend(); }
|
||||
|
||||
const dim_t &front() const { return dims_.front(); }
|
||||
const dim_t &back() const { return dims_.back(); }
|
||||
|
||||
/** @brief Get dimension */
|
||||
const dim_t &dim(size_t index) const { return dims_.at(index); }
|
||||
const dim_t &operator[](size_t index) const { return dim(index); }
|
||||
|
||||
/** @brief Set dimension */
|
||||
void dim(size_t index, dim_t value);
|
||||
|
||||
/** @brief Place a new dim at back */
|
||||
void push_back(dim_t value);
|
||||
const dim_t &emplace_back(dim_t value);
|
||||
/** @brief Place a new dim */
|
||||
const dim_t *emplace(const dim_t *position, dim_t value);
|
||||
|
||||
/** @brief Remove the dim at back */
|
||||
void pop_back();
|
||||
|
||||
private:
|
||||
template <ranges::range R> static shape_kind_t kind_of(R &&range) noexcept {
|
||||
return ranges::any_of(range,
|
||||
[](const dim_t &dim) { return dim.is_unknown(); })
|
||||
? shape_kind_has_unknown_dim
|
||||
: shape_kind_fixed;
|
||||
}
|
||||
|
||||
void update_kind(shape_kind_t before_kind,
|
||||
dim_kind_t new_dim_kind) noexcept;
|
||||
|
||||
private:
|
||||
shape_kind_t kind_;
|
||||
itlib::small_vector<dim_t, 4> dims_;
|
||||
};
|
||||
} // namespace nncase::ir
|
|
@ -1,45 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../op.h"
|
||||
#include "nncase/runtime/datatypes.h"
|
||||
#include "opcode.h"
|
||||
|
||||
namespace nncase::ir::tensors {
|
||||
/** @brief Cast operator node */
|
||||
class NNCASE_API cast_node : public op_node {
|
||||
DEFINE_OBJECT_KIND(op_node, op_tensors_cast)
|
||||
public:
|
||||
cast_node(datatype_t new_type);
|
||||
|
||||
/** @brief Get the new type of the cast expression */
|
||||
datatype_t new_type() const noexcept { return new_type_; }
|
||||
/** @brief Set the new type of the cast expression */
|
||||
void new_type(datatype_t value) noexcept { new_type_ = value; }
|
||||
|
||||
/** @brief Get the input the unary expression */
|
||||
const connector_info &input() const noexcept { return parameter_at(0); }
|
||||
|
||||
type infer_invoke_result_type(type_infer_context &context) override;
|
||||
|
||||
private:
|
||||
datatype_t new_type_;
|
||||
};
|
||||
|
||||
class cast : public object_t<cast_node> {
|
||||
public:
|
||||
NNCASE_API cast(datatype_t new_type);
|
||||
};
|
||||
} // namespace nncase::ir::tensors
|
|
@ -1,22 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../call.h"
|
||||
#include "broadcast.h"
|
||||
#include "cast.h"
|
||||
|
||||
namespace nncase::ir::F {
|
||||
NNCASE_API call broadcast(expr input, expr shape);
|
||||
} // namespace nncase::ir::F
|
|
@ -1,29 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../op.h"
|
||||
#include "nncase/runtime/datatypes.h"
|
||||
#include "opcode.h"
|
||||
|
||||
namespace nncase::ir::tensors {
|
||||
/** @brief Gather operator node */
|
||||
class NNCASE_API gather_node : public op_node {
|
||||
DEFINE_OBJECT_KIND(op_node, op_tensors_gather)
|
||||
public:
|
||||
gather_node();
|
||||
};
|
||||
|
||||
using gather = object_t<gather_node>;
|
||||
} // namespace nncase::ir::tensors
|
|
@ -1,29 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../op.h"
|
||||
#include "nncase/runtime/datatypes.h"
|
||||
#include "opcode.h"
|
||||
|
||||
namespace nncase::ir::tensors {
|
||||
/** @brief GatherND operator node */
|
||||
class NNCASE_API gather_nd_node : public op_node {
|
||||
DEFINE_OBJECT_KIND(op_node, op_tensors_gather_nd)
|
||||
public:
|
||||
gather_nd_node();
|
||||
};
|
||||
|
||||
using gather_nd = object_t<gather_nd_node>;
|
||||
} // namespace nncase::ir::tensors
|
|
@ -1,10 +0,0 @@
|
|||
DEFINE_OPCODE(tensors, broadcast, Broadcast, 0x012001)
|
||||
DEFINE_OPCODE(tensors, concat, Concat, 0x012002)
|
||||
DEFINE_OPCODE(tensors, cast, Cast, 0x012003)
|
||||
DEFINE_OPCODE(tensors, copy, Copy, 0x012004)
|
||||
DEFINE_OPCODE(tensors, gather, Gather, 0x012005)
|
||||
DEFINE_OPCODE(tensors, gather_nd, GatherND, 0x012006)
|
||||
DEFINE_OPCODE(tensors, reshape, Reshape, 0x012007)
|
||||
DEFINE_OPCODE(tensors, slice, Slice, 0x012008)
|
||||
DEFINE_OPCODE(tensors, split, Split, 0x012009)
|
||||
DEFINE_OPCODE(tensors, transpose, Transpose, 0x01200A)
|
|
@ -1,25 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../../object_kind.h"
|
||||
|
||||
namespace nncase::ir::tensors {
|
||||
#define DEFINE_OPCODE(dialect, id, name, value) \
|
||||
NNCASE_INLINE_VAR constexpr object_kind op_##dialect##_##id{value, #name};
|
||||
|
||||
#include "opcode.def"
|
||||
|
||||
#undef DEFINE_OPCODE
|
||||
} // namespace nncase::ir::tensors
|
|
@ -1,29 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../op.h"
|
||||
#include "nncase/runtime/datatypes.h"
|
||||
#include "opcode.h"
|
||||
|
||||
namespace nncase::ir::tensors {
|
||||
/** @brief Reshape operator node */
|
||||
class NNCASE_API reshape_node : public op_node {
|
||||
DEFINE_OBJECT_KIND(op_node, op_tensors_reshape)
|
||||
public:
|
||||
reshape_node();
|
||||
};
|
||||
|
||||
using reshape = object_t<reshape_node>;
|
||||
} // namespace nncase::ir::tensors
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue