add csharp runtime export

pull/517/head
郑启航 2022-01-19 15:34:58 +08:00
parent a736e2df03
commit 7fcd4efc4e
425 changed files with 19068 additions and 8343 deletions

View File

@ -32,6 +32,8 @@ project(nncase
option(ENABLE_OPENMP "OpenMP support" ON)
option(ENABLE_HALIDE "halide kernels support" ON)
option(BUILD_PYTHON_BINDING "Build python binding" ON)
option(BUILD_CSHARP_BINDING "Build csharp binding" ON)
option(BUILD_BENCHMARK "Build benchmark programs" ON)
option(BUILD_TESTING "Build test programs" OFF)
if (BUILDING_RUNTIME)
@ -57,7 +59,7 @@ else() # in user space
message(STATUS "Auto Cmake Conan Installation")
include(${CMAKE_SOURCE_DIR}/cmake/conan.cmake)
conan_cmake_run(CONANFILE conanfile.py
BASIC_SETUP TARGETS
BASIC_SETUP
OPTIONS ${CONAN_OPTS}
SETTINGS ${CONAN_SETTINGS}
BUILD missing)
@ -74,13 +76,15 @@ if (BUILDING_RUNTIME)
if (MSVC)
add_definitions(/D_CRT_SECURE_NO_WARNINGS /DNOMINMAX)
add_compile_options(/wd4267 /wd4251 /FC /utf-8 /W3 /WX)
add_compile_options(/wd4267 /wd4251 /wd4244 /FC /utf-8 /W3 /WX)
else()
add_compile_options(-Wall -Wextra -pedantic -Werror -Wno-multichar -Wno-missing-field-initializers -Wno-unused-function -Wno-type-limits)
if (APPLE)
add_compile_options(-Wno-four-char-constants -Wno-sometimes-uninitialized)
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
add_compile_options(-Wno-uninitialized -Wno-unused-private-field)
else ()
add_compile_options(-Wno-maybe-uninitialized)
add_compile_options(-Wno-maybe-uninitialized -Wno-unused-private-field)
endif()
endif()
@ -90,12 +94,20 @@ if (BUILDING_RUNTIME)
add_subdirectory(include/nncase)
add_subdirectory(src/kernels)
add_subdirectory(src/runtime)
add_subdirectory(benchmark)
add_subdirectory(src/functional)
if(BUILD_BENCHMARK)
add_subdirectory(benchmark)
endif()
# Python binding
if(BUILD_PYTHON_BINDING)
add_subdirectory(python/nncaseruntime/native)
endif()
# Csharp binding
if(BUILD_CSHARP_BINDING)
add_subdirectory(csharp)
endif()
install(DIRECTORY ${NNCASE_INCLUDE_DIR}/nncase
DESTINATION include
@ -145,12 +157,15 @@ else()
if (MSVC)
add_definitions(/D_SILENCE_ALL_CXX17_DEPRECATION_WARNINGS /D_CRT_SECURE_NO_WARNINGS /DNOMINMAX)
add_compile_options(/wd4267 /wd4251 /FC /utf-8 /W3 /WX)
add_compile_options(/wd4267 /wd4251 /wd4244 /FC /utf-8 /W3 /WX)
set(PYBIND11_CPP_STANDARD "/std:c++latest")
else()
add_compile_options(-fvisibility=hidden)
add_compile_options(-Wall -Wextra -pedantic -Werror -Wno-multichar -Wno-missing-field-initializers -Wno-unused-function -Wno-type-limits -Wno-unused-local-typedefs -Wno-sign-compare)
if (APPLE)
add_compile_options(-Wno-four-char-constants -Wno-sometimes-uninitialized -Wno-deprecated)
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
add_compile_options(-Wno-uninitialized)
else ()
add_compile_options(-Wno-maybe-uninitialized -Wno-deprecated-copy)
add_link_options(-Wl,--exclude-libs,ALL)
@ -165,23 +180,24 @@ else()
add_subdirectory(src/data)
add_subdirectory(src/ir)
add_subdirectory(src/importer)
#add_subdirectory(src/schedule)
#add_subdirectory(src/evaluator)
add_subdirectory(src/schedule)
add_subdirectory(src/evaluator)
add_subdirectory(src/functional)
add_subdirectory(src/transforms)
#add_subdirectory(src/codegen)
#add_subdirectory(src/kernels)
#add_subdirectory(src/runtime)
add_subdirectory(src/codegen)
add_subdirectory(src/kernels)
add_subdirectory(src/runtime)
add_subdirectory(src/targets)
#add_subdirectory(src/plugin)
add_subdirectory(src/plugin)
add_subdirectory(src/cli)
if(BUILD_TESTING)
# add_subdirectory(tests/kernels)
add_subdirectory(tests/kernels)
endif()
# Python binding
if(BUILD_PYTHON_BINDING)
# add_subdirectory(python/nncase/native)
add_subdirectory(python/nncase/native)
endif()
# Thrid party
@ -220,11 +236,11 @@ else()
)
# Targets
#add_subdirectory(targets/cpu)
#add_subdirectory(targets/k210)
#add_subdirectory(targets/vulkan)
add_subdirectory(targets/cpu)
add_subdirectory(targets/k210)
add_subdirectory(targets/vulkan)
endif()
# Modules
#add_subdirectory(modules/k210)
#add_subdirectory(modules/vulkan)
add_subdirectory(modules/k210)
add_subdirectory(modules/vulkan)

View File

@ -33,7 +33,7 @@
# but it is only necessary on the end-user side. It is not necessary to create conan
# packages, in fact it shouldn't be use for that. Check the project documentation.
# version: 0.16.1
# version: 0.18.0-dev
include(CMakeParseArguments)
@ -55,6 +55,8 @@ function(_get_msvc_ide_version result)
set(${result} 15 PARENT_SCOPE)
elseif(NOT MSVC_VERSION VERSION_LESS 1920 AND MSVC_VERSION VERSION_LESS 1930)
set(${result} 16 PARENT_SCOPE)
elseif(NOT MSVC_VERSION VERSION_LESS 1930 AND MSVC_VERSION VERSION_LESS 1940)
set(${result} 17 PARENT_SCOPE)
else()
message(FATAL_ERROR "Conan: Unknown MSVC compiler version [${MSVC_VERSION}]")
endif()
@ -126,6 +128,10 @@ macro(_conan_detect_compiler)
set(_CONAN_SETTING_ARCH ${ARGUMENTS_ARCH})
endif()
if(USING_CXX)
set(_CONAN_SETTING_COMPILER_CPPSTD ${CMAKE_CXX_STANDARD})
endif()
if (${CMAKE_${LANGUAGE}_COMPILER_ID} STREQUAL GNU)
# using GCC
# TODO: Handle other params
@ -415,7 +421,8 @@ endfunction()
function(_collect_settings result)
set(ARGUMENTS_PROFILE_AUTO arch build_type compiler compiler.version
compiler.runtime compiler.libcxx compiler.toolset)
compiler.runtime compiler.libcxx compiler.toolset
compiler.cppstd)
foreach(ARG ${ARGUMENTS_PROFILE_AUTO})
string(TOUPPER ${ARG} _arg_name)
string(REPLACE "." "_" _arg_name ${_arg_name})
@ -427,10 +434,10 @@ function(_collect_settings result)
endfunction()
function(conan_cmake_autodetect detected_settings)
_conan_detect_build_type()
_conan_detect_build_type(${ARGV})
_conan_check_system_name()
_conan_check_language()
_conan_detect_compiler()
_conan_detect_compiler(${ARGV})
_collect_settings(collected_settings)
set(${detected_settings} ${collected_settings} PARENT_SCOPE)
endfunction()
@ -794,7 +801,6 @@ macro(conan_check)
if(NOT CONAN_CMD AND CONAN_REQUIRED)
message(FATAL_ERROR "Conan executable not found! Please install conan.")
endif()
if(NOT CONAN_DETECT_QUIET)
message(STATUS "Conan: Found program ${CONAN_CMD}")
endif()
@ -803,13 +809,13 @@ macro(conan_check)
OUTPUT_VARIABLE CONAN_VERSION_OUTPUT
ERROR_VARIABLE CONAN_VERSION_OUTPUT)
message("a $ENV{PATH}")
if(NOT "${return_code}" STREQUAL "0")
message(FATAL_ERROR "Conan --version failed='${return_code}'")
endif()
if(NOT CONAN_DETECT_QUIET)
message(STATUS "Conan: Version found ${CONAN_VERSION_OUTPUT}")
string(STRIP "${CONAN_VERSION_OUTPUT}" _CONAN_VERSION_OUTPUT)
message(STATUS "Conan: Version found ${_CONAN_VERSION_OUTPUT}")
endif()
if(DEFINED CONAN_VERSION)
@ -839,7 +845,7 @@ function(conan_add_remote)
if(DEFINED CONAN_COMMAND)
set(CONAN_CMD ${CONAN_COMMAND})
else()
conan_check(REQUIRED)
conan_check(REQUIRED DETECT_QUIET)
endif()
set(CONAN_VERIFY_SSL_ARG "True")
if(DEFINED CONAN_VERIFY_SSL)

View File

@ -20,7 +20,7 @@ _SET_CONANOPT(CONAN_OPTS "halide" ENABLE_HALIDE)
if (NOT DEFINED CMAKE_CXX_STANDARD)
if (BUILDING_RUNTIME)
set (CMAKE_CXX_STANDARD 14)
set (CMAKE_CXX_STANDARD 17)
else ()
set (CMAKE_CXX_STANDARD 20)
endif ()

View File

@ -26,7 +26,6 @@ if (NOT BUILDING_RUNTIME)
find_package(libzippp REQUIRED)
find_package(inja REQUIRED)
find_package(shaderc REQUIRED)
find_package(range-v3 REQUIRED)
endif ()
if (BUILD_TESTING)

View File

@ -0,0 +1 @@
include(${CMAKE_CURRENT_LIST_DIR}/nncasefunctionalTargets.cmake)

9
csharp/CMakeLists.txt Normal file
View File

@ -0,0 +1,9 @@
cmake_minimum_required (VERSION 3.18)
set(SRCS interpreter.cpp)
add_library(nncaseruntime_csharp MODULE ${SRCS})
target_link_libraries(nncaseruntime_csharp PRIVATE nncaseruntime)
target_include_directories(nncaseruntime_csharp PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
# set_target_properties(nncaseruntime_csharp PROPERTIES OUTPUT_NAME _nncaseruntime_csharp)
install(TARGETS nncaseruntime_csharp DESTINATION lib)

230
csharp/RuntimeTensor.h Normal file
View File

@ -0,0 +1,230 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <array>
#include <cmath>
#include <iostream>
#include <memory>
#include <nncase/runtime/interpreter.h>
#include <nncase/runtime/runtime_op_utility.h>
#include <nncase/runtime/runtime_tensor.h>
#include <nncase/runtime/runtime_tensor_impl.h>
#include <nncase/version.h>
#include <type_traits>
#include <vector>
#include "stdprefix.h"
using namespace nncase;
using namespace nncase::runtime;
template <typename T>
datatype_t from_dtype()
{
if (std::is_same_v<T, float>)
return dt_uint8;
else if (std::is_same_v<T, uint16_t>)
return dt_uint16;
else if (std::is_same_v<T, uint32_t>)
return dt_uint32;
else if (std::is_same_v<T, uint64_t>)
return dt_uint64;
else if (std::is_same_v<T, int8_t>)
return dt_int8;
else if (std::is_same_v<T, int16_t>)
return dt_int16;
else if (std::is_same_v<T, int32_t>)
return dt_int32;
else if (std::is_same_v<T, int64_t>)
return dt_int64;
else if (std::is_same_v<T, std::bfloat16>)
throw std::runtime_error("Unsupported float16");
else if (std::is_same_v<T, float>)
return dt_float32;
else if (std::is_same_v<T, double>)
return dt_float64;
throw std::runtime_error("Unsupported dtype");
}
inline runtime_shape_t to_rt_shape(const int *shape_ptr, int shape_size)
{
runtime_shape_t shape(shape_size);
for (size_t i = 0; i < shape.size(); i++)
shape[i] = (size_t)shape_ptr[i];
return shape;
}
inline runtime_shape_t to_rt_strides(size_t elemsize, const int *stride_ptr, int stride_size)
{
runtime_shape_t strides(stride_size);
for (size_t i = 0; i < strides.size(); i++)
strides[i] = (size_t)stride_ptr[i] / elemsize;
return strides;
}
inline std::vector<size_t> to_py_strides(size_t elemsize, const runtime_shape_t &value)
{
std::vector<size_t> strides(value.size());
for (size_t i = 0; i < strides.size(); i++)
strides[i] = (size_t)value[i] * elemsize;
return strides;
}
extern "C"
{
struct RuntimeTensor
{
void *impl;
};
}
inline std::shared_ptr<nncase::runtime::detail::runtime_tensor_impl> to_rt_impl(RuntimeTensor rt)
{
return std::shared_ptr<nncase::runtime::detail::runtime_tensor_impl>(static_cast<nncase::runtime::detail::runtime_tensor_impl *>(rt.impl));
}
inline runtime_tensor to_rt(RuntimeTensor rt)
{
return runtime_tensor(to_rt_impl(rt));
}
/**
* @brief create tensor form buffer
*
* @param buffer_ptr the buffer ptr type is dtype
* @param datatype the datatype enum value.
* @param shape_ptr the shape pointer
* @param shape_size the shape array size
* @param total_items the total elements counts
* @param item_size the each element total bytes
* @param stride_ptr the stide pointer( by element)
* @return void*
*/
EXPORT_API(RuntimeTensor) RuntimeTensor_from_buffer(
const uint8_t *buffer_ptr, uint8_t datatype, int *shape_ptr,
int shape_size, size_t total_items,
size_t item_size, int *stride_ptr)
{
auto tensor = host_runtime_tensor::create(
(datatype_t)datatype,
to_rt_shape(shape_ptr, shape_size),
to_rt_strides(item_size, stride_ptr, shape_size),
gsl::make_span((gsl::byte *)(buffer_ptr),
total_items * item_size),
[=](gsl::byte *) {})
.unwrap_or_throw();
return RuntimeTensor { .impl = tensor.impl() };
}
EXPORT_API(void) RuntimeTensor_to_buffer(RuntimeTensor rt, uint8_t *buffer_ptr, uint8_t *datatype_ptr)
{
auto rtensor = to_rt(rt);
if (!rtensor.is_contiguous())
{
throw std::errc::not_supported;
}
*datatype_ptr = rtensor.datatype();
auto host = rtensor.as_host().unwrap_or_throw();
auto src_map = std::move(hrt::map(host, hrt::map_read).unwrap_or_throw());
auto src_buffer = src_map.buffer();
memcpy(buffer_ptr, src_buffer.data(), src_buffer.size_bytes());
}
/**
* @brief copy the runtime Tensor
*
* @param from
* @param to
*/
EXPORT_API(void) RuntimeTensor_copy_to(RuntimeTensor from, RuntimeTensor to)
{
auto rt = to_rt(to);
to_rt_impl(from)->copy_to(rt).unwrap_or_throw();
}
/**
* @brief get the runtime tensor datatype enum value.
*
* @return uint8_t
*/
EXPORT_API(uint8_t) RuntimeTensor_dtype(RuntimeTensor rt)
{
uint8_t res;
switch (to_rt_impl(rt)->datatype())
{
case dt_int8:
res = 0;
break;
case dt_int16:
res = 1;
break;
case dt_int32:
res = 2;
break;
case dt_int64:
res = 3;
break;
case dt_uint8:
res = 4;
break;
case dt_uint16:
res = 5;
break;
case dt_uint32:
res = 6;
break;
case dt_uint64:
res = 7;
break;
case dt_float16:
res = 8;
break;
case dt_float32:
res = 9;
break;
case dt_float64:
res = 10;
break;
case dt_bfloat16:
res = 11;
break;
default:
throw std::runtime_error("Not Support The DataType");
break;
}
return res;
}
/**
* @brief get the shape, if shape_ptr is null, just return
*
* @param rt
* @param shape_ptr
* @return int
*/
EXPORT_API(int) RuntimeTensor_shape(RuntimeTensor rt, size_t *shape_ptr)
{
auto rt_shape = to_rt_impl(rt)->shape();
if (shape_ptr != nullptr)
{
for (size_t i = 0; i < rt_shape.size(); i++)
{
shape_ptr[i] = rt_shape[i];
}
}
return rt_shape.size();
}

14
csharp/build.sh Executable file
View File

@ -0,0 +1,14 @@
#! /bin/bash
mkdir -p ../build/runtime
cmake -S .. -B ../build/runtime \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_C_COMPILER=/usr/bin/clang \
-DCMAKE_CXX_COMPILER=/usr/bin/clang++ \
-DENABLE_HALIDE=false \
-DENABLE_OPENMP=false \
-DCMAKE_EXPORT_COMPILE_COMMANDS=true \
-DENABLE_VULKAN_RUNTIME=false \
-DBUILDING_RUNTIME=true \
-G "Unix Makefiles" \
-DCMAKE_INSTALL_PREFIX:PATH=../runtime_install
cmake --build ../build/runtime --target install

128
csharp/interpreter.cpp Normal file
View File

@ -0,0 +1,128 @@
#include "RuntimeTensor.h"
#include <algorithm>
#include <array>
#include <cmath>
#include <iostream>
#include <memory>
#include <nncase/runtime/interpreter.h>
#include <nncase/runtime/runtime_op_utility.h>
#include <nncase/version.h>
#include <type_traits>
#include <vector>
using namespace nncase;
using namespace nncase::runtime;
static std::unique_ptr<interpreter> _interp;
/**
* @brief init the interpreter
*
*/
EXPORT_API(void) interpreter_init()
{
if (!_interp)
{
_interp = std::make_unique<nncase::runtime::interpreter>();
}
}
/**
* @brief load model
*
* @param buffer_ptr the buffer array
* @param size buffer length
*/
EXPORT_API(void) interpreter_load_model(const uint8_t *buffer_ptr, int size)
{
auto buffer = gsl::span<const gsl::byte>((const gsl::byte *)(buffer_ptr), size);
_interp->load_model(buffer).unwrap_or_throw();
}
/**
* @brief get the model inputs size
*
* @return size_t
*/
EXPORT_API(size_t) interpreter_inputs_size()
{
return _interp->inputs_size();
}
/**
* @brief get the model outputs size
*
* @return size_t
*/
EXPORT_API(size_t) interpreter_outputs_size()
{
return _interp->outputs_size();
}
/**
* @brief get the input memory range desc
*
* @param index input number
* @return memory_range
*/
EXPORT_API(memory_range) interpreter_get_input_desc(size_t index)
{
return _interp->input_desc(index);
}
/**
* @brief get the output memory range desc
*
* @param index output number
* @return memory_range
*/
EXPORT_API(memory_range) interpreter_get_output_desc(size_t index)
{
return _interp->output_desc(index);
}
/**
* @brief get the input tensor impl pointer
*
* @param index input number
* @return void* the runtime_tensor impl pointer
*/
EXPORT_API(RuntimeTensor) interpreter_get_input_tensor(size_t index)
{
return RuntimeTensor { .impl = _interp->input_tensor(index).unwrap_or_throw().impl() };
}
/**
* @brief set the input tensor
*
* @param index
* @param rt
*/
EXPORT_API(void) interpreter_set_input_tensor(size_t index, RuntimeTensor rt)
{
_interp->input_tensor(index, to_rt(rt)).unwrap_or_throw();
}
EXPORT_API(RuntimeTensor) interpreter_get_output_tensor(size_t index)
{
return RuntimeTensor { .impl = _interp->output_tensor(index).unwrap_or_throw().impl() };
}
/**
* @brief set the output tensor
*
* @param index
* @param rt
*/
EXPORT_API(void) interpreter_set_output_tensor(size_t index, RuntimeTensor rt)
{
_interp->input_tensor(index, to_rt(rt)).unwrap_or_throw();
}
/**
* @brief call the interpreter
*
*/
EXPORT_API(void) interpreter_run()
{
_interp->run().unwrap_or_throw();
}

20
csharp/stdprefix.h Normal file
View File

@ -0,0 +1,20 @@
#pragma once
#include <limits>
#include <assert.h>
#include <cmath>
#include <cstring>
#define UNUSED(x) (void)(x)
#define DEBUG_ONLY(x) (void)(x)
#ifdef _WIN32
#include <intrin.h>
#define EXPORT_API(ret) extern "C" __declspec(dllexport) ret
#else
#include "unixprefix.h"
#define EXPORT_API(ret) extern "C" __attribute__((visibility("default"))) ret
#define __forceinline __attribute__((always_inline)) inline
#endif

117
csharp/unixprefix.h Normal file
View File

@ -0,0 +1,117 @@
// Ignore SAL annotations for non-Windows build.
#define _In_
#define _Out_
#define _Inout_
#define _In_z_
#define _Inout_z_
#define _In_reads_(s)
#define _In_reads_bytes_(s)
#define _In_reads_z_(s)
#define _In_reads_or_z_(s)
#define _Out_writes_(s)
#define _Out_writes_bytes_(s)
#define _Out_writes_z_(s)
#define _Inout_updates_(s)
#define _Inout_updates_bytes_(s)
#define _Inout_updates_z_(s)
#define _Out_writes_to_(s,c)
#define _Out_writes_bytes_to_(s,c)
#define _Out_writes_all_(s)
#define _Out_writes_bytes_all_(s)
#define _Inout_updates_to_(s,c)
#define _Inout_updates_bytes_to_(s,c)
#define _Inout_updates_all_(s)
#define _Inout_updates_bytes_all_(s)
#define _In_reads_to_ptr_(p)
#define _In_reads_to_ptr_z_(p)
#define _Out_writes_to_ptr_(p)
#define _Out_writes_to_ptr_z_(p)
#define _In_opt_
#define _Out_opt_
#define _Inout_opt_
#define _In_opt_z_
#define _Inout_opt_z_
#define _In_reads_opt_(s)
#define _In_reads_bytes_opt_(s)
#define _In_reads_opt_z_(s)
#define _Out_writes_opt_(s)
#define _Out_writes_opt_z_(s)
#define _Inout_updates_opt_(s)
#define _Inout_updates_bytes_opt_(s)
#define _Inout_updates_opt_z_(s)
#define _Out_writes_to_opt_(s,c)
#define _Out_writes_bytes_to_opt_(s,c)
#define _Out_writes_all_opt_(s)
#define _Out_writes_bytes_all_opt_(s)
#define _Inout_updates_to_opt_(s,c)
#define _Inout_updates_bytes_to_opt_(s,c)
#define _Inout_updates_all_opt_(s)
#define _Inout_updates_bytes_all_opt_(s)
#define _In_reads_to_ptr_opt_(p)
#define _In_reads_to_ptr_opt_z_(p)
#define _Out_writes_to_ptr_opt_(p)
#define _Out_writes_to_ptr_opt_z_(p)
#define _Outptr_
#define _Outptr_opt_
#define _Outptr_result_maybenull_
#define _Outptr_opt_result_maybenull_
#define _Outptr_result_z_
#define _Outptr_opt_result_z_
#define _Outptr_result_maybenull_z_
#define _Ouptr_opt_result_maybenull_z_
#define _COM_Outptr_
#define _COM_Outptr_opt_
#define _COM_Outptr_result_maybenull_
#define _COM_Outptr_opt_result_maybenull_
#define _Outptr_result_buffer_(s) 
#define _Outptr_result_bytebuffer_(s)
#define _Outptr_opt_result_buffer_(s)
#define _Outptr_opt_result_bytebuffer_(s)
#define _Outptr_result_buffer_to_(s, c) 
#define _Outptr_result_bytebuffer_to_(s, c)
#define _Outptr_opt_result_buffer_to_(s,c)
#define _Outptr_opt_result_bytebuffer_to_(s,c)
#define _Result_nullonfailure_
#define _Result_zeroonfailure_
#define _Outptr_result_nullonfailure_
#define _Outptr_opt_result_nullonfailure_
#define _Outref_result_nullonfailure_
#define _Outref_
#define _Outref_result_maybenull_
#define _Outref_result_buffer_(s)
#define _Outref_result_bytebuffer_(s)
#define _Outref_result_buffer_to_(s, c)
#define _Outref_result_bytebuffer_to_(s, c)
#define _Outref_result_buffer_all_(s)
#define _Outref_result_bytebuffer_all_(s)
#define _Outref_result_buffer_maybenull_(s)
#define _Outref_result_bytebuffer_maybenull_(s)
#define _Outref_result_buffer_to_maybenull_(s, c)
#define _Outref_result_bytebuffer_to_maybenull_(s,c)
#define _Outref_result_buffer_all_maybenull_(s)
#define _Outref_result_bytebuffer_all_maybenull_(s)
#define _Ret_z_
#define _Ret_writes_(s)
#define _Ret_writes_bytes_(s)
#define _Ret_writes_z_(s)
#define _Ret_writes_to_(s,c)
#define _Ret_writes_maybenull_(s)
#define _Ret_writes_to_maybenull_(s)
#define _Ret_writes_maybenull_z_(s)
#define _Ret_maybenull_
#define _Ret_maybenull_z_
#define _Ret_null_
#define _Ret_notnull_
#define _Ret_writes_bytes_to_
#define _Ret_writes_bytes_maybenull_
#define _Ret_writes_bytes_to_maybenull_
#define _In_range_(low, hi) 
#define _Out_range_(low, hi)
#define _Ret_range_(low, hi)
#define _Deref_in_range_(low, hi)
#define _Deref_out_range_(low, hi)
#define _Deref_inout_range_(low, hi)
#define _Field_range_(low, hi)
#define _Pre_equal_to_(expr) 
#define _Post_equal_to_(expr)
#define _Struct_size_bytes_(size)

View File

@ -1,24 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "runtime/datatypes.h"
#include <memory>
#include <optional>
#include <type_traits>
#include <utility>
namespace nncase {
class NNCASE_API attribute_map {};
} // namespace nncase

View File

@ -90,6 +90,7 @@ public:
virtual std::unique_ptr<section_decompiler> create_decompiler(std::string_view section_name);
protected:
section *find_section(std::string_view section_name);
void merge_to_rdata_section(std::string_view from);
function_call_id function_id(ir::graph *graph);
void set_current_entry_point(std::streampos pos);

View File

@ -39,6 +39,8 @@ public:
}
void emit_abs() { emit_opcode(runtime::nnil_abs); }
void emit_acos() { emit_opcode(runtime::nnil_acos); }
void emit_asin() { emit_opcode(runtime::nnil_asin); }
void emit_ceil() { emit_opcode(runtime::nnil_ceil); }
void emit_cos() { emit_opcode(runtime::nnil_cos); }
void emit_exp() { emit_opcode(runtime::nnil_exp); }
@ -46,6 +48,7 @@ public:
void emit_log() { emit_opcode(runtime::nnil_log); }
void emit_neg() { emit_opcode(runtime::nnil_neg); }
void emit_rsqrt() { emit_opcode(runtime::nnil_rsqrt); }
void emit_sign() { emit_opcode(runtime::nnil_sign); }
void emit_sin() { emit_opcode(runtime::nnil_sin); }
void emit_sqrt() { emit_opcode(runtime::nnil_sqrt); }
void emit_square() { emit_opcode(runtime::nnil_square); }

View File

@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/9/15 16:49:15 +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 1/14/2022 2:00:39 PM +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
@ -1004,6 +1004,21 @@ struct op_writer<nncase::runtime::stackvm::tensor_convert_op_t>
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_cumsum_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_cumsum_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype));
writer.write(op.rshape_src);
writer.write(op.axis);
writer.write(op.exclusive);
writer.write(op.reverse);
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_dequantize_op_t>
{
@ -1019,6 +1034,22 @@ struct op_writer<nncase::runtime::stackvm::tensor_dequantize_op_t>
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_equal_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_equal_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype));
writer.write(op.rshape_src1);
writer.write(op.rstride_src1);
writer.write(op.rshape_src2);
writer.write(op.rstride_src2);
writer.write(op.rstride_dest);
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_gather_op_t>
{
@ -1053,6 +1084,20 @@ struct op_writer<nncase::runtime::stackvm::tensor_gather_nd_op_t>
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_hardmax_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_hardmax_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype));
writer.write(op.rshape_src);
writer.write(op.rstride_src);
writer.write(op.axis);
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_lut1d_op_t>
{
@ -1068,6 +1113,20 @@ struct op_writer<nncase::runtime::stackvm::tensor_lut1d_op_t>
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_matmul_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_matmul_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(op.rshape_src1);
writer.write(op.rshape_src2);
writer.write(op.fused_clamp_low);
writer.write(op.fused_clamp_high);
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_onehot_op_t>
{
@ -1115,6 +1174,36 @@ struct op_writer<nncase::runtime::stackvm::tensor_quantize_op_t>
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_random_normal_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_random_normal_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype_dest));
writer.write(op.rshape_dest);
writer.write(op.mean);
writer.write(op.std);
writer.write(op.seed);
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_random_uniform_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_random_uniform_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype_dest));
writer.write(op.rshape_dest);
writer.write(op.low);
writer.write(op.high);
writer.write(op.seed);
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_reduce_op_t>
{
@ -1132,6 +1221,40 @@ struct op_writer<nncase::runtime::stackvm::tensor_reduce_op_t>
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_reduce_arg_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_reduce_arg_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype_src));
writer.write(op.rshape_src);
writer.write(op.rstride_src);
writer.write(static_cast<uint8_t>(op.datatype_dest));
writer.write(op.rstride_dest);
writer.write(static_cast<uint8_t>(op.reduce_arg_op));
writer.write(op.rshape_axis);
writer.write(op.keep_dims);
writer.write(op.select_last_idx);
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_reduce_prod_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_reduce_prod_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(op.rshape_src);
writer.write(op.rstride_src);
writer.write(op.rstride_dest);
writer.write(op.rshape_axes);
writer.write(op.keep_dims);
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_reduce_window2d_op_t>
{
@ -1172,6 +1295,35 @@ struct op_writer<nncase::runtime::stackvm::tensor_resize_image_op_t>
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_roi_align_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_roi_align_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype));
writer.write(op.rshape_src);
writer.write(op.rshape_dest);
writer.write(static_cast<uint8_t>(op.mode));
writer.write(op.spatial_scale);
writer.write(op.sampling_ratio);
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_sigmoid_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_sigmoid_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype));
writer.write(op.rshape_src);
writer.write(op.rstride_src);
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_slice_op_t>
{
@ -1189,6 +1341,59 @@ struct op_writer<nncase::runtime::stackvm::tensor_slice_op_t>
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_ternary_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_ternary_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype));
writer.write(op.rshape_src1);
writer.write(op.rstride_src1);
writer.write(op.rshape_src2);
writer.write(op.rstride_src2);
writer.write(op.rshape_src3);
writer.write(op.rstride_src3);
writer.write(op.rstride_dest);
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_topk_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_topk_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype));
writer.write(op.rshape_src);
writer.write(op.rstride_src);
writer.write(op.rshape_dest1);
writer.write(op.rstride_dest1);
writer.write(op.rshape_dest2);
writer.write(op.rstride_dest2);
writer.write(op.k);
writer.write(op.axis);
writer.write(op.largest);
writer.write(op.sorted);
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_trilu_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_trilu_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype));
writer.write(op.rshape_src);
writer.write(op.upper);
writer.write(op.k);
}
};
template <>
struct op_writer<nncase::runtime::stackvm::tensor_unary_op_t>
{
@ -1218,7 +1423,7 @@ struct op_writer<nncase::runtime::stackvm::tensor_transpose_op_t>
writer.write(op.rshape_perm);
}
};
class NNCASE_API op_builder
{
public:
@ -1327,17 +1532,30 @@ public:
void tensor_conv2d_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rshape_kernel, uint8_t rstride_kernel, uint8_t rstride_bias, uint8_t rstride_dest, uint16_t groups, uint16_t stride_h, uint16_t stride_w, uint16_t dilation_h, uint16_t dilation_w, float fused_clamp_low, float fused_clamp_high);
void tensor_copy_(datatype_t datatype, uint8_t rshape, uint8_t rstride_src, uint8_t rstride_dest);
void tensor_convert_(datatype_t in_datatype, datatype_t dst_datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest);
void tensor_cumsum_(datatype_t datatype, uint8_t rshape_src, int32_t axis, bool exclusive, bool reverse);
void tensor_dequantize_(datatype_t in_datatype, datatype_t dst_datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest);
void tensor_equal_(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rstride_dest);
void tensor_gather_(datatype_t datatype, uint8_t rshape_src, uint8_t rshape_dest, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_indices, uint8_t axis);
void tensor_gather_nd_(datatype_t datatype, uint8_t rshape_src, uint8_t rshape_dest, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_indices, uint8_t batch_dims);
void tensor_hardmax_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, int32_t axis);
void tensor_lut1d_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint16_t table_len);
void tensor_matmul_(uint8_t rshape_src1, uint8_t rshape_src2, float fused_clamp_low, float fused_clamp_high);
void tensor_onehot_(datatype_t datatype, uint8_t rshape_indices, uint8_t rshape_dest, uint8_t rstride_dest, uint8_t axis, onehot_mode_t onehot_mode);
void tensor_pad_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rpaddings, pad_mode_t pad_mode);
void tensor_quantize_(datatype_t in_datatype, datatype_t dst_datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest);
void tensor_random_normal_(datatype_t datatype_dest, uint8_t rshape_dest, float mean, float std, float seed);
void tensor_random_uniform_(datatype_t datatype_dest, uint8_t rshape_dest, float low, float high, float seed);
void tensor_reduce_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, reduce_op_t reduce_op, uint8_t rshape_axis, bool keep_dims);
void tensor_reduce_arg_(datatype_t datatype_src, uint8_t rshape_src, uint8_t rstride_src, datatype_t datatype_dest, uint8_t rstride_dest, reduce_arg_op_t reduce_arg_op, uint8_t rshape_axis, bool keep_dims, bool select_last_idx);
void tensor_reduce_prod_(uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_axes, bool keep_dims);
void tensor_reduce_window2d_(datatype_t datatype, reduce_op_t reduce_op, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint16_t filter_h, uint16_t filter_w, uint16_t stride_h, uint16_t stride_w, uint16_t dilation_h, uint16_t dilation_w, float fused_clamp_low, float fused_clamp_high);
void tensor_resize_image_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, bool align_corners, bool half_pixel_centers, image_resize_mode_t image_resize_mode);
void tensor_roi_align_(datatype_t datatype, uint8_t rshape_src, uint8_t rshape_dest, roi_align_mode_t mode, float spatial_scale, int64_t sampling_ratio);
void tensor_sigmoid_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src);
void tensor_slice_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rbegins, uint8_t rends, uint8_t rstrides);
void tensor_ternary_(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rshape_src3, uint8_t rstride_src3, uint8_t rstride_dest);
void tensor_topk_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rshape_dest1, uint8_t rstride_dest1, uint8_t rshape_dest2, uint8_t rstride_dest2, int64_t k, int32_t axis, bool largest, bool sorted);
void tensor_trilu_(datatype_t datatype, uint8_t rshape_src, bool upper, int64_t k);
void tensor_unary_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, unary_op_t unary_op);
void tensor_transpose_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_perm);

View File

@ -36,21 +36,31 @@ struct compile_options
{
bool dump_ir;
bool dump_asm;
bool dump_quant_error;
bool dump_import_op_range;
bool is_fpga;
bool use_dataset_as_input_stat = false;
bool benchmark_only = false;
bool preprocess = false;
bool swapRB = false;
std::string target;
std::filesystem::path dump_dir;
std::string input_type = "default";
std::string output_type = "float32";
std::string quant_type = "uint8";
std::vector<float> mean { 0.f, 0.f, 0.f };
std::vector<float> std { 1.f, 1.f, 1.f };
std::vector<float> input_range { 0.f, 1.f };
float letterbox_value = 0.f;
std::vector<int32_t> input_shape {};
std::string w_quant_type = "uint8";
bool use_mse_quant_w = false;
std::string input_layout = "NCHW";
std::string output_layout = "NCHW";
};
struct import_options
{
std::string input_layout = "NCHW";
std::string output_layout = "NCHW";
std::span<const std::string> output_arrays;
};
@ -58,9 +68,6 @@ struct ptq_options_base
{
std::string calibrate_method = "no_clip";
std::function<void(size_t cnt, size_t total)> progress;
float input_mean = 0.f;
float input_std = 1.f;
};
struct ptq_dataset_options : ptq_options_base
@ -75,6 +82,23 @@ struct ptq_tensor_options : ptq_options_base
size_t samples_count;
};
struct dump_range_options_base
{
std::string calibrate_method = "no_clip";
std::function<void(size_t cnt, size_t total)> progress;
};
struct dump_range_dataset_options : dump_range_options_base
{
std::filesystem::path dataset;
std::string dataset_format;
};
struct dump_range_tensor_options : dump_range_options_base
{
std::vector<uint8_t> tensor_data;
size_t samples_count;
};
class NNCASE_API compiler
{
public:
@ -83,8 +107,11 @@ public:
virtual ~compiler();
virtual void import_tflite(std::span<const uint8_t> model, const import_options &options) = 0;
virtual void import_onnx(std::span<const uint8_t> model, const import_options &options) = 0;
virtual void import_caffe(std::span<const uint8_t> model, std::span<const uint8_t> prototxt) = 0;
virtual void use_ptq(ptq_dataset_options options) = 0;
virtual void use_ptq(ptq_tensor_options options) = 0;
virtual void dump_range_options(dump_range_dataset_options options) = 0;
virtual void dump_range_options(dump_range_tensor_options options) = 0;
virtual ir::graph &graph(uint32_t stage) = 0;
virtual nncase::target &target() = 0;
virtual void compile() = 0;

View File

@ -116,6 +116,7 @@ protected:
virtual void process(const std::vector<uint8_t> &src, float *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) = 0;
virtual void process(const std::vector<uint8_t> &src, uint8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) = 0;
virtual void process(const std::vector<uint8_t> &src, int8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) = 0;
virtual bool do_normalize() const noexcept { return true; }
private:
template <class T>
@ -130,6 +131,7 @@ private:
process(file, batch.data(), batch.shape(), input_layout_);
std::span<const std::filesystem::path> filenames(filenames_.data() + start, filenames_.data() + from);
return data_batch<T> { std::move(batch), filenames };
}
@ -162,5 +164,6 @@ protected:
void process(const std::vector<uint8_t> &src, float *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) override;
void process(const std::vector<uint8_t> &src, uint8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) override;
void process(const std::vector<uint8_t> &src, int8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) override;
bool do_normalize() const noexcept override { return false; }
};
}

View File

@ -0,0 +1,293 @@
/**
* @file ops.h
* @date 2021-09-27
*
* @copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
#pragma once
#include <nncase/runtime/runtime_tensor.h>
#ifndef NNCASE_FUNCTIONAL_IMPL_PLATFORM_HEADER
#include <nncase/functional/ops.platform.h>
#else
#include NNCASE_FUNCTIONAL_IMPL_PLATFORM_HEADER
#endif
namespace nncase::F
{
/**
* @brief unary_square
*
* @param input runtime_tensor
* @param dtype output tensor datatype
* @return result<runtime::runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> square(runtime::runtime_tensor &input, datatype_t dtype) noexcept
{
return impl::unary(input, dtype, unary_square);
}
/**
* @brief unary_sqrt
*
* @param input runtime_tensor
* @param dtype output tensor datatype
* @return result<runtime::runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> sqrt(runtime::runtime_tensor &input, datatype_t dtype) noexcept
{
return impl::unary(input, dtype, unary_sqrt);
}
/**
* @brief unary_log
*
* @param input runtime_tensor
* @param dtype output tensor datatype
* @return result<runtime::runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> log(runtime::runtime_tensor &input, datatype_t dtype) noexcept
{
return impl::unary(input, dtype, unary_log);
}
/**
* @brief unary_exp
*
* @param input runtime_tensor
* @param dtype output tensor datatype
* @return result<runtime::runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> exp(runtime::runtime_tensor &input, datatype_t dtype) noexcept
{
return impl::unary(input, dtype, unary_exp);
}
/**
* @brief unary_sin
*
* @param input runtime_tensor
* @param dtype output tensor datatype
* @return result<runtime::runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> sin(runtime::runtime_tensor &input, datatype_t dtype) noexcept
{
return impl::unary(input, dtype, unary_sin);
}
/**
* @brief unary_cos
*
* @param input runtime_tensor
* @param dtype output tensor datatype
* @return result<runtime::runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> cos(runtime::runtime_tensor &input, datatype_t dtype) noexcept
{
return impl::unary(input, dtype, unary_cos);
}
/**
* @brief unary_round
*
* @param input runtime_tensor
* @param dtype output tensor datatype
* @return result<runtime::runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> round(runtime::runtime_tensor &input, datatype_t dtype) noexcept
{
return impl::unary(input, dtype, unary_round);
}
/**
* @brief unary_floor
*
* @param input runtime_tensor
* @param dtype output tensor datatype
* @return result<runtime::runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> floor(runtime::runtime_tensor &input, datatype_t dtype) noexcept
{
return impl::unary(input, dtype, unary_floor);
}
/**
* @brief unary_ceil
*
* @param input runtime_tensor
* @param dtype output tensor datatype
* @return result<runtime::runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> ceil(runtime::runtime_tensor &input, datatype_t dtype) noexcept
{
return impl::unary(input, dtype, unary_ceil);
}
/**
* @brief unary_abs
*
* @param input runtime_tensor
* @param dtype output tensor datatype
* @return result<runtime::runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> abs(runtime::runtime_tensor &input, datatype_t dtype) noexcept
{
return impl::unary(input, dtype, unary_abs);
}
/**
* @brief unary_neg
*
* @param input runtime_tensor
* @param dtype output tensor datatype
* @return result<runtime::runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> neg(runtime::runtime_tensor &input, datatype_t dtype) noexcept
{
return impl::unary(input, dtype, unary_neg);
}
/**
* @brief binary add
* temporary not support
* @param input_a runtime_tensor
* @param input_b runtime_tensor
* @param dtype datatype, output tensor datatype
* @return result<runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> add(runtime::runtime_tensor &input_a, runtime::runtime_tensor &input_b, datatype_t dtype) noexcept
{
return impl::binary(input_a, input_b, dtype, binary_add);
}
/**
* @brief binary sub
* temporary not support
* @param input_a runtime_tensor
* @param input_b runtime_tensor
* @param dtype datatype, output tensor datatype
* @return result<runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> sub(runtime::runtime_tensor &input_a, runtime::runtime_tensor &input_b, datatype_t dtype) noexcept
{
return impl::binary(input_a, input_b, dtype, binary_sub);
}
/**
* @brief binary mul
* temporary not support
* @param input_a runtime_tensor
* @param input_b runtime_tensor
* @param dtype datatype, output tensor datatype
* @return result<runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> mul(runtime::runtime_tensor &input_a, runtime::runtime_tensor &input_b, datatype_t dtype) noexcept
{
return impl::binary(input_a, input_b, dtype, binary_mul);
}
/**
* @brief binary div
* temporary not support
* @param input_a runtime_tensor
* @param input_b runtime_tensor
* @param dtype datatype, output tensor datatype
* @return result<runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> div(runtime::runtime_tensor &input_a, runtime::runtime_tensor &input_b, datatype_t dtype) noexcept
{
return impl::binary(input_a, input_b, dtype, binary_div);
}
/**
* @brief binary min
* temporary not support
* @param input_a runtime_tensor
* @param input_b runtime_tensor
* @param dtype datatype, output tensor datatype
* @return result<runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> min(runtime::runtime_tensor &input_a, runtime::runtime_tensor &input_b, datatype_t dtype) noexcept
{
return impl::binary(input_a, input_b, dtype, binary_min);
}
/**
* @brief binary max
* temporary not support
* @param input_a runtime_tensor
* @param input_b runtime_tensor
* @param dtype datatype, output tensor datatype
* @return result<runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> max(runtime::runtime_tensor &input_a, runtime::runtime_tensor &input_b, datatype_t dtype) noexcept
{
return impl::binary(input_a, input_b, dtype, binary_max);
}
/**
* @brief quantize float or bfloat tensor to uint8 or int8
*
* @param input runtime_tensor
* @param dtype datatype, output tensor datatype
* @return result<runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> quantize(runtime::runtime_tensor &input, datatype_t dtype) noexcept
{
return impl::quantize(input, dtype);
}
/**
* @brief dequantize uint8 or int8 tensor to float or bfloat
*
* @param input runtime_tensor
* @param dtype datatype, output tensor datatype
* @return result<runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> dequantize(runtime::runtime_tensor &input, datatype_t dtype) noexcept
{
return impl::dequantize(input, dtype);
}
/**
* @brief give bboxs, crop new tensor from current tensor.
*
* @param input
* @param bbox runtime tensor, shape should be [1,1,roi_amounts,4], layout should be [y0, x0, y1, x1]
* @param out_h output tensor height
* @param out_w output tensor width
* @param resize_mode resize mode
* @return result<runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> crop(runtime::runtime_tensor &input, runtime::runtime_tensor &bbox, size_t out_h, size_t out_w, image_resize_mode_t resize_mode, bool align_corners, bool half_pixel_centers) noexcept
{
return impl::crop(input, bbox, out_h, out_w, resize_mode, align_corners, half_pixel_centers);
}
/**
* @brief resize tensor to new height or width
*
* @param input
* @param out_h
* @param out_w
* @param resize_mode
* @return result<runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> resize(runtime::runtime_tensor &input, size_t out_h, size_t out_w, image_resize_mode_t resize_mode, bool align_corners, bool half_pixel_centers) noexcept
{
return impl::resize(input, out_h, out_w, resize_mode, align_corners, half_pixel_centers);
}
/**
* @brief padding value on the input tensor
* temporary not support
* @param input
* @param padding vector for padding param, from last to frist. eg. vector [ {2,3}, {1,3} ] mean pad {2,3} in last dim, pad {1,3} in last second dim
* @param pad_mode
* @param fill_v const fill value
* @return result<runtime_tensor>
*/
NNCASE_API inline result<runtime::runtime_tensor> pad(runtime::runtime_tensor &input, runtime_paddings_t &paddings, pad_mode_t pad_mode, float fill_v) noexcept
{
return impl::pad(input, paddings, pad_mode, fill_v);
}
}

View File

@ -0,0 +1,37 @@
/**
* @file ops.platform.h
* @copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
#pragma once
#include <nncase/runtime/runtime_tensor.h>
namespace nncase::F::impl
{
result<runtime::runtime_tensor> unary(runtime::runtime_tensor &input, datatype_t dtype, unary_op_t op_type) noexcept;
result<runtime::runtime_tensor> binary(runtime::runtime_tensor &input_a, runtime::runtime_tensor &input_b, datatype_t dtype, binary_op_t op_type) noexcept;
result<runtime::runtime_tensor> quantize(runtime::runtime_tensor &input, datatype_t dtype) noexcept;
result<runtime::runtime_tensor> dequantize(runtime::runtime_tensor &input, datatype_t dtype) noexcept;
result<runtime::runtime_tensor> crop(runtime::runtime_tensor &input, runtime::runtime_tensor &bbox, size_t out_h, size_t out_w, image_resize_mode_t resize_mode, bool align_corners, bool half_pixel_centers) noexcept;
result<runtime::runtime_tensor> resize(runtime::runtime_tensor &input, size_t out_h, size_t out_w, image_resize_mode_t resize_mode, bool align_corners, bool half_pixel_centers) noexcept;
result<runtime::runtime_tensor> pad(runtime::runtime_tensor &input, runtime_paddings_t &paddings, pad_mode_t pad_mode, float fill_v) noexcept;
}

View File

@ -14,7 +14,7 @@
*/
#pragma once
#include <filesystem>
#include <nncase/ir/module.h>
#include <nncase/ir/graph.h>
#include <span>
#include <unordered_map>
#include <vector>
@ -23,11 +23,10 @@ namespace nncase::importer
{
struct import_options
{
std::string input_layout = "NCHW";
std::string output_layout = "NCHW";
std::span<const std::string> output_arrays;
};
NNCASE_API ir::module_t import_tflite(std::span<const uint8_t> model, const import_options &options);
//ir::module import_onnx(std::span<const uint8_t> model, const import_options &options);
void import_tflite(ir::graph &graph, std::span<const uint8_t> model, const import_options &options, std::string &real_inlayout, std::string &real_outlayout);
void import_onnx(ir::graph &graph, std::span<const uint8_t> model, const import_options &options, std::string &real_inlayout, std::string &real_outlayout);
void import_caffe(ir::graph &graph, std::span<const uint8_t> model, std::span<const uint8_t> prototxt, std::string &real_inlayout, std::string &real_outlayout);
}

View File

@ -16,6 +16,7 @@
#include <nncase/ir/debug.h>
#include <nncase/ir/ir_types.h>
#include <nncase/ir/ops/convert.h>
#include <nncase/runtime/debug.h>
namespace nncase::importer
{

View File

@ -1,48 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "function.h"
#include "op.h"
#include <variant>
namespace nncase::ir {
/** @brief Call node */
class NNCASE_API call_node : public expr_node {
DEFINE_OBJECT_KIND(expr_node, object_call)
public:
call_node(expr target, std::vector<expr> arguments);
/** @brief Get the arguments of the call expression */
std::span<const expr> arguments() const noexcept { return arguments_; }
/** @brief Get the mutable arguments of the call expression */
std::span<expr> arguments() noexcept { return arguments_; }
/** @brief Get the target of the call expression */
const expr &target() const noexcept { return target_; }
/** @brief Set the target of the function expression */
void target(expr value) noexcept { target_ = std::move(value); }
private:
expr target_;
std::vector<expr> arguments_;
};
class call : public object_t<call_node> {
public:
using object_t::object_t;
NNCASE_API call(expr target, std::vector<expr> arguments);
};
} // namespace nncase::ir

View File

@ -0,0 +1,91 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "ir_types.h"
#include <optional>
#include <span>
#include <string>
#include <vector>
#include <xtensor/xshape.hpp>
namespace nncase::ir
{
class node;
class output_connector;
class NNCASE_API base_connector
{
public:
template <class TName, class TShape>
base_connector(node &owner, TName &&name, datatype_t type, TShape &&shape)
: owner_(owner), name_(std::forward<TName>(name)), type_(type), shape_(std::forward<TShape>(shape))
{
}
base_connector(base_connector &) = delete;
base_connector(base_connector &&) = default;
node &owner() const noexcept { return owner_; }
const std::string &name() const noexcept { return name_; }
datatype_t type() const noexcept { return type_; }
const shape_t &shape() const noexcept { return shape_; }
connector_attributes attributes() const noexcept { return attributes_; }
void attributes(connector_attributes value) noexcept { attributes_ = value; }
private:
node &owner_;
std::string name_;
datatype_t type_;
shape_t shape_;
connector_attributes attributes_ = cnctr_attr_none;
};
class NNCASE_API input_connector : public base_connector
{
public:
using base_connector::base_connector;
output_connector *connection() const noexcept { return connection_; }
void connect(output_connector &connector);
void clear_connection();
private:
output_connector *connection_ = nullptr;
};
class NNCASE_API output_connector : public base_connector
{
public:
template <class TName, class TShape>
output_connector(node &owner, TName &&name, datatype_t type, TShape &&shape, memory_location_t memory_location = mem_data)
: base_connector(owner, std::forward<TName>(name), type, std::forward<TShape>(shape)), memory_location_(memory_location)
{
}
std::span<input_connector *const> connections() const noexcept { return connections_; }
void connect(input_connector &connector);
void disconnect(input_connector &connector);
void clear_connections();
// connector_attributes attributes() const noexcept { return attributes_; }
// void attributes(connector_attributes value) noexcept { attributes_ = value; }
memory_location_t memory_location() const noexcept { return memory_location_; }
void memory_location(memory_location_t value) noexcept { memory_location_ = value; }
private:
std::vector<input_connector *> connections_;
// connector_attributes attributes_ = cnctr_attr_none;
memory_location_t memory_location_;
};
}

View File

@ -1,61 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "expr.h"
#include "type.h"
namespace nncase::ir {
/** @brief Constant node */
class NNCASE_API constant_node : public expr_node {
DEFINE_OBJECT_KIND(expr_node, object_constant);
public:
constant_node(type value_type, std::vector<std::byte> data);
/** @brief Get the type of the constant expression */
const type &value_type() const noexcept { return value_type_; }
/** @brief Get the mutable type of the constant expression */
type &value_type() noexcept { return value_type_; }
/** @brief Get the data of the constant expression */
std::span<const std::byte> data() const noexcept { return data_; }
/** @brief Get the mutable data of the constant expression */
std::span<std::byte> data() noexcept { return data_; }
private:
type value_type_;
std::vector<std::byte> data_;
};
class constant : public object_t<constant_node> {
public:
using object_t::object_t;
NNCASE_API constant(type value_type, std::vector<std::byte> data);
constant(type value_type, std::span<const std::byte> data)
: constant(std::move(value_type),
std::vector<std::byte>(data.begin(), data.end())) {}
template <class T>
constant(type value_type, std::span<const T> data)
: constant(std::move(value_type), std::as_bytes(data)) {}
template <class TScalar>
constant(TScalar scalar)
: constant(prim_type(to_datatype<TScalar>()),
std::span<const TScalar>(&scalar, 1)) {}
};
} // namespace nncase::ir

View File

@ -13,37 +13,28 @@
* limitations under the License.
*/
#pragma once
#include "graph.h"
#include "ir_types.h"
#include "module.h"
#include <filesystem>
#include <map>
#include <string>
namespace nncase {
constexpr std::string_view datatype_names(datatype_t dt) {
switch (dt) {
#define DEFINE_DATATYPE(id, t, name, value) \
case dt_##id: \
return #name;
#include <nncase/runtime/datatypes.def>
#undef DEFINE_DATATYPE
default:
throw std::invalid_argument("invalid datatype");
}
namespace nncase
{
inline std::string to_string(const padding &value)
{
return "{" + std::to_string(value.before) + ", " + std::to_string(value.after) + "}";
}
inline std::string to_string(const padding &value) {
return "{" + std::to_string(value.before) + ", " +
std::to_string(value.after) + "}";
inline std::string to_string(const quant_param_t &value)
{
return "(q - " + std::to_string(value.zero_point) + ") * " + std::to_string(value.scale);
}
inline std::string to_string(const quant_param_t &value) {
return "(q - " + std::to_string(value.zero_point) + ") * " +
std::to_string(value.scale);
}
inline std::string to_string(memory_location_t location) {
switch (location) {
inline std::string to_string(memory_location_t location)
{
switch (location)
{
case mem_input:
return "input";
case mem_output:
@ -58,53 +49,52 @@ inline std::string to_string(memory_location_t location) {
}
template <typename Tv, typename T>
static size_t index_of(const Tv &v, const T &e) {
for (size_t i = 0; i < v.size(); i++) {
if (&v[i] == &e) {
static size_t index_of(const Tv &v, const T &e)
{
for (size_t i = 0; i < v.size(); i++)
{
if (&v[i] == &e)
{
return i;
}
}
return SIZE_MAX;
}
namespace ir {
inline std::string to_string(const dim_t &dim) {
return dim.is_fixed() ? std::to_string(dim.value) : "?";
}
inline std::string to_string(const shape_t &shape) {
std::string str{'['};
if (shape.is_invalid()) {
str += "invalid";
} else if (shape.is_unranked()) {
str += '*';
} else {
for (size_t i = 0; i < shape.rank(); i++) {
if (i != 0) {
namespace ir
{
inline std::string to_string(const shape_t &shape)
{
std::string str { '[' };
for (size_t i = 0; i < shape.size(); i++)
{
if (i != 0)
{
str.append(",");
}
str.append(to_string(shape[i]));
str.append(std::to_string(shape[i]));
}
str += ']';
return str;
}
str += ']';
return str;
}
inline std::string to_string(const axis_t &axis) {
std::string str{'['};
for (size_t i = 0; i < axis.size(); i++) {
if (i != 0) {
str.append(",");
inline std::string to_string(const axis_t &axis)
{
std::string str { '[' };
for (size_t i = 0; i < axis.size(); i++)
{
if (i != 0)
{
str.append(",");
}
str.append(std::to_string(axis[i]));
}
str.append(std::to_string(axis[i]));
str += ']';
return str;
}
str += ']';
return str;
NNCASE_API void dump_graph(const ir::graph &src_graph, const std::filesystem::path &dst_path);
}
}
NNCASE_API void dump_function(const ir::function &func,
const std::filesystem::path &dst_path);
} // namespace ir
} // namespace nncase

View File

@ -24,6 +24,13 @@ class target;
namespace nncase::ir
{
enum class eval_step
{
after_import,
after_calib,
after_quant
};
class module_evaluate_context;
class model_evaluate_context;
@ -53,7 +60,7 @@ public:
module_evaluate_context &module() const noexcept { return mod_eval_; }
void evaluate();
void evaluate(eval_step step, size_t stage, bool record_output_buffers);
private:
const schedule::function_schedule_result &sched_;
@ -129,7 +136,7 @@ public:
void end_sample();
void end_collect_distribution(const std::function<void(size_t cnt, size_t total)> &progress);
void evaluate();
void evaluate(eval_step step, size_t stage, bool record_output_buffers);
private:
const schedule::model_schedule_result &sched_;

View File

@ -26,7 +26,7 @@ public:
evaluator(evaluator &&) = default;
void enable_ptq(target &target, ir::calibrate_method calib_method);
void evaluate();
void evaluate(eval_step step = nncase::ir::eval_step::after_import, size_t stage = 0, bool record_output_buffers = false);
ir::quantizer *quantizer(const module_type_t &module_type);
void begin_collect_distribution();

View File

@ -1,45 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../object.h"
#include "type.h"
#include <range/v3/range/concepts.hpp>
namespace nncase::ir {
/** @brief Expression node */
class NNCASE_API expr_node : public object_node {
DEFINE_OBJECT_KIND(object_node, object_expr)
public:
expr_node();
/** @brief Get the checked type of the expression */
const type &checked_type() const noexcept { return checked_type_; }
/** @brief Get the mutable checked type of the expression */
type &checked_type() noexcept { return checked_type_; }
/** @brief Set the checked type of the expression */
void checked_type(type value) noexcept { checked_type_ = std::move(value); }
private:
type checked_type_;
};
using expr = object_t<expr_node>;
template <class T>
concept Expr = Object<T> &&
(concepts::same_as<expr_node, typename T::node_type> ||
concepts::derived_from<typename T::node_type, expr_node>);
} // namespace nncase::ir

View File

@ -1,65 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "expr.h"
#include "var.h"
namespace nncase::ir {
/** @brief Function node */
class NNCASE_API function_node : public expr_node {
DEFINE_OBJECT_KIND(expr_node, object_function)
public:
function_node(std::string name, std::vector<var> parameters, expr body);
/** @brief Get the name of the function expression */
const std::string &name() const noexcept { return name_; }
/** @brief Get the mutable name of the function expression */
std::string &name() noexcept { return name_; }
/** @brief Get the parameters of the function expression */
std::span<const var> parameters() const noexcept { return parameters_; }
/** @brief Get the body of the function expression */
const expr &body() const noexcept { return body_; }
/** @brief Set the body of the function expression */
void body(expr value) noexcept { body_ = std::move(value); }
private:
std::string name_;
std::vector<var> parameters_;
expr body_;
};
/** @brief Function expression */
class function : public object_t<function_node> {
public:
using object_t::object_t;
/** @brief Construct a named function expression with auto-generated name
* @param[in] name The name of the function
* @param[in] parameters The parameters of the function
* @param[in] body The body of the function
*/
NNCASE_API function(std::vector<var> parameters, expr body);
/** @brief Construct a named function expression
* @param[in] name The name of the function
* @param[in] parameters The parameters of the function
* @param[in] body The body of the function
*/
NNCASE_API function(std::string name, std::vector<var> parameters,
expr body);
};
} // namespace nncase::ir

View File

@ -1,37 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../object.h"
#include "call.h"
#include "constant.h"
namespace nncase::ir::F {
namespace detail {
template <class T> concept Scalar = requires {
nncase::detail::cpp_type_to_datatype<T>::type;
};
} // namespace detail
class fexpr : public expr {
public:
template <Expr T> fexpr(T &&other) noexcept : expr(std::move(other)) {}
template <Expr T> fexpr(const T &other) noexcept : expr(other) {}
template <detail::Scalar T>
fexpr(T scalar)
: expr(constant(tensor_type(to_datatype<T>()),
std::span<const T>(&scalar, 1))) {}
};
} // namespace nncase::ir::F

86
include/nncase/ir/graph.h Normal file
View File

@ -0,0 +1,86 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "node.h"
#include "placeholders.h"
#include <memory>
#include <unordered_map>
#include <vector>
namespace nncase::ir
{
class graph;
struct split_graph_result
{
std::unique_ptr<graph> subgraph;
std::unordered_map<input_node *, output_connector *> inputs;
std::unordered_map<output_node *, std::vector<input_connector *>> outputs;
};
class NNCASE_API graph
{
public:
graph() noexcept;
explicit graph(const module_type_t &module_type) noexcept
: module_type_(module_type) { }
graph(graph &) = delete;
graph(graph &&) = delete;
const std::string &name() const noexcept { return name_; }
std::string escaped_name() const noexcept;
void name(std::string value) { name_ = std::move(value); }
const module_type_t &module_type() const noexcept { return module_type_; }
void set_module_type(module_type_t type) { this->module_type_ = type; }
std::span<std::unique_ptr<node>> nodes() noexcept { return nodes_; }
std::span<input_node *> inputs() noexcept { return inputs_; }
std::span<output_node *> outputs() noexcept { return outputs_; }
std::span<std::unique_ptr<graph>> subgraphs() noexcept { return subgraphs_; }
std::vector<graph *> reachable_graphs() noexcept;
std::span<std::unique_ptr<node> const> nodes() const noexcept { return nodes_; }
std::span<input_node *const> inputs() const noexcept { return inputs_; }
std::span<output_node *const> outputs() const noexcept { return outputs_; }
std::span<std::unique_ptr<graph> const> subgraphs() const noexcept { return subgraphs_; }
template <class T, class... TArgs>
T *emplace(TArgs &&...args)
{
auto node = static_cast<T *>(nodes_.emplace_back(new T(std::forward<TArgs>(args)...)).get());
if constexpr (std::is_same_v<T, input_node>)
inputs_.emplace_back(node);
else if constexpr (std::is_same_v<T, output_node>)
outputs_.emplace_back(node);
return node;
}
void assign_names();
void dce();
void cse();
void merge_module_regions();
split_graph_result split_subgraph(std::span<node *const> nodes);
graph &add_subgraph(std::unique_ptr<graph> subgraph);
private:
std::string name_;
module_type_t module_type_;
std::vector<std::unique_ptr<node>> nodes_;
std::vector<std::unique_ptr<graph>> subgraphs_;
std::vector<input_node *> inputs_;
std::vector<output_node *> outputs_;
};
}

View File

@ -13,18 +13,18 @@
* limitations under the License.
*/
#pragma once
#include "shape.h"
#include "type.h"
#include <nncase/runtime/datatypes.h>
#include <span>
#include <type_traits>
#include <xtensor/xshape.hpp>
namespace nncase::ir {
class op_node;
namespace nncase::ir
{
using shape_t = xt::dynamic_shape<std::size_t>;
using axis_t = xt::dynamic_shape<int32_t>;
enum node_attributes {
enum node_attributes
{
node_attr_none = 0,
node_attr_action = 1,
node_attr_need_quantize = 2,
@ -33,44 +33,22 @@ enum node_attributes {
node_attr_skip_constant_folding = 16
};
enum connector_attributes {
enum connector_attributes
{
cnctr_attr_none = 0,
cnctr_attr_need_quantize = 1,
cnctr_attr_no_layout_strides = 2,
cnctr_attr_no_buffer_fusion = 4,
cnctr_attr_no_dummy_for_benchmark = 8
cnctr_attr_buffer_slice = 8,
cnctr_attr_no_dummy_for_benchmark = 16
};
DEFINE_ENUM_BITMASK_OPERATORS(node_attributes)
DEFINE_ENUM_BITMASK_OPERATORS(connector_attributes)
class NNCASE_API connector_info {
public:
connector_info(op_node &owner, std::string name, size_t index)
: owner_(owner), name_(std::move(name)), index_(index) {}
connector_info(const connector_info &) = delete;
connector_info(connector_info &&) = default;
connector_info &operator=(const connector_info &) = delete;
op_node &owner() const noexcept { return owner_; }
const std::string &name() const noexcept { return name_; }
size_t index() const noexcept { return index_; }
connector_attributes attributes() const noexcept { return attributes_; }
void attributes(connector_attributes value) noexcept {
attributes_ = value;
}
private:
op_node &owner_;
std::string name_;
size_t index_;
connector_attributes attributes_ = cnctr_attr_none;
};
template <class T, class = std::enable_if_t<std::is_pointer_v<T>>>
std::vector<std::decay_t<T>> dup(std::span<T> source) {
return {source.begin(), source.end()};
std::vector<std::decay_t<T>> dup(std::span<T> source)
{
return { source.begin(), source.end() };
}
}
} // namespace nncase::ir

View File

@ -1,53 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../op.h"
#include "nncase/runtime/datatypes.h"
#include "opcode.h"
namespace nncase::ir::math {
/** @brief Binary operator node */
class NNCASE_API binary_node : public op_node {
DEFINE_OBJECT_KIND(op_node, op_math_binary)
public:
binary_node(binary_op_t binary_op);
/** @brief Get the binary opcode of the binary expression */
binary_op_t binary_op() const noexcept { return binary_op_; }
/** @brief Set the binary opcode of the binary expression */
void binary_op(binary_op_t value) noexcept { binary_op_ = value; }
/** @brief Get the lhs the binary expression */
const connector_info &lhs() const noexcept { return parameter_at(0); }
/** @brief Get the rhs the binary expression */
const connector_info &rhs() const noexcept { return parameter_at(1); }
type infer_invoke_result_type(type_infer_context &context) override;
private:
binary_op_t binary_op_;
};
/** @brief Binary expression */
class binary : public object_t<binary_node> {
public:
using object_t::object_t;
/** @brief Construct an binary expression
* @param[in] binary_op The opcode of the binary
*/
NNCASE_API binary(binary_op_t binary_op);
};
} // namespace nncase::ir::math

View File

@ -1,38 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../op.h"
#include "nncase/runtime/datatypes.h"
#include "opcode.h"
namespace nncase::ir::math {
/** @brief Clamp operator node */
class NNCASE_API clamp_node : public op_node {
DEFINE_OBJECT_KIND(op_node, op_math_clamp)
public:
clamp_node();
/** @brief Get the input the clamp expression */
const connector_info &input() const noexcept { return parameter_at(0); }
/** @brief Get the min the clamp expression */
const connector_info &min() const noexcept { return parameter_at(1); }
/** @brief Get the max the clamp expression */
const connector_info &max() const noexcept { return parameter_at(2); }
type infer_invoke_result_type(type_infer_context &context) override;
};
using clamp = object_t<clamp_node>;
} // namespace nncase::ir::math

View File

@ -1,94 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../call.h"
#include "../functional.h"
namespace nncase::ir::F {
NNCASE_API call unary(unary_op_t unary_op, fexpr input);
NNCASE_API call binary(binary_op_t binary_op, fexpr lhs, fexpr rhs);
NNCASE_API call clamp(fexpr input, fexpr min, fexpr max);
#define DEFINE_UNARY_FUNC(name, unary_op) \
inline call name(fexpr input) { return F::unary(unary_op, input); }
DEFINE_UNARY_FUNC(abs, unary_abs)
DEFINE_UNARY_FUNC(ceil, unary_ceil)
DEFINE_UNARY_FUNC(cos, unary_cos)
DEFINE_UNARY_FUNC(exp, unary_exp)
DEFINE_UNARY_FUNC(floor, unary_floor)
DEFINE_UNARY_FUNC(log, unary_log)
DEFINE_UNARY_FUNC(neg, unary_neg)
DEFINE_UNARY_FUNC(round, unary_round)
DEFINE_UNARY_FUNC(rsqrt, unary_rsqrt)
DEFINE_UNARY_FUNC(sin, unary_sin)
DEFINE_UNARY_FUNC(sqrt, unary_sqrt)
DEFINE_UNARY_FUNC(square, unary_square)
DEFINE_UNARY_FUNC(tanh, unary_tanh)
DEFINE_UNARY_FUNC(bitwise_not, unary_bitwise_not)
DEFINE_UNARY_FUNC(logical_not, unary_logical_not)
#undef DEFINE_UNARY_FUNC
#define DEFINE_BINARY_FUNC(name, binary_op) \
inline call name(fexpr lhs, expr rhs) { \
return F::binary(binary_op, lhs, rhs); \
}
DEFINE_BINARY_FUNC(add, binary_add)
DEFINE_BINARY_FUNC(sub, binary_sub)
DEFINE_BINARY_FUNC(mul, binary_mul)
DEFINE_BINARY_FUNC(div, binary_div)
DEFINE_BINARY_FUNC(mod, binary_mod)
DEFINE_BINARY_FUNC(min, binary_min)
DEFINE_BINARY_FUNC(max, binary_max)
DEFINE_BINARY_FUNC(pow, binary_pow)
DEFINE_BINARY_FUNC(bitwise_and, binary_bitwise_and)
DEFINE_BINARY_FUNC(bitwise_or, binary_bitwise_or)
DEFINE_BINARY_FUNC(bitwise_xor, binary_bitwise_xor)
DEFINE_BINARY_FUNC(logical_and, binary_logical_and)
DEFINE_BINARY_FUNC(logical_or, binary_logical_or)
DEFINE_BINARY_FUNC(logical_xor, binary_logical_xor)
#undef DEFINE_BINARY_FUNC
} // namespace nncase::ir::F
namespace nncase::ir {
inline call operator-(expr input) { return F::neg(input); }
inline call operator~(expr input) { return F::bitwise_not(input); }
inline call operator!(expr input) { return F::logical_not(input); }
#define DEFINE_BINARY_OPERATOR(op, impl) \
inline call operator op(expr lhs, expr rhs) { return F::impl(lhs, rhs); } \
inline call operator op(expr lhs, F::fexpr rhs) { \
return F::impl(lhs, rhs); \
} \
inline call operator op(F::fexpr lhs, expr rhs) { \
return F::impl(lhs, rhs); \
}
DEFINE_BINARY_OPERATOR(+, add)
DEFINE_BINARY_OPERATOR(-, sub)
DEFINE_BINARY_OPERATOR(*, mul)
DEFINE_BINARY_OPERATOR(/, div)
DEFINE_BINARY_OPERATOR(%, mod)
DEFINE_BINARY_OPERATOR(&, bitwise_and)
DEFINE_BINARY_OPERATOR(|, bitwise_or)
DEFINE_BINARY_OPERATOR(^, bitwise_xor)
DEFINE_BINARY_OPERATOR(&&, logical_and)
DEFINE_BINARY_OPERATOR(||, logical_or)
#undef DEFINE_BINARY_OPERATOR
} // namespace nncase::ir

View File

@ -1,4 +0,0 @@
DEFINE_OPCODE(math, unary, Unary, 0x1001)
DEFINE_OPCODE(math, binary, Binary, 0x1002)
DEFINE_OPCODE(math, compare, Compare, 0x1003)
DEFINE_OPCODE(math, clamp, Clamp, 0x1004)

View File

@ -1,25 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../../object.h"
namespace nncase::ir::math {
#define DEFINE_OPCODE(dialect, id, name, value) \
NNCASE_INLINE_VAR constexpr object_kind op_##dialect##_##id{value, #name};
#include "opcode.def"
#undef DEFINE_OPCODE
} // namespace nncase::ir::math

View File

@ -1,52 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../call.h"
#include "../op.h"
#include "nncase/runtime/datatypes.h"
#include "opcode.h"
namespace nncase::ir::math {
/** @brief Unary operator node */
class NNCASE_API unary_node : public op_node {
DEFINE_OBJECT_KIND(op_node, op_math_unary)
public:
unary_node(unary_op_t unary_op);
/** @brief Get the unary opcode of the unary expression */
unary_op_t unary_op() const noexcept { return unary_op_; }
/** @brief Set the unary opcode of the unary expression */
void unary_op(unary_op_t value) noexcept { unary_op_ = value; }
/** @brief Get the input the unary expression */
const connector_info &input() const noexcept { return parameter_at(0); }
type infer_invoke_result_type(type_infer_context &context) override;
private:
unary_op_t unary_op_;
};
/** @brief Unary expression */
class unary : public object_t<unary_node> {
public:
using object_t::object_t;
/** @brief Construct an unary expression
* @param[in] unary_op The opcode of the unary
*/
NNCASE_API unary(unary_op_t unary_op);
};
} // namespace nncase::ir::math

View File

@ -1,51 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../object.h"
#include "function.h"
namespace nncase::ir {
/** @brief Module node*/
class NNCASE_API module_node : public object_node {
DEFINE_OBJECT_KIND(object_node, object_module)
public:
module_node();
/** @brief Get the functions of the module */
const std::vector<function> &functions() const noexcept {
return functions_;
}
/** @brief Add new funtion to the module */
const function &add_function(function func);
/** @brief Get the entry of the module */
const function &entry() const noexcept { return entry_; }
/** @brief Get the mutable entry of module */
function &entry() noexcept { return entry_; }
/** @brief Set the entry of the module */
void entry(function value) noexcept { entry_ = std::move(value); }
private:
std::vector<function> functions_;
function entry_;
};
/** @brief Module */
class NNCASE_API module_t : public object_t<module_node> {
public:
using object_t::object_t;
};
} // namespace nncase::ir

View File

@ -1,31 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../call.h"
#include "../op.h"
#include "nncase/runtime/datatypes.h"
#include "opcode.h"
namespace nncase::ir::nn {
/** @brief Conv1D operator node */
class NNCASE_API conv1d_node : public op_node {
DEFINE_OBJECT_KIND(op_node, op_nn_conv1d)
public:
// conv1d_node();
};
/** @brief Conv1D expression */
using conv1d = object_t<conv1d_node>;
} // namespace nncase::ir::nn

View File

@ -1,63 +0,0 @@
DEFINE_OPCODE(nn, conv1d, Conv1D, 0x3001)
DEFINE_OPCODE(nn, conv2d, Conv2D, 0x3002)
DEFINE_OPCODE(nn, conv3d, Conv3D, 0x3003)
DEFINE_OPCODE(nn, conv_transpose1d, ConvTranspose1D, 0x3004)
DEFINE_OPCODE(nn, conv_transpose2d, ConvTranspose2D, 0x3005)
DEFINE_OPCODE(nn, conv_transpose3d, ConvTranspose3D, 0x3006)
DEFINE_OPCODE(nn, avg_pool1d, AvgPool1D, 0x3007)
DEFINE_OPCODE(nn, avg_pool2d, AvgPool2D, 0x3008)
DEFINE_OPCODE(nn, avg_pool3d, AvgPool3D, 0x3009)
DEFINE_OPCODE(nn, max_pool1d, MaxPool1D, 0x300A)
DEFINE_OPCODE(nn, max_pool2d, MaxPool2D, 0x300B)
DEFINE_OPCODE(nn, max_pool3d, MaxPool3D, 0x300C)
DEFINE_OPCODE(nn, l2_pool1d, L2Pool1D, 0x300D)
DEFINE_OPCODE(nn, l2_pool2d, L2Pool2D, 0x300E)
DEFINE_OPCODE(nn, l2_pool3d, L2Pool3D, 0x300F)
DEFINE_OPCODE(nn, adaptive_avg_pool1d, AdaptiveAvgPool1D, 0x3010)
DEFINE_OPCODE(nn, adaptive_avg_pool2d, AdaptiveAvgPool2D, 0x3011)
DEFINE_OPCODE(nn, adaptive_avg_pool3d, AdaptiveAvgPool3D, 0x3012)
DEFINE_OPCODE(nn, adaptive_max_pool1d, AdaptiveMaxPool1D, 0x3013)
DEFINE_OPCODE(nn, adaptive_max_pool2d, AdaptiveMaxPool2D, 0x3014)
DEFINE_OPCODE(nn, adaptive_max_pool3d, AdaptiveMaxPool3D, 0x3015)
DEFINE_OPCODE(nn, relu, ReLU, 0x3016)
DEFINE_OPCODE(nn, hardtanh, Hardtanh, 0x3017)
DEFINE_OPCODE(nn, hardswish, HardSwish, 0x3018)
DEFINE_OPCODE(nn, relu6, ReLU6, 0x3019)
DEFINE_OPCODE(nn, elu, ELU, 0x301A)
DEFINE_OPCODE(nn, selu, SELU, 0x301B)
DEFINE_OPCODE(nn, celu, CELU, 0x301C)
DEFINE_OPCODE(nn, leaky_relu, LeakyReLU, 0x301D)
DEFINE_OPCODE(nn, prelu, PReLU, 0x301E)
DEFINE_OPCODE(nn, rrelu, RReLU, 0x301F)
DEFINE_OPCODE(nn, glu, GLU, 0x3020)
DEFINE_OPCODE(nn, gelu, GELU, 0x3021)
DEFINE_OPCODE(nn, logsigmoid, LogSigmoid, 0x3022)
DEFINE_OPCODE(nn, hardshrink, Hardshrink, 0x3023)
DEFINE_OPCODE(nn, tanhshrink, Tanhshrink, 0x3024)
DEFINE_OPCODE(nn, softsign, SoftSign, 0x3025)
DEFINE_OPCODE(nn, softplus, Softplus, 0x3026)
DEFINE_OPCODE(nn, softmin, Softmin, 0x3027)
DEFINE_OPCODE(nn, softmax, Softmax, 0x3028)
DEFINE_OPCODE(nn, softshrink, Softshrink, 0x3029)
DEFINE_OPCODE(nn, gumbel_softmax, GumbelSoftmax, 0x302A)
DEFINE_OPCODE(nn, log_softmax, LogSoftmax, 0x302B)
DEFINE_OPCODE(nn, sigmoid, Sigmoid, 0x302C)
DEFINE_OPCODE(nn, hardsigmoid, Hardsigmoid, 0x302D)
DEFINE_OPCODE(nn, silu, SiLU, 0x302E)
DEFINE_OPCODE(nn, mish, Mish, 0x302F)
DEFINE_OPCODE(nn, batch_norm, BatchNorm, 0x3030)
DEFINE_OPCODE(nn, group_norm, GroupNorm, 0x3031)
DEFINE_OPCODE(nn, instance_norm, InstanceNorm, 0x3032)
DEFINE_OPCODE(nn, layer_norm, LayerNorm, 0x3033)
DEFINE_OPCODE(nn, local_response_norm, LocalResponseNorm, 0x3034)
DEFINE_OPCODE(nn, normalize, Normalize, 0x3035)
DEFINE_OPCODE(nn, linear, Linear, 0x3036)
DEFINE_OPCODE(nn, bilinear, Bilinear, 0x3037)
DEFINE_OPCODE(nn, embedding, Embedding, 0x3038)
DEFINE_OPCODE(nn, embedding_bag, EmbeddingBag, 0x3039)
DEFINE_OPCODE(nn, one_hot, OneHot, 0x303A)
DEFINE_OPCODE(nn, pixel_shuffle, PixelShuffle, 0x3040)
DEFINE_OPCODE(nn, pad, Pad, 0x3041)
DEFINE_OPCODE(nn, interpolate, Interpolate, 0x3042)
DEFINE_OPCODE(nn, grid_sample, GridSample, 0x3043)
DEFINE_OPCODE(nn, affine_grid, AffineGrid, 0x3044)

View File

@ -1,25 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../../object.h"
namespace nncase::ir::nn {
#define DEFINE_OPCODE(dialect, id, name, value) \
NNCASE_INLINE_VAR constexpr object_kind op_##dialect##_##id{value, #name};
#include "opcode.def"
#undef DEFINE_OPCODE
} // namespace nncase::ir::nn

92
include/nncase/ir/node.h Normal file
View File

@ -0,0 +1,92 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "connectors.h"
#include "opcode.h"
#include <list>
#include <span>
#include <unordered_map>
namespace nncase::ir
{
#define DEFINE_NODE_OPCODE(value) \
static constexpr node_opcode opcode() noexcept { return value; } \
const node_opcode &runtime_opcode() const noexcept override { return value; }
class NNCASE_API node
{
public:
node(std::string name = "");
node(node &) = delete;
node &operator=(node &) = delete;
virtual ~node();
const std::string &name() const noexcept { return name_; }
template <class TArg, class... TArgs>
void name(TArg arg, TArgs... args) { name_.assign(std::forward<TArg>(arg), std::forward<TArgs>(args)...); }
std::string escaped_name() const noexcept;
const module_type_t &module_type() const noexcept { return module_type_; }
void module_type(const module_type_t &type) noexcept { module_type_ = type; }
std::span<input_connector *const> inputs() const noexcept { return input_connectors_; }
std::span<output_connector *const> outputs() const noexcept { return output_connectors_; }
input_connector &input_at(size_t index) const { return *input_connectors_.at(index); }
output_connector &output_at(size_t index) const { return *output_connectors_.at(index); }
virtual const node_opcode &runtime_opcode() const noexcept = 0;
node_attributes attributes() const noexcept { return attributes_; }
void attributes(node_attributes value) noexcept { attributes_ = value; }
bool equals(node &other) const;
void record_output_connectors_quant_map(output_connector &oc_after_quant, output_connector &oc_before_quant) noexcept { output_connectors_quant_map_.emplace(&oc_after_quant, &oc_before_quant); }
std::unordered_map<output_connector *, output_connector *> get_output_connectors_quant_map() const noexcept { return output_connectors_quant_map_; }
void record_node_name_before_quant(std::string name) noexcept { node_name_before_quant_.assign(name); }
std::string get_node_name_before_quant() const noexcept { return node_name_before_quant_; }
protected:
template <class TName, class TShape>
input_connector &add_input(TName &&name, datatype_t type, TShape &&shape)
{
auto ptr = input_connectors_storage_.emplace_back(std::make_unique<input_connector>(*this, std::forward<TName>(name), type, std::forward<TShape>(shape))).get();
input_connectors_.emplace_back(ptr);
return *ptr;
}
template <class TName, class TShape>
output_connector &add_output(TName &&name, datatype_t type, TShape &&shape, memory_location_t memory_location = mem_data)
{
auto ptr = output_connectors_storage_.emplace_back(std::make_unique<output_connector>(*this, std::forward<TName>(name), type, std::forward<TShape>(shape), memory_location)).get();
output_connectors_.emplace_back(ptr);
return *ptr;
}
virtual bool properties_equal(node &other) const = 0;
private:
std::string name_;
module_type_t module_type_;
node_attributes attributes_ = node_attributes::node_attr_action;
std::vector<input_connector *> input_connectors_;
std::vector<output_connector *> output_connectors_;
std::vector<std::unique_ptr<input_connector>> input_connectors_storage_;
std::vector<std::unique_ptr<output_connector>> output_connectors_storage_;
std::unordered_map<output_connector *, output_connector *> output_connectors_quant_map_;
std::string node_name_before_quant_;
};
}

View File

@ -1,47 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "expr.h"
#include "ir_types.h"
#include "type_infer.h"
#include <utility>
namespace nncase::ir {
/** @brief Operator node */
class NNCASE_API op_node : public expr_node {
DEFINE_OBJECT_KIND(expr_node, object_op)
public:
/** @brief Get the parameters of the function expression */
std::span<const connector_info> parameters() const noexcept {
return parameters_;
}
/** @brief Get the parameter at the index */
const connector_info &parameter_at(size_t index) const noexcept {
return parameters_.at(index);
}
/** @brief Infer the invoke result type */
virtual type infer_invoke_result_type(type_infer_context &context) = 0;
protected:
connector_info &add_parameter(std::string name);
private:
std::vector<connector_info> parameters_;
};
using op = object_t<op_node>;
} // namespace nncase::ir

View File

@ -27,13 +27,20 @@ inline shape_t get_transposed_shape(const shape_t &input_shape, const axis_t &pe
return new_shape;
}
inline size_t get_windowed_output_size(int32_t size, int32_t filter, int32_t stride, int32_t dilation, bool same)
inline size_t get_windowed_output_size(int32_t size, int32_t filter, int32_t stride, int32_t dilation, bool same, bool ceil_mode = false)
{
auto effective_filter_size = (filter - 1) * dilation + 1;
if (same)
return (size_t(size) + stride - 1) / stride;
else
return (size_t(size) - effective_filter_size + stride) / stride;
{
if (!ceil_mode)
return (size_t(size) - effective_filter_size + stride) / stride;
else
{
return static_cast<int>(ceil(static_cast<float>(size_t(size) - effective_filter_size + stride) / stride));
}
}
}
inline padding get_windowed_padding(int32_t input_size, int32_t output_size, int32_t filter, int32_t stride, int32_t dilation)
@ -131,7 +138,7 @@ inline shape_t normalize_reshape(const shape_t &in_shape, const axis_t &new_shap
if (v == -1)
{
if (non_det_id)
throw std::runtime_error("Reshap can only have 1 non-determined dimension at most");
throw std::runtime_error("Reshape can only have 1 non-determined dimension at most");
non_det_id = i;
}
else
@ -230,7 +237,7 @@ inline shape_t get_padded_shape(const shape_t &in_shape, const xt::svector<paddi
{
auto new_shape = in_shape;
for (size_t i = 0; i < in_shape.size(); i++)
new_shape[i] = size_t(int32_t(new_shape[i]) + paddings[i].sum());
new_shape[i] = size_t(int32_t(new_shape[i]) + paddings[i].sum() + (new_shape[i] - 1) * paddings[i].interior);
return new_shape;
}
@ -293,6 +300,50 @@ inline shape_t get_strided_slice_output_shape(const axis_t &begin, const axis_t
return new_shape.size() ? new_shape : shape_t { 1 };
}
inline bool is_copy_slice(const axis_t &strides)
{
return std::all_of(strides.begin(), strides.end(), [](int32_t stride) { return stride == 1; });
}
inline bool is_simple_slice(const axis_t &begin, const axis_t &end, const axis_t &strides, const shape_t &input_shape)
{
if (!is_copy_slice(strides))
return false;
bool is_simple_slice = true;
bool allow_not_equal = true;
for (size_t i = 0; i < begin.size(); i++)
{
if (begin[i] != 0
|| end[i] != input_shape[i])
{
if (allow_not_equal)
{
allow_not_equal = false;
}
else
{
is_simple_slice = false;
break;
}
}
else if (input_shape[i] != 1)
{
allow_not_equal = false;
}
}
return is_simple_slice;
}
inline bool is_axis0_squeeze_or_expand_dim_bitcast(const shape_t &in_shape, const shape_t &out_shape)
{
auto in_begin = std::find_if_not(in_shape.begin(), in_shape.end(), [](size_t dim) { return dim == 1; });
auto out_begin = std::find_if_not(out_shape.begin(), out_shape.end(), [](size_t dim) { return dim == 1; });
return std::distance(in_begin, in_shape.end()) == std::distance(out_begin, out_shape.end())
&& std::equal(in_begin, in_shape.end(), out_begin);
}
template <class U, class T>
std::span<U> as_span(const std::span<T> &src) noexcept
{

View File

@ -0,0 +1,48 @@
DEFINE_NEUTRAL_OPCODE(input_node, Input, 0x01)
DEFINE_NEUTRAL_OPCODE(output_node, Output, 0x02)
DEFINE_NEUTRAL_OPCODE(ignore_node, Ignore, 0x03)
DEFINE_NEUTRAL_OPCODE(constant, Constant, 0x04)
DEFINE_NEUTRAL_OPCODE(uninitialized, Uninitialized, 0x05)
DEFINE_NEUTRAL_OPCODE(call, Call, 0x06)
DEFINE_NEUTRAL_OPCODE(copy, Copy, 0x07)
DEFINE_NEUTRAL_OPCODE(conv2d, Conv2D, 0x100)
DEFINE_NEUTRAL_OPCODE(matmul, MatMul, 0x101)
DEFINE_NEUTRAL_OPCODE(transpose, Transpose, 0x102)
DEFINE_NEUTRAL_OPCODE(reduce, Reduce, 0x103)
DEFINE_NEUTRAL_OPCODE(reduce_window2d, ReduceWindow2D, 0x104)
DEFINE_NEUTRAL_OPCODE(binary, Binary, 0x105)
DEFINE_NEUTRAL_OPCODE(concat, Concat, 0x106)
DEFINE_NEUTRAL_OPCODE(unary, Unary, 0x107)
DEFINE_NEUTRAL_OPCODE(fused_unary, FusedUnary, 0x108)
DEFINE_NEUTRAL_OPCODE(quantize, Quantize, 0x109)
DEFINE_NEUTRAL_OPCODE(dequantize, Dequantize, 0x10A)
DEFINE_NEUTRAL_OPCODE(pad, Pad, 0x10B)
DEFINE_NEUTRAL_OPCODE(bitcast, Bitcast, 0x10C)
DEFINE_NEUTRAL_OPCODE(resize_image, ResizeImage, 0x10D)
DEFINE_NEUTRAL_OPCODE(slice, Slice, 0x10E)
DEFINE_NEUTRAL_OPCODE(table_lookup1d, TableLookup1D, 0x10F)
DEFINE_NEUTRAL_OPCODE(conv2d_transpose, Conv2DTranspose, 0x110)
DEFINE_NEUTRAL_OPCODE(clamp, Clamp, 0x111)
DEFINE_NEUTRAL_OPCODE(convert, Convert, 0x112)
DEFINE_NEUTRAL_OPCODE(broadcast, Broadcast, 0x113)
DEFINE_NEUTRAL_OPCODE(take, Take, 0x114)
DEFINE_NEUTRAL_OPCODE(space_to_batch, SpaceToBatch, 0x115)
DEFINE_NEUTRAL_OPCODE(batch_to_space, BatchToSpace, 0x116)
DEFINE_NEUTRAL_OPCODE(split, Split, 0x117)
DEFINE_NEUTRAL_OPCODE(gather, Gather, 0x118)
DEFINE_NEUTRAL_OPCODE(gather_nd, GatherND, 0x119)
DEFINE_NEUTRAL_OPCODE(onehot, OneHot, 0x11A)
DEFINE_NEUTRAL_OPCODE(lstm, LSTM, 0x11B)
DEFINE_NEUTRAL_OPCODE(reduce_arg, ReduceArg, 0x11C)
DEFINE_NEUTRAL_OPCODE(cumsum, CumSum, 0x11D)
DEFINE_NEUTRAL_OPCODE(hardmax, HardMax, 0x11E)
DEFINE_NEUTRAL_OPCODE(random_normal, RandomNormal, 0x11F)
DEFINE_NEUTRAL_OPCODE(random_uniform, RandomUniform, 0x120)
DEFINE_NEUTRAL_OPCODE(reduce_prod, ReduceProd, 0x121)
DEFINE_NEUTRAL_OPCODE(ternary, Ternary, 0x122)
DEFINE_NEUTRAL_OPCODE(topk, TopK, 0x123)
DEFINE_NEUTRAL_OPCODE(trilu, Trilu, 0x124)
DEFINE_NEUTRAL_OPCODE(sigmoid, Sigmoid, 0x125)
DEFINE_NEUTRAL_OPCODE(roi_align, RoiAlign, 0x126)
DEFINE_NEUTRAL_OPCODE(equal, Equal, 0x127)

View File

@ -0,0 +1,51 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <compare>
#include <nncase/runtime/datatypes.h>
#include <stdexcept>
#include <string_view>
#include <type_traits>
namespace nncase::ir
{
struct node_opcode
{
uint32_t id;
std::string_view name;
};
constexpr inline bool operator==(const node_opcode &lhs, const node_opcode &rhs) noexcept { return lhs.id == rhs.id; }
#define DEFINE_NEUTRAL_OPCODE(id, name, value) NNCASE_INLINE_VAR constexpr node_opcode op_##id { value, #name };
#define DEFINE_OPCODE(target, id, name, value) NNCASE_INLINE_VAR constexpr node_opcode op_##target##_##id { value, #name };
#include "opcode.def"
#undef DEFINE_NEUTRAL_OPCODE
#undef DEFINE_OPCODE
}
namespace std
{
template <>
struct hash<nncase::ir::node_opcode>
{
[[nodiscard]] size_t operator()(const nncase::ir::node_opcode &opcode) const noexcept
{
return std::hash<uint32_t>()(opcode.id);
}
};
}

View File

@ -0,0 +1,62 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xarray.hpp>
namespace nncase::ir
{
class NNCASE_API batch_to_space : public node
{
public:
DEFINE_NODE_OPCODE(op_batch_to_space);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
int32_t block_size_h() const noexcept { return block_size_h_; }
int32_t block_size_w() const noexcept { return block_size_w_; }
const axis_t &begin() const noexcept { return begin_; }
const axis_t &end() const noexcept { return end_; }
const axis_t &strides() const noexcept { return strides_; }
int32_t begin_mask() const noexcept { return begin_mask_; }
int32_t end_mask() const noexcept { return end_mask_; }
int32_t ellipsis_mask() const noexcept { return ellipsis_mask_; }
int32_t new_axis_mask() const noexcept { return new_axis_mask_; }
int32_t shrink_axis_mask() const noexcept { return shrink_axis_mask_; }
std::array<int32_t, 2> crop_h() const noexcept { return crop_h_; }
std::array<int32_t, 2> crop_w() const noexcept { return crop_w_; }
batch_to_space(datatype_t input_type, shape_t input_shape, int32_t block_shape_h, int32_t block_shape_w, axis_t stride, axis_t begin, axis_t end, std::array<int32_t, 2> crop_h_, std::array<int32_t, 2> crop_w_);
protected:
bool properties_equal(node &other) const override;
private:
int32_t block_size_h_;
int32_t block_size_w_;
axis_t begin_;
axis_t end_;
axis_t strides_;
int32_t begin_mask_;
int32_t end_mask_;
int32_t ellipsis_mask_;
int32_t new_axis_mask_;
int32_t shrink_axis_mask_;
std::array<int32_t, 2> crop_h_;
std::array<int32_t, 2> crop_w_;
};
}

View File

@ -0,0 +1,42 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API binary : public node
{
public:
DEFINE_NODE_OPCODE(op_binary);
input_connector &input_a() { return input_at(0); }
input_connector &input_b() { return input_at(1); }
output_connector &output() { return output_at(0); }
binary_op_t binary_op() const noexcept { return binary_op_; }
value_range<float> fused_activation() const noexcept { return fused_activation_; }
binary(binary_op_t binary_op, shape_t input_a_shape, shape_t input_b_shape, value_range<float> input_fused_activation);
protected:
bool properties_equal(node &other) const override;
private:
binary_op_t binary_op_;
value_range<float> fused_activation_;
};
}

View File

@ -0,0 +1,46 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API bitcast : public node
{
public:
DEFINE_NODE_OPCODE(op_bitcast);
const input_connector &input() const { return input_at(0); }
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
bool is_reshape() const noexcept { return new_type_ == input().type(); }
datatype_t new_type() const noexcept { return new_type_; }
const shape_t &new_shape() const noexcept { return new_shape_; }
bitcast(datatype_t input_type, shape_t input_shape, axis_t new_shape);
bitcast(datatype_t input_type, shape_t input_shape, shape_t new_shape);
bitcast(datatype_t input_type, shape_t input_shape, datatype_t new_type, axis_t new_shape);
bitcast(datatype_t input_type, shape_t input_shape, datatype_t new_type, shape_t new_shape);
protected:
bool properties_equal(node &other) const override;
private:
datatype_t new_type_;
shape_t new_shape_;
};
}

View File

@ -13,17 +13,28 @@
* limitations under the License.
*/
#pragma once
#include "../op.h"
#include "nncase/runtime/datatypes.h"
#include "opcode.h"
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir::tensors {
/** @brief Broadcast operator node */
class NNCASE_API broadcast_node : public op_node {
DEFINE_OBJECT_KIND(op_node, op_tensors_broadcast)
public:
broadcast_node();
namespace nncase::ir
{
class NNCASE_API broadcast : public node
{
public:
DEFINE_NODE_OPCODE(op_broadcast);
const input_connector &input() const { return input_at(0); }
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
const shape_t &new_shape() const noexcept { return new_shape_; }
broadcast(datatype_t input_type, shape_t input_shape, shape_t new_shape);
protected:
bool properties_equal(node &other) const override;
private:
shape_t new_shape_;
};
using broadcast = object_t<broadcast_node>;
} // namespace nncase::ir::tensors
}

View File

@ -0,0 +1,41 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../graph.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API call : public node
{
public:
DEFINE_NODE_OPCODE(op_call);
graph &target() const noexcept { return target_; }
call(graph &target);
input_connector &outer_connector(input_node &target_input);
input_connector &outer_connector(input_connector &target_input);
output_connector &outer_connector(output_node &target_output);
output_connector &outer_connector(output_connector &target_output);
protected:
bool properties_equal(node &other) const override;
private:
graph &target_;
};
}

View File

@ -0,0 +1,39 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API clamp : public node
{
public:
DEFINE_NODE_OPCODE(op_clamp);
input_connector &input() { return input_at(0); }
input_connector &input_low() { return input_at(1); }
const input_connector &input_low() const { return input_at(1); }
input_connector &input_high() { return input_at(2); }
const input_connector &input_high() const { return input_at(2); }
output_connector &output() { return output_at(0); }
const output_connector &output() const { return output_at(0); }
clamp(shape_t input_shape, shape_t input_low_shape, shape_t input_high_shape);
protected:
bool properties_equal([[maybe_unused]] node &other) const override { return true; }
};
}

View File

@ -13,25 +13,28 @@
* limitations under the License.
*/
#pragma once
#include "../op.h"
#include "nncase/runtime/datatypes.h"
#include "opcode.h"
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir::tensors {
/** @brief Split operator node */
class NNCASE_API split_node : public op_node {
DEFINE_OBJECT_KIND(op_node, op_tensors_split)
public:
split_node(int32_t axis);
namespace nncase::ir
{
class NNCASE_API concat : public node
{
public:
DEFINE_NODE_OPCODE(op_concat);
output_connector &output() { return output_at(0); }
/** @brief Get the axis of the concat expression */
int32_t axis() const noexcept { return axis_; }
/** @brief Set the axis of the concat expression */
void axis(int32_t value) noexcept { axis_ = value; }
std::span<const size_t> concat_dims() const noexcept { return concat_dims_; }
private:
concat(datatype_t type, std::span<shape_t> input_shapes, int32_t axis);
protected:
bool properties_equal(node &other) const override;
private:
int32_t axis_;
std::vector<size_t> concat_dims_;
};
using split = object_t<split_node>;
} // namespace nncase::ir::tensors
}

View File

@ -0,0 +1,141 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../debug.h"
#include "../node.h"
#include "../op_utils.h"
#include <nncase/runtime/debug.h>
#include <vector>
namespace nncase::ir
{
class NNCASE_API constant : public node
{
public:
DEFINE_NODE_OPCODE(op_constant);
output_connector &output() { return output_at(0); }
const output_connector &output() const { return output_at(0); }
size_t alignment() const noexcept { return alignment_; }
void alignment(size_t value) { alignment_ = value; }
std::span<const std::byte> data() const noexcept { return data_; }
datatype_t data_type() { return datatype_; }
template <class TShape>
constant(datatype_t type, TShape &&shape, std::span<const std::byte> data)
: constant(type, std::forward<TShape>(shape), data.begin(), data.end())
{
}
template <class TShape>
constant(datatype_t type, TShape &&shape, gsl::span<const gsl::byte> data)
: constant(type, std::forward<TShape>(shape), reinterpret_cast<const std::byte *>(data.begin()), reinterpret_cast<const std::byte *>(data.end()))
{
}
template <class TShape, class T>
constant(datatype_t type, TShape &&shape, std::span<const T> data)
: constant(type, std::forward<TShape>(shape), std::as_bytes(data))
{
}
template <class TShape, class T>
constant(datatype_t type, TShape &&shape, gsl::span<const T> data)
: constant(type, std::forward<TShape>(shape), gsl::as_bytes(data))
{
}
template <class TShape, class T>
constant(datatype_t type, TShape &&shape, std::span<T> data)
: constant(type, std::forward<TShape>(shape), std::as_bytes(data))
{
}
template <class TShape, class T>
constant(datatype_t type, TShape &&shape, gsl::span<T> data)
: constant(type, std::forward<TShape>(shape), gsl::as_bytes(data))
{
}
template <class TShape, class T>
constant(datatype_t type, TShape &&shape, const std::vector<T> &data)
: constant(type, std::forward<TShape>(shape), std::as_bytes(std::span<const T>(data)))
{
}
template <class TShape, class... TDataArgs>
constant(datatype_t type, TShape &&shape, TDataArgs... data_args)
: data_(std::forward<TDataArgs>(data_args)...), datatype_(type)
{
if (ir::get_bytes(type, shape) != data_.size())
throw std::invalid_argument("Shape and data size don't match");
add_output("output", type, std::forward<TShape>(shape), mem_rdata)
.attributes(cnctr_attr_no_layout_strides);
}
template <class TScalar>
constant(TScalar scalar)
: constant(to_datatype<TScalar>(), shape_t { 1 }, std::span<const TScalar>(&scalar, 1))
{
}
std::string to_string() const
{
auto shape = this->output().shape();
auto dtype = this->output().type();
auto total_size = 1;
for (auto i : shape)
{
total_size *= i;
}
if (total_size == 1)
{
switch (dtype)
{
case dt_int8:
return std::to_string(*(to_cpp_type_t<dt_int8> *)data_.data());
case dt_uint8:
return std::to_string(*(to_cpp_type_t<dt_uint8> *)data_.data());
#define DT_TO_STRING_CASE(dt) \
case dt: \
return std::to_string(*(to_cpp_type_t<dt> *)data_.data());
DT_TO_STRING_CASE(dt_uint32);
DT_TO_STRING_CASE(dt_float32);
DT_TO_STRING_CASE(dt_bfloat16);
DT_TO_STRING_CASE(dt_int32);
#undef DT_TO_STRING_CASE
default:
throw "un supported dtype to_string: " + std::string(nncase::datatype_names(dtype));
}
}
else
{
return "[...]";
}
}
protected:
bool properties_equal(node &other) const override;
private:
std::vector<std::byte> data_;
datatype_t datatype_;
size_t alignment_ = 8;
};
}

View File

@ -0,0 +1,62 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xarray.hpp>
namespace nncase::ir
{
class NNCASE_API conv2d : public node
{
public:
DEFINE_NODE_OPCODE(op_conv2d);
const input_connector &weights() const { return input_at(1); }
input_connector &input() { return input_at(0); }
input_connector &weights() { return input_at(1); }
input_connector &bias() { return input_at(2); }
output_connector &output() { return output_at(0); }
int32_t filter_h() const noexcept { return (int32_t)weights().shape()[2]; }
int32_t filter_w() const noexcept { return (int32_t)weights().shape()[3]; }
int32_t input_channels() const noexcept { return (int32_t)weights().shape()[1] * groups(); }
int32_t output_channels() const noexcept { return (int32_t)weights().shape()[0]; }
int32_t groups() const noexcept { return groups_; }
bool is_depthwise() const noexcept { return input_channels() == output_channels() && output_channels() == groups() && groups() != 1; }
padding padding_h() const noexcept { return padding_h_; }
padding padding_w() const noexcept { return padding_w_; }
int32_t stride_h() const noexcept { return stride_h_; }
int32_t stride_w() const noexcept { return stride_w_; }
int32_t dilation_h() const noexcept { return dilation_h_; }
int32_t dilation_w() const noexcept { return dilation_w_; }
value_range<float> fused_activation() const noexcept { return fused_activation_; }
conv2d(shape_t input_shape, shape_t weights_shape, int32_t groups, padding padding_h, padding padding_w, int32_t stride_h, int32_t stride_w, int32_t dilation_h, int32_t dilation_w, value_range<float> fused_activation);
protected:
bool properties_equal(node &other) const override;
private:
int32_t groups_;
padding padding_h_;
padding padding_w_;
int32_t stride_h_;
int32_t stride_w_;
int32_t dilation_h_;
int32_t dilation_w_;
value_range<float> fused_activation_;
};
}

View File

@ -0,0 +1,65 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API conv2d_transpose : public node
{
public:
DEFINE_NODE_OPCODE(op_conv2d_transpose);
const input_connector &weights() const { return input_at(1); }
input_connector &input() { return input_at(0); }
input_connector &weights() { return input_at(1); }
input_connector &bias() { return input_at(2); }
output_connector &output() { return output_at(0); }
int32_t filter_h() const noexcept { return (int32_t)weights().shape()[2]; }
int32_t filter_w() const noexcept { return (int32_t)weights().shape()[3]; }
int32_t input_channels() const noexcept { return (int32_t)weights().shape()[1] * groups(); }
int32_t output_channels() const noexcept { return (int32_t)weights().shape()[0]; }
int32_t groups() const noexcept { return groups_; }
padding padding_h() const noexcept { return padding_h_; }
padding padding_w() const noexcept { return padding_w_; }
int32_t output_padding_h() const noexcept { return output_padding_h_; }
int32_t output_padding_w() const noexcept { return output_padding_w_; }
int32_t stride_h() const noexcept { return stride_h_; }
int32_t stride_w() const noexcept { return stride_w_; }
int32_t dilation_h() const noexcept { return dilation_h_; }
int32_t dilation_w() const noexcept { return dilation_w_; }
value_range<float> fused_activation() const noexcept { return fused_activation_; }
conv2d_transpose(shape_t input_shape, shape_t weights_shape, shape_t output_shape, int32_t groups, padding padding_h, padding padding_w, int32_t output_padding_h, int32_t output_padding_w, int32_t stride_h, int32_t stride_w, int32_t dilation_h, int32_t dilation_w, value_range<float> fused_activation);
protected:
bool properties_equal(node &other) const override;
private:
int32_t groups_;
padding padding_h_;
padding padding_w_;
int32_t output_padding_h_;
int32_t output_padding_w_;
int32_t stride_h_;
int32_t stride_w_;
int32_t dilation_h_;
int32_t dilation_w_;
value_range<float> fused_activation_;
};
}

View File

@ -0,0 +1,39 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API convert : public node
{
public:
DEFINE_NODE_OPCODE(op_convert);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
datatype_t new_type() const noexcept { return new_type_; }
convert(datatype_t input_type, shape_t input_shape, datatype_t new_type);
protected:
bool properties_equal(node &other) const override;
private:
datatype_t new_type_;
};
}

View File

@ -0,0 +1,35 @@
/* Copyright 2019-2020 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API copy : public node
{
public:
DEFINE_NODE_OPCODE(op_copy);
const input_connector &input() const { return input_at(0); }
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
copy(datatype_t input_type, shape_t input_shape);
protected:
bool properties_equal(node &other) const override;
};
}

View File

@ -0,0 +1,42 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
namespace nncase::ir
{
class NNCASE_API cumsum : public node
{
public:
DEFINE_NODE_OPCODE(op_cumsum);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
int32_t axis() const noexcept { return axis_; }
bool exclusive() const noexcept { return exclusive_; }
bool reverse() const noexcept { return reverse_; }
cumsum(datatype_t input_type, shape_t input_shape, int32_t axis, bool exclusive = false, bool reverse = false);
protected:
bool properties_equal(node &other) const override;
private:
int32_t axis_;
bool exclusive_;
bool reverse_;
};
}

View File

@ -0,0 +1,39 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API dequantize : public node
{
public:
DEFINE_NODE_OPCODE(op_dequantize);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
const quant_param_t quant_param() const noexcept { return quant_param_; }
dequantize(datatype_t input_type, shape_t input_shape, datatype_t output_type, quant_param_t quant_param);
protected:
bool properties_equal(node &other) const override;
private:
quant_param_t quant_param_;
};
}

View File

@ -0,0 +1,37 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API equal : public node
{
public:
DEFINE_NODE_OPCODE(op_equal);
input_connector &input_a() { return input_at(0); }
input_connector &input_b() { return input_at(1); }
output_connector &output() { return output_at(0); }
equal(datatype_t input_type, shape_t input_a_shape, shape_t input_b_shape);
protected:
bool properties_equal(node &other) const override;
private:
};
}

View File

@ -0,0 +1,150 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../graph.h"
#include "../node.h"
#include <nncase/codegen/nnil_builder.h>
#include <nncase/runtime/nnil.h>
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
enum fused_unary_opcode
{
fu_constant,
fu_identity,
fu_ldx,
fu_unary,
fu_binary,
fu_clamp
};
struct fused_unary_arg
{
size_t op_id;
};
struct fused_unary_constant
{
float value;
};
struct fused_unary_identity
{
fused_unary_arg input;
};
struct fused_unary_ldx
{
};
struct fused_unary_unary
{
unary_op_t unary_op;
fused_unary_arg input;
};
struct fused_unary_binary
{
binary_op_t binary_op;
fused_unary_arg input_a;
fused_unary_arg input_b;
};
struct fused_unary_clamp
{
fused_unary_arg input;
fused_unary_arg low;
fused_unary_arg high;
};
struct fused_unary_op
{
fused_unary_opcode opcode;
union
{
fused_unary_constant constant;
fused_unary_identity identity;
fused_unary_ldx ldx;
fused_unary_unary unary;
fused_unary_binary binary;
fused_unary_clamp clamp;
};
static fused_unary_op make_ldx() noexcept
{
fused_unary_op op { fu_ldx, {} };
return op;
}
static fused_unary_op make_constant(float value) noexcept
{
fused_unary_op op { fu_constant, {} };
op.constant.value = value;
return op;
}
static fused_unary_op make_unary(unary_op_t unary_op, fused_unary_arg input) noexcept
{
fused_unary_op op { fu_unary, {} };
op.unary = { unary_op, input };
return op;
}
static fused_unary_op make_binary(binary_op_t binary_op, fused_unary_arg input_a, fused_unary_arg input_b) noexcept
{
fused_unary_op op { fu_binary, {} };
op.binary = { binary_op, input_a, input_b };
return op;
}
static fused_unary_op make_clamp(fused_unary_arg input, fused_unary_arg low, fused_unary_arg high) noexcept
{
fused_unary_op op { fu_clamp, {} };
op.clamp = { input, low, high };
return op;
}
static fused_unary_op make_identity(fused_unary_arg input) noexcept
{
fused_unary_op op { fu_identity, {} };
op.identity = { input };
return op;
}
};
NNCASE_API std::vector<fused_unary_op> concat_subgraph(const std::vector<fused_unary_op> &src1, const std::vector<fused_unary_op> &src2);
class NNCASE_API fused_unary : public node
{
public:
static void compile_graph(const std::vector<fused_unary_op> &subgraph, codegen::nnil_builder &builder);
DEFINE_NODE_OPCODE(op_fused_unary);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
std::vector<fused_unary_op> &subgraph() noexcept { return subgraph_; }
fused_unary(std::vector<fused_unary_op> subgraph, shape_t in_shape);
protected:
bool properties_equal(node &other) const override;
private:
std::vector<fused_unary_op> subgraph_;
};
}

View File

@ -0,0 +1,40 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API gather : public node
{
public:
DEFINE_NODE_OPCODE(op_gather);
input_connector &input() { return input_at(0); }
input_connector &indices() { return input_at(1); }
output_connector &output() { return output_at(0); }
int32_t axis() const noexcept { return axis_; }
gather(datatype_t in_type, shape_t input_shape, shape_t indices_shape, shape_t output_shape, int32_t axis);
protected:
bool properties_equal(node &other) const override;
private:
int32_t axis_;
};
}

View File

@ -0,0 +1,40 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API gather_nd : public node
{
public:
DEFINE_NODE_OPCODE(op_gather_nd);
input_connector &input() { return input_at(0); }
input_connector &indices() { return input_at(1); }
output_connector &output() { return output_at(0); }
int32_t batch_dims() const noexcept { return batch_dims_; }
gather_nd(datatype_t type, shape_t input_shape, shape_t indices_shape, shape_t output_shape, int32_t batch_dims);
protected:
bool properties_equal(node &other) const override;
private:
int32_t batch_dims_;
};
}

View File

@ -0,0 +1,37 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
namespace nncase::ir
{
class NNCASE_API hardmax : public node
{
public:
DEFINE_NODE_OPCODE(op_hardmax);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
int32_t axis() const noexcept { return axis_; }
hardmax(datatype_t input_type, shape_t input_shape, int32_t axis);
protected:
bool properties_equal(node &other) const override;
private:
int32_t axis_;
};
}

View File

@ -0,0 +1,52 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API lstm : public node
{
public:
DEFINE_NODE_OPCODE(op_lstm);
input_connector &input() { return input_at(0); }
input_connector &w_xc() { return input_at(1); }
input_connector &b_xc() { return input_at(2); }
input_connector &w_rc() { return input_at(3); }
input_connector &b_rc() { return input_at(4); }
input_connector &initial_h() { return input_at(5); }
input_connector &initial_c() { return input_at(6); }
input_connector &w_static() { return input_at(7); }
output_connector &output() { return output_at(0); }
output_connector &output_h() { return output_at(1); }
output_connector &output_c() { return output_at(2); }
int32_t num_output() const noexcept { return num_output_; }
bool has_static() const noexcept { return has_static_; }
std::string framework() const noexcept { return framework_; }
lstm(shape_t input_shape, shape_t w_xc_shape, shape_t b_xc_shape, shape_t w_rc_shape, shape_t b_rc_shape, shape_t initial_h_shape, shape_t initial_c_shape, int32_t num_output, bool has_static, std::string framework);
protected:
bool properties_equal(node &other) const override;
private:
int32_t num_output_;
bool has_static_;
std::string framework_;
};
}

View File

@ -0,0 +1,41 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API matmul : public node
{
public:
DEFINE_NODE_OPCODE(op_matmul);
input_connector &input_a() { return input_at(0); }
input_connector &input_b() { return input_at(1); }
input_connector &bias() { return input_at(2); }
output_connector &output() { return output_at(0); }
value_range<float> fused_activation() const noexcept { return fused_activation_; }
matmul(shape_t input_a_shape, shape_t input_b_shape, value_range<float> fused_activation);
protected:
bool properties_equal(node &other) const override;
private:
value_range<float> fused_activation_;
};
}

View File

@ -0,0 +1,44 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API onehot : public node
{
public:
DEFINE_NODE_OPCODE(op_onehot);
input_connector &indices() { return input_at(0); }
input_connector &depth() { return input_at(1); }
input_connector &on_value() { return input_at(2); }
input_connector &off_value() { return input_at(3); }
output_connector &output() { return output_at(0); }
int32_t axis() const noexcept { return axis_; }
onehot_mode_t mode() const noexcept { return mode_; }
onehot(datatype_t type, shape_t indices_shape, shape_t output_shape, int32_t axis, onehot_mode_t mode = onehot_mode_t::onehot_normal);
protected:
bool properties_equal(node &other) const override;
private:
int32_t axis_;
onehot_mode_t mode_;
};
}

View File

@ -0,0 +1,43 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API pad : public node
{
public:
DEFINE_NODE_OPCODE(op_pad);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
const xt::svector<padding> &paddings() const noexcept { return paddings_; }
pad_mode_t pad_mode() const noexcept { return pad_mode_; }
const scalar &pad_value() const noexcept { return pad_value_; }
pad(datatype_t type, shape_t input_shape, xt::svector<padding> paddings, pad_mode_t pad_mode, scalar pad_value);
protected:
bool properties_equal(node &other) const override;
private:
xt::svector<padding> paddings_;
pad_mode_t pad_mode_;
scalar pad_value_;
};
}

View File

@ -0,0 +1,39 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API quantize : public node
{
public:
DEFINE_NODE_OPCODE(op_quantize);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
const quant_param_t quant_param() const noexcept { return quant_param_; }
quantize(datatype_t input_type, shape_t input_shape, datatype_t output_type, quant_param_t quant_param);
protected:
bool properties_equal(node &other) const override;
private:
quant_param_t quant_param_;
};
}

View File

@ -0,0 +1,40 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
namespace nncase::ir
{
class NNCASE_API random_normal : public node
{
public:
DEFINE_NODE_OPCODE(op_random_normal);
output_connector &output() { return output_at(0); }
float mean() const noexcept { return mean_; }
float std() const noexcept { return std_; }
float seed() const noexcept { return seed_; }
random_normal(datatype_t output_type, shape_t output_shape, float mean, float std, float seed);
protected:
bool properties_equal(node &other) const override;
private:
float mean_;
float std_;
float seed_;
};
}

View File

@ -0,0 +1,40 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
namespace nncase::ir
{
class NNCASE_API random_uniform : public node
{
public:
DEFINE_NODE_OPCODE(op_random_uniform);
output_connector &output() { return output_at(0); }
float low() const noexcept { return low_; }
float high() const noexcept { return high_; }
float seed() const noexcept { return seed_; }
random_uniform(datatype_t output_type, shape_t output_shape, float low, float high, float seed);
protected:
bool properties_equal(node &other) const override;
private:
float low_;
float high_;
float seed_;
};
}

View File

@ -0,0 +1,45 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API reduce : public node
{
public:
DEFINE_NODE_OPCODE(op_reduce);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
reduce_op_t reduce_op() const noexcept { return reduce_op_; }
const axis_t &axis() const noexcept { return axis_; }
float init_value() const noexcept { return init_value_; }
bool keep_dims() const noexcept { return keep_dims_; }
reduce(reduce_op_t reduce_op, shape_t input_shape, axis_t axis, float init_value, bool keep_dims);
protected:
bool properties_equal(node &other) const override;
private:
reduce_op_t reduce_op_;
axis_t axis_;
float init_value_;
bool keep_dims_;
};
}

View File

@ -0,0 +1,44 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
namespace nncase::ir
{
class NNCASE_API reduce_arg : public node
{
public:
DEFINE_NODE_OPCODE(op_reduce_arg);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
reduce_arg_op_t reduce_arg_op() const noexcept { return reduce_arg_op_; }
int32_t axis() const noexcept { return axis_; }
bool keep_dims() const noexcept { return keep_dims_; }
bool select_last_index() const noexcept { return select_last_index_; }
reduce_arg(reduce_arg_op_t op, datatype_t input_type, shape_t input_shape, datatype_t output_type, int32_t axis, bool keep_dims = true, bool select_last_index = false);
protected:
bool properties_equal(node &other) const override;
private:
reduce_arg_op_t reduce_arg_op_;
int32_t axis_;
bool keep_dims_;
bool select_last_index_;
};
}

View File

@ -0,0 +1,40 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
namespace nncase::ir
{
class NNCASE_API reduce_prod : public node
{
public:
DEFINE_NODE_OPCODE(op_reduce_prod);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
const axis_t &axis() const noexcept { return axis_; }
bool keep_dims() const noexcept { return keep_dims_; }
reduce_prod(shape_t input_shape, axis_t axis, bool keep_dims);
protected:
bool properties_equal(node &other) const override;
private:
axis_t axis_;
bool keep_dims_;
};
}

View File

@ -0,0 +1,67 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API reduce_window2d : public node
{
public:
DEFINE_NODE_OPCODE(op_reduce_window2d);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
reduce_op_t reduce_op() const noexcept { return reduce_op_; }
float init_value() const noexcept { return init_value_; }
int32_t filter_h() const noexcept { return filter_h_; }
int32_t filter_w() const noexcept { return filter_w_; }
padding padding_h() const noexcept { return padding_h_; }
padding padding_w() const noexcept { return padding_w_; }
int32_t stride_h() const noexcept { return stride_h_; }
int32_t stride_w() const noexcept { return stride_w_; }
int32_t dilation_h() const noexcept { return dilation_h_; }
int32_t dilation_w() const noexcept { return dilation_w_; }
value_range<float> fused_activation() const noexcept { return fused_activation_; }
bool ceil_mode() const noexcept { return ceil_mode_; }
bool count_include_pad() const noexcept { return count_include_pad_; }
std::vector<int32_t> padding_h_w_after() const noexcept { return padding_h_w_after_; }
bool strict_inside_input() const noexcept { return strict_inside_input_; }
reduce_window2d(reduce_op_t reduce_op, shape_t input_shape, float init_value, int32_t filter_h, int32_t filter_w, padding padding_h, padding padding_w, int32_t stride_h, int32_t stride_w, int32_t dilation_h, int32_t dilation_w, value_range<float> fused_activation, bool ceil_mode = false, bool count_include_pad = false, std::vector<int32_t> padding_h_w_after = { 0, 0 }, bool strict_inside_input = false);
protected:
bool properties_equal(node &other) const override;
private:
reduce_op_t reduce_op_;
float init_value_;
int32_t filter_h_;
int32_t filter_w_;
padding padding_h_;
padding padding_w_;
int32_t stride_h_;
int32_t stride_w_;
int32_t dilation_h_;
int32_t dilation_w_;
value_range<float> fused_activation_;
bool ceil_mode_;
bool count_include_pad_;
std::vector<int32_t> padding_h_w_after_;
bool strict_inside_input_;
};
}

View File

@ -0,0 +1,45 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API resize_image : public node
{
public:
DEFINE_NODE_OPCODE(op_resize_image);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
const std::array<int32_t, 2> &new_size() const noexcept { return new_size_; }
image_resize_mode_t mode() const noexcept { return mode_; }
bool align_corners() const noexcept { return align_corners_; }
bool half_pixel_centers() const noexcept { return half_pixel_centers_; }
resize_image(datatype_t type, image_resize_mode_t mode, shape_t input_shape, std::array<int32_t, 2> new_size,
bool align_corners = false, bool half_pixel_centers = false);
protected:
bool properties_equal(node &other) const override;
private:
std::array<int32_t, 2> new_size_;
image_resize_mode_t mode_;
bool align_corners_;
bool half_pixel_centers_;
};
}

View File

@ -0,0 +1,46 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
namespace nncase::ir
{
class NNCASE_API roi_align : public node
{
public:
DEFINE_NODE_OPCODE(op_roi_align);
input_connector &input() { return input_at(0); }
input_connector &rois() { return input_at(1); }
input_connector &batch_indices() { return input_at(2); }
output_connector &output() { return output_at(0); }
roi_align_mode_t mode() const noexcept { return mode_; }
const float &spatial_scale() const noexcept { return spatial_scale_; }
const int64_t &sampling_ratio() const noexcept { return sampling_ratio_; }
roi_align(datatype_t input_type, shape_t input_shape, shape_t rois, shape_t batch_indices, roi_align_mode_t mode,
float spatial_scale, int64_t output_height, int64_t output_width, int64_t sampling_ratio);
protected:
bool properties_equal(node &other) const override;
private:
roi_align_mode_t mode_;
float spatial_scale_;
int64_t sampling_ratio_;
};
}

View File

@ -13,18 +13,23 @@
* limitations under the License.
*/
#pragma once
#include "egraph.h"
#include "../node.h"
namespace nncase::ir::transforms {
class NNCASE_API egraph_pattern {
public:
virtual bool match() = 0;
namespace nncase::ir
{
class NNCASE_API sigmoid : public node
{
public:
DEFINE_NODE_OPCODE(op_sigmoid);
private:
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
sigmoid(datatype_t input_type, shape_t input_shape);
protected:
bool properties_equal(node &other) const override;
private:
};
namespace patterns {
class NNCASE_API wildcard : public egraph_pattern {};
} // namespace patterns
} // namespace nncase::ir::transforms
}

View File

@ -0,0 +1,52 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API slice : public node
{
public:
DEFINE_NODE_OPCODE(op_slice);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
const axis_t &begin() const noexcept { return begin_; }
const axis_t &end() const noexcept { return end_; }
const axis_t &strides() const noexcept { return strides_; }
int32_t begin_mask() const noexcept { return begin_mask_; }
int32_t end_mask() const noexcept { return end_mask_; }
int32_t ellipsis_mask() const noexcept { return ellipsis_mask_; }
int32_t new_axis_mask() const noexcept { return new_axis_mask_; }
slice(datatype_t type, shape_t input_shape, axis_t begin, axis_t end);
slice(datatype_t type, shape_t input_shape, axis_t begin, axis_t end, axis_t strides, int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask, int32_t new_axis_mask);
protected:
bool properties_equal(node &other) const override;
private:
axis_t begin_;
axis_t end_;
axis_t strides_;
int32_t begin_mask_;
int32_t end_mask_;
int32_t ellipsis_mask_;
int32_t new_axis_mask_;
};
}

View File

@ -0,0 +1,47 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xarray.hpp>
namespace nncase::ir
{
class NNCASE_API space_to_batch : public node
{
public:
DEFINE_NODE_OPCODE(op_space_to_batch);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
int32_t block_size_h() const noexcept { return block_size_h_; }
int32_t block_size_w() const noexcept { return block_size_w_; }
padding padding_h() const noexcept { return padding_h_; }
padding padding_w() const noexcept { return padding_w_; }
const scalar &pad_value() const noexcept { return pad_value_; }
space_to_batch(datatype_t input_type, shape_t input_shape, int32_t block_shape_h, int32_t block_shape_w, padding padding_h, padding padding_w, scalar pad_value);
protected:
bool properties_equal(node &other) const override;
private:
int32_t block_size_h_;
int32_t block_size_w_;
padding padding_h_;
padding padding_w_;
scalar pad_value_;
};
}

View File

@ -0,0 +1,42 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
namespace nncase::ir
{
class NNCASE_API split : public node
{
public:
DEFINE_NODE_OPCODE(op_split);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
std::vector<size_t> indices_or_sections() const noexcept { return indices_or_sections_; }
int32_t axis() const noexcept { return axis_; }
bool is_indices() const noexcept { return is_indices_; }
split(datatype_t type, shape_t inputs_shape, std::vector<shape_t> outputs_shape, std::vector<size_t> indices_or_sections, int32_t axis, bool is_indices);
protected:
bool properties_equal(node &other) const override;
private:
std::vector<size_t> indices_or_sections_;
int32_t axis_;
bool is_indices_;
};
}

View File

@ -0,0 +1,35 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API table_lookup1d : public node
{
public:
DEFINE_NODE_OPCODE(op_table_lookup1d);
input_connector &input() { return input_at(0); }
input_connector &table() { return input_at(1); }
output_connector &output() { return output_at(0); }
table_lookup1d(datatype_t type, shape_t input_shape, size_t table_size);
protected:
bool properties_equal([[maybe_unused]] node &other) const override { return true; }
};
}

View File

@ -0,0 +1,42 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API take : public node
{
public:
DEFINE_NODE_OPCODE(op_take);
input_connector &input() { return input_at(0); }
input_connector &indices() { return input_at(1); }
output_connector &output() { return output_at(0); }
int32_t axis() const noexcept { return axis_; }
const std::string &mode() const noexcept { return mode_; }
take(datatype_t type, shape_t input_shape, shape_t indices_shape, shape_t output_shape, int32_t axis, std::string mode);
protected:
bool properties_equal(node &other) const override;
private:
int32_t axis_;
std::string mode_;
};
}

View File

@ -13,24 +13,26 @@
* limitations under the License.
*/
#pragma once
#include "../call.h"
#include "../op.h"
#include "nncase/runtime/datatypes.h"
#include "opcode.h"
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir::nn {
/** @brief Sigmoid operator node */
class NNCASE_API sigmoid_node : public op_node {
DEFINE_OBJECT_KIND(op_node, op_nn_sigmoid)
public:
sigmoid_node();
namespace nncase::ir
{
class NNCASE_API ternary : public node
{
public:
DEFINE_NODE_OPCODE(op_ternary);
/** @brief Get the input the sigmoid expression */
const connector_info &input() const noexcept { return parameter_at(0); }
input_connector &input_a() { return input_at(0); }
input_connector &input_b() { return input_at(1); }
input_connector &input_c() { return input_at(2); }
output_connector &output() { return output_at(0); }
type infer_invoke_result_type(type_infer_context &context) override;
ternary(datatype_t input_a_type, datatype_t input_bc_type, shape_t input_a_shape, shape_t input_b_shape, shape_t input_c_shape);
protected:
bool properties_equal(node &other) const override;
private:
};
/** @brief Sigmoid expression */
using sigmoid = object_t<sigmoid_node>;
} // namespace nncase::ir::nn
}

View File

@ -0,0 +1,49 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
namespace nncase::ir
{
class NNCASE_API topk : public node
{
public:
DEFINE_NODE_OPCODE(op_topk);
input_connector &input() { return input_at(0); }
// output largest values
output_connector &output_a() { return output_at(0); }
// output indices of largest values
output_connector &output_b() { return output_at(1); }
const int64_t &k() const noexcept { return k_; }
const int32_t &axis() const noexcept { return axis_; }
bool largest() const noexcept { return largest_; }
bool sorted() const noexcept { return sorted_; }
topk(datatype_t input_type, shape_t input_shape, int64_t k, int32_t axis, bool largest, bool sorted);
protected:
bool properties_equal(node &other) const override;
private:
int64_t k_;
int32_t axis_;
bool largest_;
bool sorted_;
};
}

View File

@ -13,17 +13,27 @@
* limitations under the License.
*/
#pragma once
#include "../op.h"
#include "nncase/runtime/datatypes.h"
#include "opcode.h"
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir::tensors {
/** @brief Transpose operator node */
class NNCASE_API transpose_node : public op_node {
DEFINE_OBJECT_KIND(op_node, op_tensors_transpose)
public:
transpose_node();
namespace nncase::ir
{
class NNCASE_API transpose : public node
{
public:
DEFINE_NODE_OPCODE(op_transpose);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
const axis_t &perm() const noexcept { return perm_; }
transpose(datatype_t type, shape_t input_shape, axis_t perm);
protected:
bool properties_equal(node &other) const override;
private:
axis_t perm_;
};
using transpose = object_t<transpose_node>;
} // namespace nncase::ir::tensors
}

View File

@ -0,0 +1,40 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
namespace nncase::ir
{
class NNCASE_API trilu : public node
{
public:
DEFINE_NODE_OPCODE(op_trilu);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
bool upper() const noexcept { return upper_; }
const int64_t &k() const noexcept { return k_; }
trilu(datatype_t input_type, shape_t input_shape, bool upper, int64_t k);
protected:
bool properties_equal(node &other) const override;
private:
bool upper_;
int64_t k_;
};
}

View File

@ -0,0 +1,39 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
class NNCASE_API unary : public node
{
public:
DEFINE_NODE_OPCODE(op_unary);
input_connector &input() { return input_at(0); }
output_connector &output() { return output_at(0); }
unary_op_t unary_op() const noexcept { return unary_op_; }
unary(unary_op_t unary_op, shape_t input_shape);
protected:
bool properties_equal(node &other) const override;
private:
unary_op_t unary_op_;
};
}

View File

@ -0,0 +1,90 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "node.h"
namespace nncase::ir
{
class NNCASE_API input_node : public node
{
public:
DEFINE_NODE_OPCODE(op_input_node);
output_connector &output() { return output_at(0); }
template <class TShape>
input_node(datatype_t type, TShape &&shape)
{
add_output("output", type, std::forward<TShape>(shape), mem_input)
.attributes(cnctr_attr_no_layout_strides);
}
protected:
bool properties_equal([[maybe_unused]] node &other) const override { return true; }
};
class NNCASE_API output_node : public node
{
public:
DEFINE_NODE_OPCODE(op_output_node);
input_connector &input() { return input_at(0); }
template <class TShape>
output_node(datatype_t type, TShape &&shape)
{
attributes(attributes() | node_attr_skip_constant_folding);
add_input("input", type, std::forward<TShape>(shape));
}
protected:
bool properties_equal([[maybe_unused]] node &other) const override { return true; }
};
class NNCASE_API ignore_node : public node
{
public:
DEFINE_NODE_OPCODE(op_ignore_node);
~ignore_node() = default;
input_connector &input() { return input_at(0); }
template <class TShape>
ignore_node(datatype_t type, TShape &&shape)
{
add_input("input", type, std::forward<TShape>(shape));
}
protected:
bool properties_equal([[maybe_unused]] node &other) const override { return true; }
};
class NNCASE_API uninitialized : public node
{
public:
DEFINE_NODE_OPCODE(op_uninitialized);
output_connector &output() { return output_at(0); }
template <class TShape>
uninitialized(datatype_t type, TShape &&shape, memory_location_t memory_location = mem_data)
{
add_output("output", type, std::forward<TShape>(shape), memory_location);
}
protected:
bool properties_equal([[maybe_unused]] node &other) const override { return true; }
};
}

View File

@ -46,6 +46,7 @@ class NNCASE_API quantizer
void record(std::span<const float> data);
void record(std::span<const bfloat16> data);
void record(std::span<const half> data);
void finish();
value_range<float> optimal_range() const noexcept { return optimal_range_; }
@ -90,27 +91,30 @@ public:
}
else
{
if (range.min < -1e3)
range.min = -1e3;
if (range.max > 1e3)
range.max = 1e3;
if (range.max < 0)
range.max = 0;
if (range.min > 0)
range.min = 0;
auto r = range.max - range.min;
if (r == 0)
r = 0.1f;
else if (r < 0.01f)
r = 0.01f;
range.max = range.min + r;
if (range.max < 0)
range.max = 0;
if (range.min > 0)
range.min = 0;
}
return range;
}
static quant_param_t get_quant_param(value_range<float> range, int32_t bits);
enum class quant_mode
{
unsigned_mode,
signed_symmetric_mode,
signed_asymmetric_mode
};
static quant_param_t get_quant_param(value_range<float> range, int32_t bits, quant_mode qm);
static fixed_mul get_fixed_mul(float value, int32_t max_bits, uint8_t max_shift, bool is_signed);
void record(ir::output_connector &connector, value_range<float> range);
@ -118,6 +122,13 @@ public:
bool has_record(ir::output_connector &connector) const;
void record(ir::output_connector &connector, std::span<const float> data);
void record(ir::output_connector &connector, std::span<const bfloat16> data);
void record(ir::output_connector &connector, std::span<const half> data);
void record_buffers(ir::output_connector &connector, std::span<const float> data);
void record_buffers(ir::output_connector &connector, std::span<const bfloat16> data);
void record_buffers(ir::output_connector &connector, std::span<const half> data);
void record_quant_buffers(ir::output_connector &connector, std::span<const float> data);
void record_quant_buffers(ir::output_connector &connector, std::span<const bfloat16> data);
void record_quant_buffers(ir::output_connector &connector, std::span<const half> data);
value_range<float> get(ir::output_connector &connector) const;
void broadcast_output(ir::graph &graph, const std::unordered_set<node_opcode> &ops);
void broadcast_output(ir::node &node, const value_range<float> &range, const std::unordered_set<node_opcode> &ops);
@ -125,6 +136,10 @@ public:
void end_collect_distribution(std::function<void(size_t cnt, size_t total)> progress);
size_t histograms_count() const noexcept { return histograms_.size(); }
void end_sample() { has_record_.clear(); }
std::unordered_map<ir::output_connector *, std::vector<float>> output_buffers() const noexcept { return output_buffers_; }
std::vector<ir::output_connector *> quant_buffers_insert_order() const noexcept { return quant_buffers_insert_order_; }
std::unordered_map<ir::output_connector *, value_range<float>> ranges() const noexcept { return quant_ranges_; }
std::vector<ir::output_connector *> ranges_insert_order() const noexcept { return ranges_insert_order_; }
private:
calibrate_method cali_method_;
@ -133,5 +148,8 @@ private:
std::unordered_map<ir::output_connector *, value_range<float>> quant_ranges_;
std::unordered_map<ir::output_connector *, histogram> histograms_;
std::unordered_map<ir::output_connector *, bool> has_record_;
std::unordered_map<ir::output_connector *, std::vector<float>> output_buffers_;
std::vector<ir::output_connector *> quant_buffers_insert_order_;
std::vector<ir::output_connector *> ranges_insert_order_;
};
}

View File

@ -1,166 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../object.h"
#include <algorithm>
#include <nncase/runtime/datatypes.h>
#include <nncase/runtime/small_vector.hpp>
#include <optional>
#include <range/v3/algorithm/any_of.hpp>
#include <range/v3/range/concepts.hpp>
namespace nncase::ir {
struct unknown_dim_t {};
inline constexpr unknown_dim_t unknown_dim;
enum dim_kind_t { dim_fixed = 0, dim_unknown = 1 };
using dim_value_t = int64_t;
/** @brief Dimension */
struct dim_t {
/** @brief Initialize an unknown dim */
constexpr dim_t(unknown_dim_t = unknown_dim) noexcept
: kind(dim_unknown), value(0) {}
/** @brief Initialize an fixed dim */
constexpr dim_t(dim_value_t value) noexcept
: kind(dim_fixed), value(value) {}
/** @brief Is this a fixed dim */
bool is_fixed() const noexcept { return kind == dim_fixed; }
/** @brief Is this an unknown dim */
bool is_unknown() const noexcept { return kind == dim_unknown; }
dim_value_t fixed_value() const {
assert(is_fixed());
return value;
}
dim_kind_t kind;
dim_value_t value;
};
struct scalar_shape_t {};
inline constexpr scalar_shape_t scalar_shape;
struct unranked_shape_t {};
inline constexpr unranked_shape_t unranked_shape;
struct invalid_shape_t {};
inline constexpr invalid_shape_t invalid_shape;
/** @brief Shape type */
class NNCASE_API shape_t {
enum shape_kind_t {
shape_kind_fixed,
shape_kind_has_unknown_dim,
shape_kind_unranked,
shape_kind_invalid
};
public:
using value_type = dim_t;
/** @brief Initialize a scalar shape */
shape_t(scalar_shape_t) noexcept : kind_(shape_kind_fixed) {}
/** @brief Initialize an unranked shape */
shape_t(unranked_shape_t) noexcept : kind_(shape_kind_unranked) {}
/** @brief Initialize an invalid shape */
shape_t(invalid_shape_t) noexcept : kind_(shape_kind_invalid) {}
/** @brief Initialize a ranked shape */
template <ranges::range R>
shape_t(R dims) : kind_(kind_of(dims)), dims_(dims.begin(), dims.end()) {}
/** @brief Initialize a fixed shape */
shape_t(std::initializer_list<dim_value_t> dims) : kind_(shape_kind_fixed) {
dims_.reserve(dims.size());
std::transform(dims.begin(), dims.end(), std::back_inserter(dims_),
[](dim_value_t dim) -> dim_t { return dim; });
}
/** @brief Get kind */
shape_kind_t kind() const noexcept { return kind_; }
/** @brief Is this a fixed shape */
bool is_fixed() const noexcept { return kind() == shape_kind_fixed; }
/** @brief Is this a scalar */
bool is_scalar() const noexcept {
return kind() == shape_kind_fixed && dims_.empty();
}
/** @brief Is this an ranked shape */
bool is_ranked() const noexcept { return is_fixed() || has_unknown_dim(); }
/** @brief Is this an unranked shape */
bool is_unranked() const noexcept { return kind() == shape_kind_unranked; }
/** @brief Has at least one unknown dimension */
bool has_unknown_dim() const noexcept {
return kind() == shape_kind_has_unknown_dim;
}
/** @brief Is this an invalid shape */
bool is_invalid() const noexcept { return kind() == shape_kind_invalid; }
/** @brief Get dimensions */
std::span<const dim_t> dims() const noexcept { return dims_; }
/** @brief Get rank */
std::optional<size_t> rank() const noexcept {
return is_ranked() ? std::make_optional(dims_.size()) : std::nullopt;
}
auto begin() const noexcept { return dims_.cbegin(); }
auto end() const noexcept { return dims_.cend(); }
const dim_t &front() const { return dims_.front(); }
const dim_t &back() const { return dims_.back(); }
/** @brief Get dimension */
const dim_t &dim(size_t index) const { return dims_.at(index); }
const dim_t &operator[](size_t index) const { return dim(index); }
/** @brief Set dimension */
void dim(size_t index, dim_t value);
/** @brief Place a new dim at back */
void push_back(dim_t value);
const dim_t &emplace_back(dim_t value);
/** @brief Place a new dim */
const dim_t *emplace(const dim_t *position, dim_t value);
/** @brief Remove the dim at back */
void pop_back();
private:
template <ranges::range R> static shape_kind_t kind_of(R &&range) noexcept {
return ranges::any_of(range,
[](const dim_t &dim) { return dim.is_unknown(); })
? shape_kind_has_unknown_dim
: shape_kind_fixed;
}
void update_kind(shape_kind_t before_kind,
dim_kind_t new_dim_kind) noexcept;
private:
shape_kind_t kind_;
itlib::small_vector<dim_t, 4> dims_;
};
} // namespace nncase::ir

View File

@ -1,45 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../op.h"
#include "nncase/runtime/datatypes.h"
#include "opcode.h"
namespace nncase::ir::tensors {
/** @brief Cast operator node */
class NNCASE_API cast_node : public op_node {
DEFINE_OBJECT_KIND(op_node, op_tensors_cast)
public:
cast_node(datatype_t new_type);
/** @brief Get the new type of the cast expression */
datatype_t new_type() const noexcept { return new_type_; }
/** @brief Set the new type of the cast expression */
void new_type(datatype_t value) noexcept { new_type_ = value; }
/** @brief Get the input the unary expression */
const connector_info &input() const noexcept { return parameter_at(0); }
type infer_invoke_result_type(type_infer_context &context) override;
private:
datatype_t new_type_;
};
class cast : public object_t<cast_node> {
public:
NNCASE_API cast(datatype_t new_type);
};
} // namespace nncase::ir::tensors

View File

@ -1,22 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../call.h"
#include "broadcast.h"
#include "cast.h"
namespace nncase::ir::F {
NNCASE_API call broadcast(expr input, expr shape);
} // namespace nncase::ir::F

View File

@ -1,29 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../op.h"
#include "nncase/runtime/datatypes.h"
#include "opcode.h"
namespace nncase::ir::tensors {
/** @brief Gather operator node */
class NNCASE_API gather_node : public op_node {
DEFINE_OBJECT_KIND(op_node, op_tensors_gather)
public:
gather_node();
};
using gather = object_t<gather_node>;
} // namespace nncase::ir::tensors

View File

@ -1,29 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../op.h"
#include "nncase/runtime/datatypes.h"
#include "opcode.h"
namespace nncase::ir::tensors {
/** @brief GatherND operator node */
class NNCASE_API gather_nd_node : public op_node {
DEFINE_OBJECT_KIND(op_node, op_tensors_gather_nd)
public:
gather_nd_node();
};
using gather_nd = object_t<gather_nd_node>;
} // namespace nncase::ir::tensors

View File

@ -1,10 +0,0 @@
DEFINE_OPCODE(tensors, broadcast, Broadcast, 0x012001)
DEFINE_OPCODE(tensors, concat, Concat, 0x012002)
DEFINE_OPCODE(tensors, cast, Cast, 0x012003)
DEFINE_OPCODE(tensors, copy, Copy, 0x012004)
DEFINE_OPCODE(tensors, gather, Gather, 0x012005)
DEFINE_OPCODE(tensors, gather_nd, GatherND, 0x012006)
DEFINE_OPCODE(tensors, reshape, Reshape, 0x012007)
DEFINE_OPCODE(tensors, slice, Slice, 0x012008)
DEFINE_OPCODE(tensors, split, Split, 0x012009)
DEFINE_OPCODE(tensors, transpose, Transpose, 0x01200A)

View File

@ -1,25 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../../object_kind.h"
namespace nncase::ir::tensors {
#define DEFINE_OPCODE(dialect, id, name, value) \
NNCASE_INLINE_VAR constexpr object_kind op_##dialect##_##id{value, #name};
#include "opcode.def"
#undef DEFINE_OPCODE
} // namespace nncase::ir::tensors

View File

@ -1,29 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../op.h"
#include "nncase/runtime/datatypes.h"
#include "opcode.h"
namespace nncase::ir::tensors {
/** @brief Reshape operator node */
class NNCASE_API reshape_node : public op_node {
DEFINE_OBJECT_KIND(op_node, op_tensors_reshape)
public:
reshape_node();
};
using reshape = object_t<reshape_node>;
} // namespace nncase::ir::tensors

Some files were not shown because too many files have changed in this diff Show More