Compare commits

...

4 Commits

Author SHA1 Message Date
zhangyang2057 5fd159c25d
Merge 587e1c35c4 into 17ad24801b 2024-05-08 12:13:39 +08:00
Curio Yang 17ad24801b
Feature/add xsgetn (#1199)
* Load multiple characters at once

* format

* fix build

* Apply code-format changes

---------

Co-authored-by: curioyang <curioyang@users.noreply.github.com>
Co-authored-by: sunnycase <sunnycase@live.cn>
2024-05-08 12:13:03 +08:00
zhangyang2057 587e1c35c4 Add custimized data for x86_64/riscv64 roofline. 2024-05-06 16:44:09 +08:00
zhangyang2057 c6f7af48a9 Add benchmark test for ntt unary/binary. 2024-04-26 21:06:35 +08:00
43 changed files with 636 additions and 74 deletions

View File

@ -67,7 +67,7 @@ jobs:
run: |
cd build
ctest -C ${{matrix.config.buildType}} --test-dir tests/kernels --output-on-failure -j4
ctest -C ${{matrix.config.buildType}} --test-dir src/Native/test --output-on-failure -j4
ctest -C ${{matrix.config.buildType}} --test-dir src/Native/test/ctest --output-on-failure -j4
if: runner.os != 'Windows'
#- name: Benchmark
@ -144,7 +144,7 @@ jobs:
run: |
cd build
ctest --test-dir tests/kernels --output-on-failure -j4
ctest --test-dir src/Native/test --output-on-failure -j4
ctest --test-dir src/Native/test/ctest --output-on-failure -j4
- name: Upload nncaseruntime Build Artifact
uses: actions/upload-artifact@v3

View File

@ -129,11 +129,12 @@ if (BUILDING_RUNTIME)
# add_subdirectory(src/Native/src/runtime)
add_subdirectory(src/Native/src)
if(BUILD_TESTING)
add_subdirectory(src/Native/test)
add_subdirectory(src/Native/test/ctest)
endif()
# add_subdirectory(src/Native/src/functional)
if(BUILD_BENCHMARK)
# add_subdirectory(benchmark)
add_subdirectory(src/Native/test/benchmark_test)
endif()
# Python binding
@ -222,7 +223,7 @@ else()
add_subdirectory(src/Native/include/nncase)
add_subdirectory(src/Native/src)
if(BUILD_TESTING)
add_subdirectory(src/Native/test)
add_subdirectory(src/Native/test/ctest)
endif()
# Python binding
if(BUILD_PYTHON_BINDING)

View File

@ -17,7 +17,7 @@
#include <arm_neon.h>
#ifndef NTT_VLEN
#define NTT_VLEN (sizeof(int8x16_t))
#define NTT_VLEN (sizeof(int8x16_t) * 8)
#endif
NTT_DEFINE_NATIVE_TENSOR(int8_t, int8x16_t, 16)

View File

@ -17,7 +17,7 @@
#include <immintrin.h>
#ifndef NTT_VLEN
#define NTT_VLEN (sizeof(__m256i))
#define NTT_VLEN (sizeof(__m256i) * 8)
#endif
NTT_DEFINE_NATIVE_TENSOR(int8_t, __m256i, 32)

View File

@ -26,21 +26,21 @@ class char_array_buffer : public std::streambuf {
current_(data.data()) {}
private:
int_type underflow() {
int_type underflow() override {
if (current_ == end_)
return traits_type::eof();
return traits_type::to_int_type(*current_);
}
int_type uflow() {
int_type uflow() override {
if (current_ == end_)
return traits_type::eof();
return traits_type::to_int_type(*current_++);
}
int_type pbackfail(int_type ch) {
int_type pbackfail(int_type ch) override {
if (current_ == begin_ ||
(ch != traits_type::eof() && ch != current_[-1]))
return traits_type::eof();
@ -48,13 +48,14 @@ class char_array_buffer : public std::streambuf {
return traits_type::to_int_type(*--current_);
}
std::streamsize showmanyc() {
std::streamsize showmanyc() override {
assert(std::less_equal<const char *>()(current_, end_));
return end_ - current_;
}
std::streampos seekoff(std::streamoff off, std::ios_base::seekdir way,
[[maybe_unused]] std::ios_base::openmode which) {
std::streampos
seekoff(std::streamoff off, std::ios_base::seekdir way,
[[maybe_unused]] std::ios_base::openmode which) override {
if (way == std::ios_base::beg) {
current_ = begin_ + off;
} else if (way == std::ios_base::cur) {
@ -69,8 +70,9 @@ class char_array_buffer : public std::streambuf {
return current_ - begin_;
}
std::streampos seekpos(std::streampos sp,
[[maybe_unused]] std::ios_base::openmode which) {
std::streampos
seekpos(std::streampos sp,
[[maybe_unused]] std::ios_base::openmode which) override {
current_ = begin_ + sp;
if (current_ < begin_ || current_ > end_)
@ -79,6 +81,17 @@ class char_array_buffer : public std::streambuf {
return current_ - begin_;
}
std::streamsize xsgetn(char_type *s, std::streamsize count) override {
std::streamsize available =
static_cast<std::streamsize>(end_ - current_);
std::streamsize n = (count > available) ? available : count;
if (n > 0) {
traits_type::copy(s, current_, static_cast<size_t>(n));
current_ += n;
}
return n;
}
const char *const begin_;
const char *const end_;
const char *current_;

View File

@ -0,0 +1,9 @@
include_directories(${CMAKE_CURRENT_LIST_DIR}/..)
file(GLOB TEST_NAMES CONFIGURE_DEPENDS benchmark_*.cpp)
foreach(test_name ${TEST_NAMES})
get_filename_component(tname ${test_name} NAME_WE)
add_executable(${tname} ${tname}.cpp)
target_link_libraries(${tname} PRIVATE nncaseruntime)
endforeach()

View File

@ -0,0 +1,295 @@
import argparse
import os
from pathlib import Path
import subprocess
import socket
import struct
import json
import time
from html import escape
def kpu_targets():
return os.getenv('KPU_TARGETS', "").split(',')
def nuc_ip():
return os.getenv('NUC_PROXY_IP')
def nuc_port():
return os.getenv('NUC_PROXY_PORT')
def report_file(default: str):
return os.getenv('BENCHMARK_NTT_REPORT_FILE', default)
def generate_markdown(benchmark_list: list, md_file: str):
# generate dict after sorting
benchmark_list = sorted(benchmark_list, key=lambda d: (d['kind'], d['op']))
dict = {}
for e in benchmark_list:
kind = e['kind']
if kind not in dict:
dict[kind] = []
dict[kind].append(e)
# generate html table
md = '<table>\n'
# table head
md += '\t<tr>\n'
for key in benchmark_list[0]:
md += f'\t\t<th>{key}</th>\n'
md += '\t</tr>\n'
# table row
for value in dict.values():
length = len(value)
for i in range(length):
md += '\t<tr>\n'
if i == 0:
for k, v in value[i].items():
if k == 'kind':
md += f'\t\t<td rowspan=\'{length}\'>{v}</td>\n'
else:
md += f'\t\t<td>{v}</td>\n'
else:
for k, v in value[i].items():
if k != 'kind':
md += f'\t\t<td>{v}</td>\n'
md += '\t</tr>\n'
md += '</table>\n'
with open(md_file, 'w') as f:
f.write(md)
class BenchmarkNTT():
def __init__(self, arch: str, target: str, bin_path: str):
self.arch = arch
self.target = target
self.bin_path = bin_path
self.bin_prefix = 'benchmark_ntt_'
self.bin_list = self.traverse_dir()
self.benchmark_list = []
def traverse_dir(self):
file_list = []
for bin in Path(os.path.dirname(self.bin_path)).glob(f'{self.bin_prefix}*'):
file_list.append(bin)
return file_list
def parse_result(self, result: str):
lines = result.split('\n')
for line in lines:
items = line.split(' ')
dict = {}
dict['kind'], dict['op'] = items[0].split(self.bin_prefix)[1].split('_', 1)
dict[f'{self.arch}_roofline'] = self.roofline_dict[dict['kind']][dict['op']]
dict[f'{self.arch}_actual'] = items[-2]
self.benchmark_list.append(dict)
def run():
pass
class BenchmarkNTT_x86_64(BenchmarkNTT):
def __init__(self, target: str, bin_path: str):
BenchmarkNTT.__init__(self, 'x86_64', target, bin_path)
self.roofline_dict = {'binary': {'add': 'N/A',
'sub': 'N/A',
'mul': 'N/A',
'div': 'N/A',
'max': 'N/A',
'min': 'N/A',
'floor_mod': 'N/A',
'mod': 'N/A',
'pow': 'N/A',
},
'unary': {'abs': 'N/A',
'acos': 'N/A',
'acosh': 'N/A',
'asin': 'N/A',
'asinh': 'N/A',
'ceil': 'N/A',
'cos': 'N/A',
'cosh': 'N/A',
'exp': 'N/A',
'floor': 'N/A',
'log': 'N/A',
'neg': 'N/A',
'round': 'N/A',
'rsqrt': 'N/A',
'sign': 'N/A',
'sin': 'N/A',
'sinh': 'N/A',
'sqrt': 'N/A',
'square': 'N/A',
'swish': 'N/A',
'tanh': 'N/A',
},
}
def run(self):
for bin in self.bin_list:
cmd_status, cmd_result = subprocess.getstatusoutput(f'{bin}')
assert(cmd_status == 0)
self.parse_result(cmd_result)
class BenchmarkNTT_riscv64(BenchmarkNTT):
def __init__(self, target: str, bin_path: str):
BenchmarkNTT.__init__(self, 'riscv64', target, bin_path)
self.roofline_dict = {'binary': {'add': 'N/A',
'sub': 'N/A',
'mul': 'N/A',
'div': 'N/A',
'max': 'N/A',
'min': 'N/A',
'floor_mod': 'N/A',
'mod': 'N/A',
'pow': 'N/A'
},
'unary': {'abs': 'N/A',
'acos': 'N/A',
'acosh': 'N/A',
'asin': 'N/A',
'asinh': 'N/A',
'ceil': 'N/A',
'cos': 'N/A',
'cosh': 'N/A',
'exp': 'N/A',
'floor': 'N/A',
'log': 'N/A',
'neg': 'N/A',
'round': 'N/A',
'rsqrt': 'N/A',
'sign': 'N/A',
'sin': 'N/A',
'sinh': 'N/A',
'sqrt': 'N/A',
'square': 'N/A',
'swish': 'N/A',
'tanh': 'N/A',
},
}
def send_msg(self, sock, msg):
# Prefix each message with a 4-byte length (network byte order)
msg = struct.pack('>I', len(msg)) + msg
sock.sendall(msg)
def recv_msg(self, sock):
# Read message length and unpack it into an integer
raw_msglen = self.recvall(sock, 4)
if not raw_msglen:
return None
msglen = struct.unpack('>I', raw_msglen)[0]
# Read the message data
return self.recvall(sock, msglen)
def recvall(self, sock, n):
# Helper function to recv n bytes or return None if EOF is hit
data = bytearray()
while len(data) < n:
packet = sock.recv(n - len(data))
if not packet:
return None
data.extend(packet)
return data
def run_evb(self, bin):
# connect server
ip = nuc_ip()
port = nuc_port()
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client_socket.connect((ip, int(port)))
# send target
dummy = self.recv_msg(client_socket)
target_dict = {}
target_dict['target'] = self.target
self.send_msg(client_socket, json.dumps(target_dict).encode())
# send header
dummy = self.recv_msg(client_socket)
header_dict = {}
header_dict['case'] = os.path.basename(bin)
header_dict['app'] = 1
header_dict['kmodel'] = 0
header_dict['inputs'] = 0
header_dict['description'] = 0
header_dict['outputs'] = 0
header_dict['cfg_cmds'] = []
self.send_msg(client_socket, json.dumps(header_dict).encode())
# send app
dummy = self.recv_msg(client_socket)
file_dict = {}
file_dict['file_name'] = os.path.basename(bin)
file_dict['file_size'] = os.path.getsize(bin)
self.send_msg(client_socket, json.dumps(file_dict).encode())
dummy = self.recv_msg(client_socket)
with open(bin, 'rb') as f:
client_socket.sendall(f.read())
# get result
header_dict = {}
ret = self.recv_msg(client_socket)
header_dict = json.loads(ret.decode())
msg = header_dict['msg']
if header_dict['type'].find('finish') != -1:
self.send_msg(client_socket, f"recved msg".encode())
client_socket.close()
return 0, msg[0]
else:
client_socket.close()
raise Exception(msg)
return -1, msg
def run(self):
if self.target not in kpu_targets():
return
for bin in self.bin_list:
cmd_status, cmd_result = self.run_evb(bin)
assert(cmd_status == 0)
lines = cmd_result.split('\r\n')
new_lines = lines[1:-1]
new_cmd_result = '\n'.join(new_lines)
self.parse_result(new_cmd_result)
if __name__ == '__main__':
# parse
parser = argparse.ArgumentParser(prog="benchmark_ntt")
parser.add_argument("--x86_64_target", help='x86_64 target to run on',
type=str, default='local')
parser.add_argument("--x86_64_path", help='bin path for x86_64',
type=str, default='x86_64_build/bin')
parser.add_argument("--riscv64_target", help='riscv64 target to run on',
type=str, default='k230')
parser.add_argument("--riscv64_path", help='bin path for riscv64',
type=str, default='riscv64_build/bin')
args = parser.parse_args()
# x86_64
ntt_x86_64 = BenchmarkNTT_x86_64(args.x86_64_target, args.x86_64_path)
ntt_x86_64.run()
# riscv64
ntt_riscv64 = BenchmarkNTT_riscv64(args.riscv64_target, args.riscv64_path)
ntt_riscv64.run()
# merge benchmark list
benchmark_list = []
for i in range(len(ntt_x86_64.benchmark_list)):
item = {**ntt_x86_64.benchmark_list[i], **ntt_riscv64.benchmark_list[i]}
benchmark_list.append(item)
# generate md
md_file = report_file('benchmark_ntt.md')
generate_markdown(benchmark_list, md_file)

View File

@ -0,0 +1,69 @@
/* Copyright 2019-2024 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ntt_test.h"
#include <iomanip>
#include <nncase/ntt/ntt.h>
using namespace nncase;
#define BENCHMARMK_NTT_BINARY(op) \
template <typename T, size_t N> \
void benchmark_ntt_binary_##op(T lhs_low, T lhs_high, T rhs_low, \
T rhs_high) { \
constexpr size_t size = 2000; \
ntt::tensor<ntt::vector<T, N>, ntt::fixed_shape<size>> ntt_lhs, \
ntt_rhs; \
NttTest::init_tensor(ntt_lhs, lhs_low, lhs_high); \
NttTest::init_tensor(ntt_rhs, rhs_low, rhs_high); \
\
auto t1 = NttTest::get_cpu_cycle(); \
for (size_t i = 0; i < size; i++) \
ntt::op(ntt_lhs, ntt_rhs); \
auto t2 = NttTest::get_cpu_cycle(); \
std::cout << __FUNCTION__ << " took " << std::setprecision(1) \
<< std::fixed << static_cast<float>(t2 - t1) / size / size \
<< " cycles" << std::endl; \
}
#define REGISTER_NTT_BINARY \
BENCHMARMK_NTT_BINARY(add) \
BENCHMARMK_NTT_BINARY(sub) \
BENCHMARMK_NTT_BINARY(mul) \
BENCHMARMK_NTT_BINARY(div) \
BENCHMARMK_NTT_BINARY(max) \
BENCHMARMK_NTT_BINARY(min) \
BENCHMARMK_NTT_BINARY(floor_mod) \
BENCHMARMK_NTT_BINARY(mod) \
BENCHMARMK_NTT_BINARY(pow)
REGISTER_NTT_BINARY
#define RUN_NTT_BINARY(N) \
benchmark_ntt_binary_add<float, N>(-10.f, 10.f, -10.f, 10.f); \
benchmark_ntt_binary_sub<float, N>(-10.f, 10.f, -10.f, 10.f); \
benchmark_ntt_binary_mul<float, N>(-10.f, 10.f, -10.f, 10.f); \
benchmark_ntt_binary_div<float, N>(-10.f, 10.f, 1.f, 10.f); \
benchmark_ntt_binary_max<float, N>(-10.f, 10.f, -10.f, 10.f); \
benchmark_ntt_binary_min<float, N>(-10.f, 10.f, -10.f, 10.f); \
benchmark_ntt_binary_floor_mod<int32_t, N>(-10, 10, 1, 10); \
benchmark_ntt_binary_mod<float, N>(-10.f, 10.f, 1.f, 10.f); \
benchmark_ntt_binary_pow<float, N>(0.f, 3.f, 0.f, 3.f);
int main(int argc, char *argv[]) {
(void)argc;
(void)argv;
RUN_NTT_BINARY(NTT_VLEN / (sizeof(float) * 8))
}

View File

@ -0,0 +1,90 @@
/* Copyright 2019-2024 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ntt_test.h"
#include <iomanip>
#include <nncase/ntt/ntt.h>
using namespace nncase;
#define BENCHMARMK_NTT_UNARY(op) \
template <typename T, size_t N> \
void benchmark_ntt_unary_##op(T low, T high) { \
constexpr size_t size = 2000; \
ntt::tensor<ntt::vector<T, N>, ntt::fixed_shape<size>> ntt_input; \
NttTest::init_tensor(ntt_input, low, high); \
\
auto t1 = NttTest::get_cpu_cycle(); \
for (size_t i = 0; i < size; i++) \
ntt::op(ntt_input); \
auto t2 = NttTest::get_cpu_cycle(); \
std::cout << __FUNCTION__ << " took " << std::setprecision(1) \
<< std::fixed << static_cast<float>(t2 - t1) / size / size \
<< " cycles" << std::endl; \
}
#define REGISTER_NTT_UNARY \
BENCHMARMK_NTT_UNARY(abs) \
BENCHMARMK_NTT_UNARY(acos) \
BENCHMARMK_NTT_UNARY(acosh) \
BENCHMARMK_NTT_UNARY(asin) \
BENCHMARMK_NTT_UNARY(asinh) \
BENCHMARMK_NTT_UNARY(ceil) \
BENCHMARMK_NTT_UNARY(cos) \
BENCHMARMK_NTT_UNARY(cosh) \
BENCHMARMK_NTT_UNARY(exp) \
BENCHMARMK_NTT_UNARY(floor) \
BENCHMARMK_NTT_UNARY(log) \
BENCHMARMK_NTT_UNARY(neg) \
BENCHMARMK_NTT_UNARY(round) \
BENCHMARMK_NTT_UNARY(rsqrt) \
BENCHMARMK_NTT_UNARY(sign) \
BENCHMARMK_NTT_UNARY(sin) \
BENCHMARMK_NTT_UNARY(sinh) \
BENCHMARMK_NTT_UNARY(sqrt) \
BENCHMARMK_NTT_UNARY(square) \
BENCHMARMK_NTT_UNARY(swish) \
BENCHMARMK_NTT_UNARY(tanh)
REGISTER_NTT_UNARY
#define RUN_NTT_UNARY(N) \
benchmark_ntt_unary_abs<float, N>(-10.f, 10.f); \
benchmark_ntt_unary_acos<float, N>(-1.f, 1.f); \
benchmark_ntt_unary_acosh<float, N>(1.f, 10.f); \
benchmark_ntt_unary_asin<float, N>(-1.f, 1.f); \
benchmark_ntt_unary_asinh<float, N>(-10.f, 10.f); \
benchmark_ntt_unary_ceil<float, N>(-10.f, 10.f); \
benchmark_ntt_unary_cos<float, N>(-10.f, 10.f); \
benchmark_ntt_unary_cosh<float, N>(-10.f, 10.f); \
benchmark_ntt_unary_exp<float, N>(-10.f, 10.f); \
benchmark_ntt_unary_floor<float, N>(-10.f, 10.f); \
benchmark_ntt_unary_log<float, N>(-10.f, 10.f); \
benchmark_ntt_unary_neg<float, N>(-10.f, 10.f); \
benchmark_ntt_unary_round<float, N>(-10.f, 10.f); \
benchmark_ntt_unary_rsqrt<float, N>(1.f, 10.f); \
benchmark_ntt_unary_sign<float, N>(-10.f, 10.f); \
benchmark_ntt_unary_sin<float, N>(-10.f, 10.f); \
benchmark_ntt_unary_sinh<float, N>(-10.f, 10.f); \
benchmark_ntt_unary_sqrt<float, N>(1.f, 10.f); \
benchmark_ntt_unary_square<float, N>(-10.f, 10.f); \
benchmark_ntt_unary_swish<float, N>(-10.f, 10.f); \
benchmark_ntt_unary_tanh<float, N>(-10.f, 10.f);
int main(int argc, char *argv[]) {
(void)argc;
(void)argv;
RUN_NTT_UNARY(NTT_VLEN / (sizeof(float) * 8))
}

View File

@ -0,0 +1,19 @@
enable_testing()
find_package(ortki)
find_package(GTest REQUIRED)
message(STATUS "CMAKE_CURRENT_LIST_DIR=${CMAKE_CURRENT_LIST_DIR}")
include_directories(${CMAKE_CURRENT_LIST_DIR}/..)
macro(add_test_exec name)
add_executable(${name} ${name}.cpp)
target_link_libraries(${name} PRIVATE GTest::gtest_main nncaseruntime ortki::ortki)
add_test(NAME ${name} COMMAND ${CMAKE_COMMAND} -DTEST_EXECUTABLE=$<TARGET_FILE:${name}> -P ${CMAKE_CURRENT_SOURCE_DIR}/../../../../toolchains/run_test.cmake)
endmacro()
file(GLOB TEST_NAMES CONFIGURE_DEPENDS test_*.cpp)
foreach(test_name ${TEST_NAMES})
get_filename_component(tname ${test_name} NAME_WE)
add_test_exec(${tname})
endforeach()

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <iostream>
#include <nncase/ntt/ntt.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <iostream>
#include <nncase/ntt/ntt.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <iostream>
#include <nncase/ntt/ntt.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <iostream>
#include <nncase/ntt/ntt.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <iostream>
#include <nncase/ntt/ntt.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <iostream>
#include <nncase/ntt/ntt.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <iostream>
#include <nncase/ntt/ntt.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <iostream>
#include <nncase/ntt/ntt.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <iostream>
#include <nncase/ntt/ntt.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>

View File

@ -18,7 +18,6 @@
#include "nncase/ntt/shape.h"
#include <assert.h>
#include <iostream>
#include <ortki/c_api.h>
#include <random>
#include <string>
@ -120,67 +119,24 @@ void init_tensor(ntt::tensor<T, Shape, Stride> &tensor,
}
}
template <typename T, typename Shape,
typename Stride = ntt::default_strides_t<Shape>>
void init_tensor(ntt::tensor<ntt::vector<T, 8>, Shape, Stride> &tensor,
T start = static_cast<T>(0), T stop = static_cast<T>(1)) {
ntt::apply(tensor.shape(),
[&](auto &index) { init_tensor(tensor(index), start, stop); });
}
template <typename T> ortki::DataType primitive_type2ort_type() {
ortki::DataType ort_type = ortki::DataType_FLOAT;
if (std::is_same_v<T, int8_t>)
ort_type = ortki::DataType_INT8;
else if (std::is_same_v<T, int16_t>)
ort_type = ortki::DataType_INT16;
else if (std::is_same_v<T, int32_t>)
ort_type = ortki::DataType_INT32;
else if (std::is_same_v<T, int64_t>)
ort_type = ortki::DataType_INT64;
else if (std::is_same_v<T, uint8_t>)
ort_type = ortki::DataType_UINT8;
else if (std::is_same_v<T, uint16_t>)
ort_type = ortki::DataType_UINT16;
else if (std::is_same_v<T, uint32_t>)
ort_type = ortki::DataType_UINT32;
else if (std::is_same_v<T, uint64_t>)
ort_type = ortki::DataType_UINT64;
else if (std::is_same_v<T, float>)
ort_type = ortki::DataType_FLOAT;
else if (std::is_same_v<T, double>)
ort_type = ortki::DataType_DOUBLE;
else {
std::cerr << "unsupported data type" << std::endl;
std::abort();
#define _INIT_TENSOR(N) \
template <typename T, typename Shape, \
typename Stride = ntt::default_strides_t<Shape>> \
void init_tensor(ntt::tensor<ntt::vector<T, N>, Shape, Stride> &tensor, \
T start = static_cast<T>(0), \
T stop = static_cast<T>(1)) { \
ntt::apply(tensor.shape(), [&](auto &index) { \
init_tensor(tensor(index), start, stop); \
}); \
}
return ort_type;
}
#define INIT_TENSOR(T) \
_INIT_TENSOR((NTT_VLEN) / (sizeof(T) * 8) * 1) \
_INIT_TENSOR((NTT_VLEN) / (sizeof(T) * 8) * 2) \
_INIT_TENSOR((NTT_VLEN) / (sizeof(T) * 8) * 4) \
_INIT_TENSOR((NTT_VLEN) / (sizeof(T) * 8) * 8)
template <typename T, typename Shape,
typename Stride = ntt::default_strides_t<Shape>>
ortki::OrtKITensor *ntt2ort(ntt::tensor<T, Shape, Stride> &tensor) {
void *buffer = reinterpret_cast<void *>(tensor.elements().data());
auto ort_type = primitive_type2ort_type<T>();
auto rank = tensor.shape().rank();
std::vector<size_t> v(rank);
for (size_t i = 0; i < rank; i++)
v[i] = tensor.shape()[i];
const int64_t *shape = reinterpret_cast<const int64_t *>(v.data());
return make_tensor(buffer, ort_type, shape, rank);
}
template <typename T, typename Shape,
typename Stride = ntt::default_strides_t<Shape>>
void ort2ntt(ortki::OrtKITensor *ort_tensor,
ntt::tensor<T, Shape, Stride> &ntt_tensor) {
size_t size = 0;
void *ort_ptr = tensor_buffer(ort_tensor, &size);
assert(tensor_length(ort_tensor) == ntt_tensor.shape().length());
memcpy((void *)ntt_tensor.elements().data(), ort_ptr, size);
}
INIT_TENSOR(float)
template <typename T, typename Shape,
typename Stride = ntt::default_strides_t<Shape>>

View File

@ -0,0 +1,81 @@
/* Copyright 2019-2024 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "nncase/ntt/apply.h"
#include "nncase/ntt/ntt.h"
#include "nncase/ntt/shape.h"
#include <assert.h>
#include <iostream>
#include <ortki/c_api.h>
#include <string>
namespace nncase {
namespace NttTest {
template <typename T> ortki::DataType primitive_type2ort_type() {
ortki::DataType ort_type = ortki::DataType_FLOAT;
if (std::is_same_v<T, int8_t>)
ort_type = ortki::DataType_INT8;
else if (std::is_same_v<T, int16_t>)
ort_type = ortki::DataType_INT16;
else if (std::is_same_v<T, int32_t>)
ort_type = ortki::DataType_INT32;
else if (std::is_same_v<T, int64_t>)
ort_type = ortki::DataType_INT64;
else if (std::is_same_v<T, uint8_t>)
ort_type = ortki::DataType_UINT8;
else if (std::is_same_v<T, uint16_t>)
ort_type = ortki::DataType_UINT16;
else if (std::is_same_v<T, uint32_t>)
ort_type = ortki::DataType_UINT32;
else if (std::is_same_v<T, uint64_t>)
ort_type = ortki::DataType_UINT64;
else if (std::is_same_v<T, float>)
ort_type = ortki::DataType_FLOAT;
else if (std::is_same_v<T, double>)
ort_type = ortki::DataType_DOUBLE;
else {
std::cerr << "unsupported data type" << std::endl;
std::abort();
}
return ort_type;
}
template <typename T, typename Shape,
typename Stride = ntt::default_strides_t<Shape>>
ortki::OrtKITensor *ntt2ort(ntt::tensor<T, Shape, Stride> &tensor) {
void *buffer = reinterpret_cast<void *>(tensor.elements().data());
auto ort_type = primitive_type2ort_type<T>();
auto rank = tensor.shape().rank();
std::vector<size_t> v(rank);
for (size_t i = 0; i < rank; i++)
v[i] = tensor.shape()[i];
const int64_t *shape = reinterpret_cast<const int64_t *>(v.data());
return make_tensor(buffer, ort_type, shape, rank);
}
template <typename T, typename Shape,
typename Stride = ntt::default_strides_t<Shape>>
void ort2ntt(ortki::OrtKITensor *ort_tensor,
ntt::tensor<T, Shape, Stride> &ntt_tensor) {
size_t size = 0;
void *ort_ptr = tensor_buffer(ort_tensor, &size);
assert(tensor_length(ort_tensor) == ntt_tensor.shape().length());
memcpy((void *)ntt_tensor.elements().data(), ort_ptr, size);
}
} // namespace NttTest
} // namespace nncase

View File

@ -22,7 +22,6 @@ set(ENABLE_OPENMP OFF)
set(ENABLE_HALIDE OFF)
set(DEFAULT_BUILTIN_RUNTIMES OFF)
set(DEFAULT_SHARED_RUNTIME_TENSOR_PLATFORM_IMPL ON)
set(BUILD_BENCHMARK OFF)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=rv64imafdcv -mabi=lp64d -mcmodel=medany")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=rv64imafdcv -mabi=lp64d -mcmodel=medany")