Split runtime function from runtime module (#337)

* Update

* Update

* Update

* Pass unary

* Merge from master

* Fix build errors

* Fix build errors

* Fix build errors

* Fix schedule

* Fix build errors

* Update benchmark models
pull/340/head
sunnycase 2021-08-18 18:56:39 +08:00 committed by GitHub
parent 7ad9920682
commit 3bfbd61ccf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
123 changed files with 3776 additions and 2519 deletions

View File

@ -24,6 +24,7 @@ import nncase
import requests
import onnxsim
import onnx
from io import BytesIO
TEMP_DIR = "tmp"
MODEL_DIR = "models"
@ -45,7 +46,7 @@ def _download(url, name, in_shapes):
if not os.path.exists(filename):
req = requests.get(url)
onnx_model, check = onnxsim.simplify(
onnx.load(req.content), check_n=3, input_shapes=in_shapes)
onnx.load_model(BytesIO(req.content)), check_n=3, input_shapes=in_shapes)
assert check, "Simplified ONNX model could not be validated"
onnx.save(onnx_model, filename)

Binary file not shown.

View File

@ -25,7 +25,7 @@ struct build_model_result
class NNCASE_API model_builder
{
public:
model_builder(target &target, const schedule::schedule_result &sched);
model_builder(target &target, const schedule::model_schedule_result &sched);
model_builder(model_builder &) = delete;
model_builder(model_builder &&) = delete;
@ -36,7 +36,7 @@ public:
private:
target &target_;
const schedule::schedule_result &sched_;
const schedule::model_schedule_result &sched_;
std::filesystem::path dump_dir_;
bool dump_asm_;
};

View File

@ -35,10 +35,16 @@ public:
struct module_builder_params
{
const schedule::schedule_result &sched;
const schedule::model_schedule_result &model_sched;
const schedule::module_schedule_result &module_sched;
};
struct function_call_id
{
size_t module_id;
size_t function_id;
};
class NNCASE_API module_builder
{
private:
@ -80,22 +86,27 @@ public:
section_writer &writer(std::string_view section_name);
virtual module_type_t module_type() const noexcept = 0;
virtual uint32_t module_version() const noexcept = 0;
virtual std::unique_ptr<section_decompiler> create_decompiler(std::string_view section_name);
protected:
void merge_to_rdata_section(std::string_view from);
size_t module_id(ir::graph *graph);
function_call_id function_id(ir::graph *graph);
void set_current_entry_point(std::streampos pos);
void set_current_function_text_end(std::streampos pos);
virtual void begin_emit() { }
virtual void begin_emit_module();
virtual void begin_emit_function(const schedule::function_schedule_result &function);
virtual void end_emit_function(const schedule::function_schedule_result &function);
virtual void emit(ir::node &node);
virtual void end_emit() { }
virtual void end_emit_module();
protected:
std::filesystem::path dump_dir_;
bool dump_asm_;
private:
std::vector<nncase::ir::node *> generate_runtime_ops();
std::vector<nncase::ir::node *> generate_current_runtime_ops();
void compile();
void decompile(std::string_view stage, std::string_view section_name, std::span<const uint8_t> input, std::span<const symbol> symbols);
@ -105,6 +116,7 @@ private:
void write_symbol_refs();
void link();
void write_binary(binary_writer &writer);
void write_function_binary(binary_writer &writer, const schedule::function_schedule_result &function_sched);
private:
uint32_t alignment_;
@ -113,5 +125,9 @@ private:
std::map<std::string, section, std::less<>> section_writer_;
std::map<std::string, rdata_merge_info, std::less<>> rdata_section_merges_;
std::unordered_map<std::string_view, std::pair<size_t, std::string_view>> symbol_offsets_;
const schedule::function_schedule_result *current_function_;
std::unordered_map<const schedule::function_schedule_result *, std::streampos> entry_points_;
std::unordered_map<const schedule::function_schedule_result *, std::streampos> function_text_end_;
};
}

View File

@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/7/14 19:17:48 +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/8/11 17:40:11 +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
@ -944,6 +944,7 @@ struct op_writer<nncase::runtime::stackvm::tensor_call_op_t>
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(op.function_id);
writer.write(op.module_id);
writer.write(op.num_src);
writer.write(op.num_dst);
@ -1322,7 +1323,7 @@ public:
void tensor_batch_to_space_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_block, uint8_t rpad_crops);
void tensor_broadcast_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rshape_dest, uint8_t rstride_dest);
void tensor_binary_(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rstride_dest, binary_op_t binary_op, float fused_clamp_low, float fused_clamp_high);
void tensor_call_(uint32_t module_id, uint8_t num_src, uint8_t num_dst);
void tensor_call_(uint32_t function_id, uint16_t module_id, uint8_t num_src, uint8_t num_dst);
void tensor_conv2d_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rshape_kernel, uint8_t rstride_kernel, uint8_t rstride_bias, uint8_t rstride_dest, uint16_t groups, uint16_t stride_h, uint16_t stride_w, uint16_t dilation_h, uint16_t dilation_w, float fused_clamp_low, float fused_clamp_high);
void tensor_copy_(datatype_t datatype, uint8_t rshape, uint8_t rstride_src, uint8_t rstride_dest);
void tensor_convert_(datatype_t in_datatype, datatype_t dst_datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest);

View File

@ -116,6 +116,7 @@ protected:
virtual void process(const std::vector<uint8_t> &src, float *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) = 0;
virtual void process(const std::vector<uint8_t> &src, uint8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) = 0;
virtual void process(const std::vector<uint8_t> &src, int8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) = 0;
virtual bool do_normalize() const noexcept { return true; }
private:
template <class T>
@ -170,5 +171,6 @@ protected:
void process(const std::vector<uint8_t> &src, float *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) override;
void process(const std::vector<uint8_t> &src, uint8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) override;
void process(const std::vector<uint8_t> &src, int8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) override;
bool do_normalize() const noexcept override { return false; }
};
}

View File

@ -0,0 +1,138 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "evaluate_types.h"
#include "quantizer.h"
#include <nncase/schedule/schedule_types.h>
namespace nncase
{
class target;
}
namespace nncase::ir
{
class module_evaluate_context;
class model_evaluate_context;
class NNCASE_API function_evaluate_context
{
public:
function_evaluate_context(const schedule::function_schedule_result &sched, module_evaluate_context &mod_eval);
function_evaluate_context(const function_evaluate_context &) = delete;
function_evaluate_context(function_evaluate_context &&) = default;
evaluate_tensor memory_at(const output_connector &conn);
evaluate_tensor memory_at(const input_connector &conn)
{
return memory_at(*conn.connection());
}
evaluate_tensor input_at(size_t index)
{
return memory_at(*inputs_[index]);
}
evaluate_tensor output_at(size_t index)
{
return memory_at(*outputs_[index]);
}
module_evaluate_context &module() const noexcept { return mod_eval_; }
void evaluate();
private:
const schedule::function_schedule_result &sched_;
module_evaluate_context &mod_eval_;
std::unique_ptr<std::byte[]> input_pool_;
std::unique_ptr<std::byte[]> output_pool_;
std::vector<output_connector *> inputs_;
std::vector<input_connector *> outputs_;
};
class NNCASE_API module_evaluate_context
{
public:
module_evaluate_context(const schedule::module_schedule_result &sched, model_evaluate_context &model_eval);
module_evaluate_context(module_evaluate_context &) = delete;
module_evaluate_context(module_evaluate_context &&) = default;
const schedule::module_schedule_result &sched() const noexcept { return sched_; }
std::byte *memory_pool(memory_location_t location) const;
ir::quantizer *quantizer() noexcept { return quantizer_.get(); }
function_evaluate_context &function(ir::graph &function);
model_evaluate_context &model() const noexcept { return model_eval_; }
void enable_ptq(target &target, ir::calibrate_method calib_method);
void begin_collect_distribution();
void end_sample();
void end_collect_distribution(const std::function<void(size_t cnt, size_t total)> &progress);
private:
const schedule::module_schedule_result &sched_;
model_evaluate_context &model_eval_;
std::unordered_map<memory_location_t, std::unique_ptr<std::byte[]>> memory_pools_;
std::vector<output_connector *> inputs_;
std::vector<input_connector *> outputs_;
std::unique_ptr<ir::quantizer> quantizer_;
std::unordered_map<ir::graph *, function_evaluate_context> functions_;
};
class NNCASE_API model_evaluate_context
{
public:
model_evaluate_context(const schedule::model_schedule_result &sched);
model_evaluate_context(const model_evaluate_context &) = delete;
model_evaluate_context(model_evaluate_context &&) = default;
function_evaluate_context &entrypoint();
module_evaluate_context &module(const module_type_t &module_type);
evaluate_tensor memory_at(const output_connector &conn)
{
return entrypoint().memory_at(conn);
}
evaluate_tensor memory_at(const input_connector &conn)
{
return memory_at(*conn.connection());
}
evaluate_tensor input_at(size_t index)
{
return entrypoint().input_at(index);
}
evaluate_tensor output_at(size_t index)
{
return entrypoint().output_at(index);
}
void enable_ptq(nncase::target &target, ir::calibrate_method calib_method);
void begin_collect_distribution();
void end_sample();
void end_collect_distribution(const std::function<void(size_t cnt, size_t total)> &progress);
void evaluate();
private:
const schedule::model_schedule_result &sched_;
std::unordered_map<module_type_t, module_evaluate_context> module_ctxs_;
};
}

View File

@ -0,0 +1,36 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <nncase/runtime/datatypes.h>
namespace nncase::ir
{
class NNCASE_API evaluate_tensor
{
public:
evaluate_tensor(datatype_t datatype, runtime_shape_t shape, runtime_shape_t strides, gsl::span<gsl::byte> buffer);
datatype_t datatype() const noexcept { return datatype_; }
const runtime_shape_t &shape() const noexcept { return shape_; }
const runtime_shape_t &strides() const noexcept { return strides_; }
gsl::span<gsl::byte> buffer() const noexcept { return buffer_; }
private:
datatype_t datatype_;
runtime_shape_t shape_;
runtime_shape_t strides_;
gsl::span<gsl::byte> buffer_;
};
}

View File

@ -13,114 +13,35 @@
* limitations under the License.
*/
#pragma once
#include <cassert>
#include <nncase/ir/graph.h>
#include <nncase/ir/op_utils.h>
#include <nncase/ir/quantizer.h>
#include <nncase/kernels/kernel_context.h>
#include <nncase/runtime/compiler_defs.h>
#include <nncase/schedule/scheduler.h>
#include <unordered_map>
#include "evaluate_context.h"
#include "evaluate_types.h"
namespace nncase::ir
{
class quantizer;
class NNCASE_API evaluate_tensor
{
public:
evaluate_tensor(datatype_t datatype, runtime_shape_t shape, runtime_shape_t strides, gsl::span<gsl::byte> buffer);
datatype_t datatype() const noexcept { return datatype_; }
const runtime_shape_t &shape() const noexcept { return shape_; }
const runtime_shape_t &strides() const noexcept { return strides_; }
gsl::span<gsl::byte> buffer() const noexcept { return buffer_; }
private:
datatype_t datatype_;
runtime_shape_t shape_;
runtime_shape_t strides_;
gsl::span<gsl::byte> buffer_;
};
class NNCASE_API module_evaluate_context
{
public:
module_evaluate_context(const schedule::module_schedule_result &sched);
module_evaluate_context(module_evaluate_context &) = delete;
module_evaluate_context(module_evaluate_context &&) = default;
evaluate_tensor memory_at(const output_connector &conn);
evaluate_tensor memory_at(const input_connector &conn)
{
return memory_at(*conn.connection());
}
evaluate_tensor input_at(size_t index)
{
return memory_at(*inputs_[index]);
}
evaluate_tensor output_at(size_t index)
{
return memory_at(*outputs_[index]);
}
ir::quantizer *quantizer() noexcept { return quantizer_.get(); }
void enable_ptq(target &target, ir::calibrate_method calib_method);
void evaluate();
void begin_collect_distribution();
void end_collect_distribution(std::function<void(size_t cnt, size_t total)> progress);
private:
const schedule::module_schedule_result &sched_;
std::unordered_map<memory_location_t, std::unique_ptr<std::byte[]>> memory_pools_;
std::vector<output_connector *> inputs_;
std::vector<input_connector *> outputs_;
std::unique_ptr<ir::quantizer> quantizer_;
};
class NNCASE_API evaluator
{
public:
evaluator(const schedule::schedule_result &sched);
evaluator(const schedule::model_schedule_result &sched);
evaluator(evaluator &) = delete;
evaluator(evaluator &&) = default;
module_evaluate_context &module_context(ir::graph &graph);
module_evaluate_context &main_module_context();
void enable_ptq(target &target, ir::calibrate_method calib_method);
void evaluate();
ir::quantizer *quantizer(const module_type_t &module_type);
void begin_collect_distribution();
void end_collect_distribution(std::function<void(size_t cnt, size_t total)> progress);
void end_sample();
void end_collect_distribution(const std::function<void(size_t cnt, size_t total)> &progress);
evaluate_tensor memory_at(const output_connector &conn);
evaluate_tensor memory_at(const input_connector &conn);
evaluate_tensor memory_at(const input_connector &conn)
{
return memory_at(*conn.connection());
}
evaluate_tensor input_at(size_t index)
{
return main_module_context().input_at(index);
}
evaluate_tensor output_at(size_t index)
{
return main_module_context().output_at(index);
}
evaluate_tensor input_at(size_t index);
evaluate_tensor output_at(size_t index);
private:
const schedule::schedule_result &sched_;
std::unordered_map<ir::graph *, module_evaluate_context> module_ctxs_;
model_evaluate_context model_eval_;
};
NNCASE_API void register_evaluator(ir::node_opcode opcode, std::function<void(ir::node &, module_evaluate_context &)> evaluator);
NNCASE_API void register_evaluator(ir::node_opcode opcode, std::function<void(ir::node &, function_evaluate_context &)> evaluator);
}

View File

@ -124,7 +124,7 @@ public:
void begin_collect_distribution();
void end_collect_distribution(std::function<void(size_t cnt, size_t total)> progress);
size_t histograms_count() const noexcept { return histograms_.size(); }
void reset_record() { has_record_.clear(); }
void end_sample() { has_record_.clear(); }
private:
calibrate_method cali_method_;

View File

@ -325,6 +325,8 @@ NNCASE_INLINE_VAR constexpr memory_location_t mem_input = 0;
NNCASE_INLINE_VAR constexpr memory_location_t mem_output = 1;
NNCASE_INLINE_VAR constexpr memory_location_t mem_rdata = 2;
NNCASE_INLINE_VAR constexpr memory_location_t mem_data = 3;
NNCASE_INLINE_VAR constexpr memory_location_t mem_shared_data = 4;
NNCASE_INLINE_VAR constexpr memory_location_t mem_private_base = 64;
using runtime_shape_t = itlib::small_vector<size_t, 4>;
using runtime_axis_t = itlib::small_vector<int32_t, 4>;
@ -402,7 +404,7 @@ struct memory_range
{
memory_location_t memory_location;
datatype_t datatype;
uint16_t reserved0;
uint16_t shared_module;
uint32_t start;
uint32_t size;
};
@ -463,3 +465,16 @@ inline bool operator!=(const scalar &lhs, const scalar &rhs) noexcept
return lhs.type != rhs.type || memcmp(&lhs.storage, &rhs.storage, valid_bytes);
}
}
template <>
struct std::hash<nncase::module_type_t>
{
auto operator()(const nncase::module_type_t &key) const noexcept
{
size_t result = 0;
const size_t prime = 31;
for (auto c : key)
result = c + (result * prime);
return result;
}
};

View File

@ -74,7 +74,7 @@ public:
private:
std::vector<std::unique_ptr<runtime_module>> modules_;
runtime_module *main_module_;
runtime_function *entry_function_;
options_dict options_;
};

View File

@ -24,26 +24,49 @@ struct model_header
{
uint32_t identifier;
uint32_t version;
uint32_t header_size;
uint32_t flags;
uint32_t alignment;
uint32_t modules;
uint32_t main_module;
uint32_t entry_module;
uint32_t entry_function;
};
struct function_header
{
uint32_t header_size;
uint32_t size;
uint32_t input_pool_size;
uint32_t output_pool_size;
uint32_t inputs;
uint32_t outputs;
uint32_t entrypoint;
uint32_t text_size;
};
struct module_header
{
module_type_t type;
uint32_t version;
uint32_t header_size;
uint32_t size;
uint32_t mempools;
uint32_t inputs;
uint32_t outputs;
uint32_t shared_mempools;
uint32_t sections;
uint32_t functions;
uint32_t reserved0;
};
struct mempool_desc
{
memory_location_t location;
uint8_t reserved0[3];
uint32_t size;
};
struct shared_mempool_desc
{
uint32_t module;
uint32_t size;
};
@ -51,8 +74,8 @@ struct section_header
{
char name[MAX_SECTION_NAME_LENGTH];
uint32_t flags;
uint32_t start;
uint32_t size;
uint32_t body_start;
uint32_t body_size;
uint32_t reserved0;
};

View File

@ -0,0 +1,86 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "model.h"
#include "result.h"
#include "runtime_tensor.h"
BEGIN_NS_NNCASE_RUNTIME
class interpreter;
class runtime_module;
struct runtime_module_init_context;
struct NNCASE_API runtime_function_init_context
{
virtual runtime_module_init_context &module_init_context() noexcept = 0;
virtual const function_header &header() noexcept = 0;
virtual gsl::span<const gsl::byte> body() noexcept = 0;
};
class NNCASE_API runtime_function
{
private:
struct inout_tensor_info
{
runtime_shape_t shape;
runtime_shape_t strides;
memory_range range;
runtime_tensor bind_tensor;
runtime_tensor staging_tensor;
runtime_tensor device_tensor;
};
public:
runtime_function(runtime_module &rt_module);
runtime_function(const runtime_function &) = delete;
virtual ~runtime_function() = default;
runtime_function &operator=(const runtime_function &) = delete;
result<void> initialize(gsl::span<const gsl::byte> payload, runtime_module_init_context &module_init_context) noexcept;
runtime_module &module() const noexcept;
uint32_t inputs_size() const noexcept;
const runtime_shape_t &input_shape(size_t index) const noexcept;
const memory_range &input_desc(size_t index) const noexcept;
result<runtime_tensor> input_tensor(size_t index) noexcept;
result<void> input_tensor(size_t index, runtime_tensor tensor) noexcept;
uint32_t outputs_size() const noexcept;
const runtime_shape_t &output_shape(size_t index) const noexcept;
const memory_range &output_desc(size_t index) const noexcept;
result<runtime_tensor> output_tensor(size_t index) noexcept;
result<void> output_tensor(size_t index, runtime_tensor tensor) noexcept;
result<void> invoke() noexcept;
protected:
virtual result<void> initialize_core(runtime_function_init_context &context) noexcept = 0;
virtual result<runtime_tensor> allocate_input_tensor(size_t index) noexcept = 0;
virtual result<runtime_tensor> allocate_output_tensor(size_t index) noexcept = 0;
virtual result<void> validate_input_tensor(size_t index, runtime_tensor tensor) noexcept = 0;
virtual result<void> validate_output_tensor(size_t index, runtime_tensor tensor) noexcept = 0;
result<runtime_tensor> device_input_tensor(size_t index) noexcept;
result<runtime_tensor> device_output_tensor(size_t index) noexcept;
virtual result<void> invoke_core() noexcept = 0;
private:
function_header header_;
std::vector<inout_tensor_info> input_tensors_;
std::vector<inout_tensor_info> output_tensors_;
runtime_module &rt_module_;
};
END_NS_NNCASE_RUNTIME

View File

@ -15,6 +15,7 @@
#pragma once
#include "model.h"
#include "result.h"
#include "runtime_function.h"
#include "runtime_tensor.h"
BEGIN_NS_NNCASE_RUNTIME
@ -31,26 +32,15 @@ struct NNCASE_API runtime_module_init_context
class NNCASE_API runtime_module
{
private:
struct inout_tensor_info
{
runtime_shape_t shape;
runtime_shape_t strides;
memory_range range;
runtime_tensor bind_tensor;
runtime_tensor staging_tensor;
runtime_tensor device_tensor;
};
public:
static result<std::unique_ptr<runtime_module>> create(const module_type_t &type);
runtime_module() = default;
runtime_module(runtime_module &) = delete;
runtime_module(const runtime_module &) = delete;
virtual ~runtime_module() = default;
runtime_module &operator=(const runtime_module &) = delete;
result<void> initialize(const module_header &header, interpreter &interp) noexcept;
virtual result<void> initialize_inter_modules(interpreter &interp) noexcept;
result<void> initialize(gsl::span<const gsl::byte> payload, interpreter &interp) noexcept;
const module_type_t &type() const noexcept;
interpreter &interp() const noexcept { return *interp_; }
@ -59,35 +49,20 @@ public:
const mempool_desc &mempool(size_t index) const noexcept;
mempool_desc mempool(memory_location_t location) const noexcept;
uint32_t inputs_size() const noexcept;
const runtime_shape_t &input_shape(size_t index) const noexcept;
const memory_range &input_desc(size_t index) const noexcept;
result<runtime_tensor> input_tensor(size_t index) noexcept;
result<void> input_tensor(size_t index, runtime_tensor tensor) noexcept;
uint32_t outputs_size() const noexcept;
const runtime_shape_t &output_shape(size_t index) const noexcept;
const memory_range &output_desc(size_t index) const noexcept;
result<runtime_tensor> output_tensor(size_t index) noexcept;
result<void> output_tensor(size_t index, runtime_tensor tensor) noexcept;
result<void> run() noexcept;
result<runtime_function *> find_function_by_id(size_t index) noexcept;
protected:
virtual result<void> initialize_core(runtime_module_init_context &context) noexcept = 0;
virtual result<runtime_tensor> allocate_input_tensor(size_t index) noexcept = 0;
virtual result<runtime_tensor> allocate_output_tensor(size_t index) noexcept = 0;
virtual result<void> validate_input_tensor(size_t index, runtime_tensor tensor) noexcept = 0;
virtual result<void> validate_output_tensor(size_t index, runtime_tensor tensor) noexcept = 0;
result<runtime_tensor> device_input_tensor(size_t index) noexcept;
result<runtime_tensor> device_output_tensor(size_t index) noexcept;
virtual result<void> run_core() noexcept = 0;
virtual result<void> initialize_before_functions(runtime_module_init_context &context) noexcept;
virtual result<void> initialize_after_functions(runtime_module_init_context &context) noexcept;
virtual result<std::unique_ptr<runtime_function>> create_function() noexcept = 0;
gsl::span<std::unique_ptr<runtime_function>> functions() noexcept { return functions_; }
private:
module_header header_;
std::vector<mempool_desc> mempools_;
std::vector<inout_tensor_info> input_tensors_;
std::vector<inout_tensor_info> output_tensors_;
std::vector<mempool_desc> shared_mempools_;
std::vector<std::unique_ptr<runtime_function>> functions_;
interpreter *interp_ = nullptr;
};

View File

@ -23,17 +23,20 @@ inline constexpr size_t get_bytes(datatype_t type)
return nncase::detail::datatype_bytes(type);
}
inline size_t compute_size(const runtime_shape_t &shape)
template <class TShape>
inline size_t compute_size(const TShape &shape)
{
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
}
inline size_t get_bytes(datatype_t type, const runtime_shape_t &shape)
template <class TShape>
inline size_t get_bytes(datatype_t type, const TShape &shape)
{
return compute_size(shape) * get_bytes(type);
}
inline size_t compute_size(const runtime_shape_t &shape, const runtime_shape_t &strides)
template <class TShape>
inline size_t compute_size(const TShape &shape, const TShape &strides)
{
size_t max_stride = 0, max_shape = 0;
for (size_t i = 0; i < shape.size(); i++)
@ -48,7 +51,8 @@ inline size_t compute_size(const runtime_shape_t &shape, const runtime_shape_t &
return size ? size : 1;
}
inline size_t get_bytes(datatype_t type, const runtime_shape_t &shape, const runtime_shape_t &strides)
template <class TShape>
inline size_t get_bytes(datatype_t type, const TShape &shape, const TShape &strides)
{
return compute_size(shape, strides) * get_bytes(type);
}

View File

@ -88,18 +88,16 @@ public:
}
template <class T>
T peek()
T peek_with_offset(size_t offset)
{
auto value = *reinterpret_cast<const T *>(span_.data());
auto value = *reinterpret_cast<const T *>(span_.data() + offset);
return value;
}
template <class T>
T peek_unaligned()
T peek()
{
T value;
std::memcpy(&value, span_.data(), sizeof(T));
return value;
return peek_with_offset<T>(0);
}
template <class T>
@ -110,6 +108,12 @@ public:
return value;
}
template <class T>
T peek_unaligned()
{
return peek_unaligned_with_offset<T>(0);
}
template <class T>
const T *get_ref()
{

View File

@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/7/14 19:17:48 +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/8/11 17:40:11 +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
@ -1141,7 +1141,8 @@ struct op_reader<tensor_call_op_t>
tensor_call_op_t op(default_init);
op.opcode = static_cast<opcode_t>(reader.read_unaligned<uint8_t>());
op.funct = static_cast<tensor_function_t>(reader.read_unaligned<uint16_t>());
op.module_id = reader.read_unaligned<uint32_t>();
op.function_id = reader.read_unaligned<uint32_t>();
op.module_id = reader.read_unaligned<uint16_t>();
op.num_src = reader.read_unaligned<uint8_t>();
op.num_dst = reader.read_unaligned<uint8_t>();
return op;

View File

@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/7/14 19:17:48 +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/8/11 17:40:11 +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
@ -1265,13 +1265,14 @@ struct tensor_call_op_t
{
opcode_t opcode;
tensor_function_t funct;
uint32_t module_id;
uint32_t function_id;
uint16_t module_id;
uint8_t num_src;
uint8_t num_dst;
tensor_call_op_t(default_init_t) noexcept { }
explicit tensor_call_op_t(uint32_t module_id, uint8_t num_src, uint8_t num_dst) noexcept
: opcode(opcode_t::TENSOR), funct(tensor_function_t::CALL), module_id(module_id), num_src(num_src), num_dst(num_dst)
explicit tensor_call_op_t(uint32_t function_id, uint16_t module_id, uint8_t num_src, uint8_t num_dst) noexcept
: opcode(opcode_t::TENSOR), funct(tensor_function_t::CALL), function_id(function_id), module_id(module_id), num_src(num_src), num_dst(num_dst)
{
}
};

View File

@ -18,6 +18,7 @@
BEGIN_NS_NNCASE_RT_MODULE(stackvm)
NNCASE_INLINE_VAR constexpr module_type_t stackvm_module_type = to_module_type("stackvm");
NNCASE_INLINE_VAR constexpr uint32_t stackvm_module_version = 1;
NNCASE_API result<std::unique_ptr<runtime_module>> create_stackvm_runtime_module();

View File

@ -73,4 +73,5 @@ private:
};
using allocator_map_t = std::unordered_map<memory_location_t, buffer_allocator *>;
using shared_allocator_map_t = std::unordered_map<module_type_t, buffer_allocator *>;
}

View File

@ -75,6 +75,9 @@ public:
memory_location_t memory_location() const noexcept { return memory_location_; }
memory_location_t &memory_location() noexcept { return memory_location_; }
module_type_t shared_module() const noexcept { return shared_module_; }
void shared_module(const module_type_t &type) noexcept { shared_module_ = type; }
bool no_action_concat_with_strides() const noexcept { return no_action_concat_with_strides_; }
bool &no_action_concat_with_strides() noexcept { return no_action_concat_with_strides_; }
@ -82,6 +85,7 @@ private:
size_t id_;
ir::output_connector &owner_;
memory_location_t memory_location_;
module_type_t shared_module_;
std::optional<sub_buffer_desc> parent_;
ir::shape_t strides_shape_;
buffer_lifetime lifetime_ {};

View File

@ -0,0 +1,38 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "schedule_types.h"
namespace nncase::schedule
{
class lifetime_recorder
{
public:
lifetime_recorder(std::list<logical_buffer> &buffers, std::unordered_map<const ir::output_connector *, logical_buffer *> &buffer_map);
size_t current_age() const noexcept { return cnt_age_; }
void current_age(size_t age);
void allocate(ir::output_connector &conn, memory_location_t location);
void release(ir::output_connector &conn);
void grow_age();
private:
size_t next_buffer_id_ = 0;
size_t cnt_age_ = 0;
std::list<logical_buffer> &buffers_;
std::unordered_map<const ir::output_connector *, logical_buffer *> &buffer_map_;
};
}

View File

@ -0,0 +1,135 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "buffer_allocator.h"
#include "liveness_analysis.h"
#include "schedule_types.h"
#include <filesystem>
namespace nncase
{
class target;
}
namespace nncase::schedule
{
class module_schedule_context;
class model_schedule_context;
struct caller_context
{
lifetime_recorder &lifetime;
};
class function_schedule_context : public function_schedule_result
{
public:
function_schedule_context(ir::graph &graph, module_schedule_context &mod_sched);
function_schedule_context(const function_schedule_context &) = delete;
function_schedule_context(function_schedule_context &&) = default;
function_schedule_context &operator=(const function_schedule_context &) = delete;
const module_type_t &module_type() const noexcept { return graph->module_type(); }
std::span<ir::output_node *> outputs() const noexcept { return outputs_; }
std::unordered_map<const ir::output_connector *, logical_buffer *> &logical_buffer_map() noexcept { return logical_buffer_map_; }
std::list<logical_buffer> &logical_buffers() noexcept { return logical_buffers_; }
std::vector<physical_buffer> &physical_buffers() noexcept { return physical_buffers_; }
void visit_function(caller_context &caller_ctx);
void end_schedule();
private:
void create_allocators();
void generate_compute_sequence();
void make_logical_buffers(caller_context &caller_ctx);
void analyze_buffer_alias();
void update_offset();
void fix_lifetime();
void make_physical_buffers();
void allocate_physical_buffers();
void assign_allocations();
void dump(const std::filesystem::path &dump_dir);
private:
module_schedule_context &mod_sched_;
std::span<ir::output_node *> outputs_;
allocator_map_t allocators_;
std::vector<std::shared_ptr<buffer_allocator>> allocator_holder_;
std::unordered_map<const ir::output_connector *, logical_buffer *> logical_buffer_map_;
std::list<logical_buffer> logical_buffers_;
std::vector<physical_buffer> physical_buffers_;
};
class module_schedule_context
{
public:
module_schedule_context(module_schedule_result &result, model_schedule_context &model_sched, module_type_t type);
module_schedule_context(const module_schedule_context &) = delete;
module_schedule_context(module_schedule_context &&) = default;
module_schedule_context &operator=(const module_schedule_context &) = delete;
module_schedule_result &module_result() const noexcept { return result_; }
model_schedule_context &model_sched() const noexcept { return model_sched_; }
allocator_map_t &allocators() noexcept { return allocators_; }
buffer_allocator &shared_allocator(const module_type_t &type);
void visit_function(ir::graph &graph, caller_context &caller_ctx);
void end_schedule();
private:
module_schedule_result &result_;
model_schedule_context &model_sched_;
module_type_t type_;
allocator_map_t allocators_;
std::vector<std::shared_ptr<buffer_allocator>> allocator_holder_;
shared_allocator_map_t shared_allocators_;
std::vector<function_schedule_context> functions_;
std::filesystem::path dump_dir_;
};
class model_schedule_context
{
public:
model_schedule_context(model_schedule_result &result, nncase::target &target, bool skip_buffer_alias);
model_schedule_context(const model_schedule_context &) = delete;
model_schedule_context(model_schedule_context &&) = default;
model_schedule_context &operator=(const model_schedule_context &) = delete;
nncase::target &target() const noexcept { return target_; }
bool skip_buffer_alias() const noexcept { return skip_buffer_alias_; }
void config_dump(std::filesystem::path dump_dir);
const std::filesystem::path &dump_dir() const noexcept { return dump_dir_; }
model_schedule_result &model_result() const noexcept { return result_; }
void schedule(ir::graph &entry_function);
void visit_function(ir::graph &graph, caller_context &caller_ctx);
private:
void end_schedule();
private:
model_schedule_result &result_;
nncase::target &target_;
bool skip_buffer_alias_;
std::filesystem::path dump_dir_;
module_schedule_context *entry_module_;
ir::graph *entry_function_;
std::unordered_map<module_type_t, module_schedule_context> module_contexts_;
};
}

View File

@ -0,0 +1,75 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "buffer_allocator.h"
#include "buffers.h"
#include <nncase/ir/graph.h>
#include <unordered_map>
#include <vector>
namespace nncase::schedule
{
struct buffer_allocation
{
memory_location_t memory_location;
datatype_t type;
size_t shared_module;
size_t start;
size_t size;
ir::shape_t shape;
ir::shape_t strides;
ir::shape_t strides_shape;
size_t linear_end() const noexcept { return start + size; }
bool overlap(const buffer_allocation &rhs) const noexcept
{
return size != 0 && rhs.size != 0 && memory_location == rhs.memory_location && (start < rhs.linear_end() && linear_end() > rhs.start);
}
memory_range runtime_type() const
{
return { .memory_location = memory_location, .datatype = type, .shared_module = (uint16_t)shared_module, .start = (uint32_t)start, .size = (uint32_t)size };
}
};
using allocation_map_t = std::unordered_map<const ir::output_connector *, buffer_allocation>;
struct module_schedule_result;
struct function_schedule_result
{
ir::graph *graph;
module_schedule_result *module;
std::vector<ir::node *> compute_sequence;
size_t input_pool_size;
size_t output_pool_size;
};
struct module_schedule_result
{
module_type_t type;
std::vector<function_schedule_result> functions;
std::unordered_map<ir::graph *, function_schedule_result *> functions_map;
allocation_map_t allocations;
std::unordered_map<memory_location_t, size_t> max_usages;
std::unordered_map<module_type_t, size_t> shared_max_usages;
};
struct model_schedule_result
{
std::vector<module_schedule_result> modules;
function_schedule_result *entry_function;
};
}

View File

@ -13,14 +13,10 @@
* limitations under the License.
*/
#pragma once
#include "buffer_allocator.h"
#include "buffers.h"
#include "schedule_context.h"
#include "schedule_types.h"
#include <filesystem>
#include <functional>
#include <nncase/ir/graph.h>
#include <span>
#include <unordered_map>
#include <vector>
namespace nncase
{
@ -28,78 +24,14 @@ class target;
namespace schedule
{
struct buffer_allocation
{
memory_location_t memory_location;
datatype_t type;
size_t start;
size_t size;
ir::shape_t shape;
ir::shape_t strides;
ir::shape_t strides_shape;
size_t linear_end() const noexcept { return start + size; }
bool overlap(const buffer_allocation &rhs) const noexcept
{
return size != 0 && rhs.size != 0 && memory_location == rhs.memory_location && (start < rhs.linear_end() && linear_end() > rhs.start);
}
memory_range runtime_type() const
{
return { .memory_location = memory_location, .datatype = type, .start = (uint32_t)start, .size = (uint32_t)size };
}
};
using allocation_map_t = std::unordered_map<const ir::output_connector *, buffer_allocation>;
struct module_schedule_result
{
ir::graph *graph;
std::vector<ir::node *> compute_sequence;
std::unordered_map<memory_location_t, size_t> max_usages;
allocation_map_t allocations;
};
struct schedule_result
{
std::vector<ir::graph *> graph_orders;
std::unordered_map<ir::graph *, module_schedule_result> modules;
ir::graph *main_module;
};
struct schedule_context : module_schedule_result
{
bool skip_buffer_alias = false;
nncase::target *target;
module_type_t module_type;
std::span<ir::output_node *> outputs;
std::unordered_map<const ir::output_connector *, logical_buffer *> logical_buffer_map;
std::list<logical_buffer> logical_buffers;
std::vector<physical_buffer> physical_buffers;
void generate_compute_sequence();
void make_logical_buffers();
void analyze_buffer_alias();
void update_offset();
void fix_lifetime();
void make_physical_buffers();
void allocate_physical_buffers();
void assign_allocations();
};
class NNCASE_API scheduler
{
public:
scheduler(target &target, ir::graph &main_graph, std::span<ir::output_node *> outputs)
: target_(target), main_graph_(main_graph), outputs_(outputs) { }
scheduler(target &target, ir::graph &main_graph, std::span<ir::output_node *> outputs);
schedule_result schedule(bool skip_buffer_alias = false);
model_schedule_result schedule(bool skip_buffer_alias = false);
void config_dump(std::filesystem::path dump_dir);
private:
void dump_schedule(const schedule_context &context);
private:
target &target_;
ir::graph &main_graph_;

View File

@ -28,7 +28,7 @@ namespace nncase::ir::transforms
struct run_pass_options
{
ir::quantizer *quantizer;
schedule::schedule_context *schedule_context;
schedule::function_schedule_context *schedule_context;
std::optional<std::filesystem::path> dump_dir;
};
@ -105,14 +105,14 @@ public:
void dump_dir(const std::filesystem::path &dir);
void quantizer(ir::quantizer *q);
void schedule_context(schedule::schedule_context *c);
void schedule_context(schedule::function_schedule_context *c);
private:
std::vector<std::unique_ptr<pass>> passes_;
graph &graph_;
nncase::target &target_;
ir::quantizer *quantizer_;
schedule::schedule_context *schedule_context_;
schedule::function_schedule_context *schedule_context_;
std::optional<std::filesystem::path> dump_dir_;
};
}

View File

@ -24,7 +24,7 @@ class target;
namespace schedule
{
struct schedule_context;
class function_schedule_context;
}
namespace ir
@ -45,7 +45,7 @@ namespace ir
ir::graph &graph;
nncase::target &target;
ir::quantizer *quantizer;
schedule::schedule_context *schedule_context;
schedule::function_schedule_context *schedule_context;
std::optional<std::filesystem::path> dump_dir;
std::vector<node *> matched_nodes;
std::vector<input_connector *> inputs;

View File

@ -25,18 +25,6 @@
#define NNCASE_MODULES_K210_API __attribute__((visibility("default")))
#endif
#define BEGIN_NS_NNCASE_RT_K210 \
namespace nncase \
{ \
namespace runtime \
{ \
namespace k210 \
{
#define END_NS_NNCASE_RT_K210 \
} \
} \
}
#define BEGIN_NS_NNCASE_KERNELS_K210 \
namespace nncase \
{ \

View File

@ -16,7 +16,7 @@
#include "compiler_defs.h"
#include <nncase/runtime/error.h>
BEGIN_NS_NNCASE_RT_K210
BEGIN_NS_NNCASE_RT_MODULE(k210)
enum class nncase_k210_errc
{
@ -26,7 +26,7 @@ enum class nncase_k210_errc
NNCASE_MODULES_K210_API const std::error_category &nncase_k210_category() noexcept;
NNCASE_MODULES_K210_API std::error_condition make_error_condition(nncase_k210_errc code);
END_NS_NNCASE_RT_K210
END_NS_NNCASE_RT_MODULE
namespace std
{

View File

@ -17,7 +17,7 @@
#include <nncase/runtime/result.h>
#include <nncase/runtime/span_reader.h>
BEGIN_NS_NNCASE_RT_K210
BEGIN_NS_NNCASE_RT_MODULE(k210)
class NNCASE_MODULES_K210_API op_visitor
{
@ -44,4 +44,4 @@ private:
result<void> next() noexcept;
};
END_NS_NNCASE_RT_K210
END_NS_NNCASE_RT_MODULE

View File

@ -16,10 +16,11 @@
#include "compiler_defs.h"
#include <nncase/runtime/runtime_module.h>
BEGIN_NS_NNCASE_RT_K210
BEGIN_NS_NNCASE_RT_MODULE(k210)
NNCASE_INLINE_VAR constexpr module_type_t k210_module_type = to_module_type("k210");
NNCASE_INLINE_VAR constexpr uint32_t k210_module_version = 1;
NNCASE_MODULES_K210_API result<std::unique_ptr<runtime_module>> create_k210_runtime_module();
END_NS_NNCASE_RT_K210
END_NS_NNCASE_RT_MODULE

View File

@ -15,7 +15,7 @@
#pragma once
#include "runtime_types.h"
BEGIN_NS_NNCASE_RT_K210
BEGIN_NS_NNCASE_RT_MODULE(k210)
struct kpu_layout
{
@ -184,4 +184,4 @@ inline std::array<int32_t, 2> get_kpu_select_pool_offset(kpu_pool_type_t pool_ty
}
}
END_NS_NNCASE_RT_K210
END_NS_NNCASE_RT_MODULE

View File

@ -16,9 +16,9 @@
#include "compiler_defs.h"
#include <nncase/runtime/datatypes.h>
BEGIN_NS_NNCASE_RT_K210
BEGIN_NS_NNCASE_RT_MODULE(k210)
NNCASE_INLINE_VAR constexpr memory_location_t mem_kpu = 4;
NNCASE_INLINE_VAR constexpr memory_location_t mem_kpu = mem_private_base + 0;
NNCASE_INLINE_VAR constexpr size_t KPU_RAM_SIZE = 2 * 1024 * 1024; // 2MB
typedef struct
@ -341,4 +341,4 @@ struct copy_options
kpu_shape_t out_strides;
};
END_NS_NNCASE_RT_K210
END_NS_NNCASE_RT_MODULE

View File

@ -38,11 +38,26 @@ module_type_t k210_module_builder::module_type() const noexcept
return k210_module_type;
}
uint32_t k210_module_builder::module_version() const noexcept
{
return k210_module_version;
}
section_writer &k210_module_builder::text_writer()
{
return writer(".text");
}
void k210_module_builder::begin_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
{
set_current_entry_point(text_writer().position());
}
void k210_module_builder::end_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
{
set_current_function_text_end(text_writer().position());
}
void k210_module_builder::emit(ir::node &node)
{
#define DEFINE_OP(op) \

View File

@ -28,10 +28,13 @@ public:
k210_module_builder(std::string_view module_name, const module_builder_params &params);
module_type_t module_type() const noexcept override;
uint32_t module_version() const noexcept override;
protected:
section_writer &text_writer();
void begin_emit_function(const schedule::function_schedule_result &function) override;
void end_emit_function(const schedule::function_schedule_result &function) override;
void emit(ir::node &node) override;
private:

View File

@ -27,7 +27,7 @@ using namespace nncase::runtime;
void ir::k210::register_k210_evaluators()
{
register_evaluator(op_k210_fake_kpu_conv2d, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_k210_fake_kpu_conv2d, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<fake_kpu_conv2d &>(node);
assert(rnode.input().type() == dt_float32);
@ -42,13 +42,13 @@ void ir::k210::register_k210_evaluators()
auto in_shape = input.shape();
shape_t conv_out_shape { in_shape[0], (size_t)rnode.output_channels(), in_shape[2], in_shape[3] };
auto conv_out_fmap_size = xt::compute_size(conv_out_shape);
auto conv_out_fmap_size = runtime::compute_size(conv_out_shape);
auto conv_output_tmp = std::make_unique<float[]>(conv_out_fmap_size);
auto batch = in_shape[0];
auto in_size_per_batch = xt::compute_size(in_shape) / batch;
auto in_size_per_batch = runtime::compute_size(in_shape) / batch;
auto conv_output_tmp_size_per_batch = conv_out_fmap_size / batch;
auto out_size_per_batch = xt::compute_size(rnode.output().shape()) / batch;
auto out_size_per_batch = runtime::compute_size(rnode.output().shape()) / batch;
auto p_input = input_mem.data();
auto p_conv_ouput_tmp = conv_output_tmp.get();
auto p_output = output_mem.data();

View File

@ -3,6 +3,7 @@
add_subdirectory(kendryte)
set(SRCS runtime_module.cpp
runtime_function.cpp
op_reader.cpp
error.cpp
ops/kpu_conv2d.cpp

View File

@ -12,7 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/k210/k210_kernels.h>
#include <nncase/kernels/tensor_compute.h>
@ -20,7 +20,7 @@ using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::k210;
result<void> k210_runtime_module::visit(const copy_options &op) noexcept
result<void> k210_runtime_function::visit(const copy_options &op) noexcept
{
try_var(input, memory_at(op.input));
try_var(output, memory_at(op.output));

View File

@ -12,7 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/k210/k210_kernels.h>
#include <nncase/runtime/dbg.h>
#ifndef NNCASE_SIMULATOR
@ -71,7 +71,7 @@ void kpu_conv2d_normal(runtime::k210::kpu_layer_argument_t &layer, plic_irq_call
#endif
}
result<void> k210_runtime_module::visit(const kpu_conv2d_options &op) noexcept
result<void> k210_runtime_function::visit(const kpu_conv2d_options &op) noexcept
{
auto layer = op.layer;

View File

@ -12,14 +12,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/k210/k210_kernels.h>
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::k210;
result<void> k210_runtime_module::visit(const kpu_download_options &op) noexcept
result<void> k210_runtime_function::visit(const kpu_download_options &op) noexcept
{
try_var(input, memory_at(op.input));
try_var(output, memory_at(op.output));

View File

@ -12,14 +12,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/k210/k210_kernels.h>
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::k210;
result<void> k210_runtime_module::visit(const kpu_upload_options &op) noexcept
result<void> k210_runtime_function::visit(const kpu_upload_options &op) noexcept
{
try_var(input, memory_at(op.input));
try_var(output, memory_at(op.output));
@ -30,7 +30,7 @@ result<void> k210_runtime_module::visit(const kpu_upload_options &op) noexcept
#ifdef NNCASE_SIMULATOR
0
#else
dma_ch_
module().dma_ch()
#endif
);
}

View File

@ -0,0 +1,145 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime_function.h"
#include <nncase/runtime/host_runtime_tensor.h>
#include <nncase/runtime/interpreter.h>
#include <nncase/runtime/k210/error.h>
#include <nncase/runtime/k210/runtime_types.h>
#include <nncase/runtime/runtime_loader.h>
#ifndef NNCASE_SIMULATOR
#include <kpu.h>
#endif
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::detail;
using namespace nncase::runtime::k210;
k210_runtime_module &k210_runtime_function::module() const noexcept
{
return static_cast<k210_runtime_module &>(runtime_function::module());
}
result<void> k210_runtime_function::initialize_core(runtime_function_init_context &context) noexcept
{
text_ = context.module_init_context().section(".text").subspan(context.header().entrypoint, context.header().text_size);
return ok();
}
result<runtime_tensor> k210_runtime_function::allocate_input_tensor(size_t index) noexcept
{
return hrt::create(input_desc(index).datatype, input_shape(index), hrt::pool_shared);
}
result<runtime_tensor> k210_runtime_function::allocate_output_tensor(size_t index) noexcept
{
return hrt::create(output_desc(index).datatype, output_shape(index), hrt::pool_shared);
}
result<void> k210_runtime_function::validate_input_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
{
if (tensor.is_host()
&& hrt::memory_pool(tensor).unwrap() == hrt::pool_shared)
return ok();
return err(std::errc::invalid_argument);
}
result<void> k210_runtime_function::validate_output_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
{
if (tensor.is_host())
return ok();
return err(std::errc::invalid_argument);
}
result<void> k210_runtime_function::invoke_core() noexcept
{
for (size_t i = 0; i < inputs_size(); i++)
{
try_var(input, device_input_tensor(i));
try_(hrt::sync(input, hrt::sync_write_back));
}
try_(visit(text_));
return ok();
}
result<gsl::span<gsl::byte>> k210_runtime_function::memory_at(const memory_range &mrange) noexcept
{
#define ID_NOT_FOUND ((size_t)-1)
gsl::byte *base;
switch (mrange.memory_location)
{
case mem_input:
{
size_t id = ID_NOT_FOUND;
for (size_t i = 0; i < inputs_size(); i++)
{
if (mrange.start == input_desc(i).start)
{
id = i;
break;
}
}
if (id != ID_NOT_FOUND)
{
try_var(tensor, device_input_tensor(id));
base = reinterpret_cast<gsl::byte *>(static_cast<host_runtime_tensor_impl *>(tensor.impl())->memory_block().virtual_address - mrange.start);
}
else
{
return err(std::errc::invalid_argument);
}
break;
}
case mem_output:
{
size_t id = ID_NOT_FOUND;
for (size_t i = 0; i < outputs_size(); i++)
{
if (mrange.start == output_desc(i).start)
{
id = i;
break;
}
}
if (id != ID_NOT_FOUND)
{
try_var(tensor, device_output_tensor(id));
try_var(tensor_map, hrt::map(tensor, hrt::map_read_write));
base = tensor_map.buffer().data() - mrange.start;
}
else
{
return err(std::errc::invalid_argument);
}
break;
}
case mem_rdata:
base = const_cast<gsl::byte *>(module().rdata().data());
break;
case mem_data:
base = module().data().data();
break;
case mem_kpu:
base = module().kpu_ram().data();
break;
default:
return err(nncase_errc::invalid_memory_location);
}
return ok(gsl::make_span(base + mrange.start, mrange.size));
}

View File

@ -0,0 +1,51 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "runtime_module.h"
#include <nncase/runtime/k210/op_reader.h>
#include <nncase/runtime/k210/runtime_types.h>
#include <nncase/runtime/runtime_function.h>
BEGIN_NS_NNCASE_RT_MODULE(k210)
class k210_runtime_function : public runtime_function, private op_visitor
{
public:
using runtime_function::runtime_function;
k210_runtime_module &module() const noexcept;
protected:
result<void> initialize_core(runtime_function_init_context &context) noexcept override;
result<runtime_tensor> allocate_input_tensor(size_t index) noexcept override;
result<runtime_tensor> allocate_output_tensor(size_t index) noexcept override;
result<void> validate_input_tensor(size_t index, runtime_tensor tensor) noexcept override;
result<void> validate_output_tensor(size_t index, runtime_tensor tensor) noexcept override;
result<void> invoke_core() noexcept override;
using op_visitor::visit;
result<void> visit(const kpu_conv2d_options &op) noexcept override;
result<void> visit(const kpu_download_options &op) noexcept override;
result<void> visit(const kpu_upload_options &op) noexcept override;
result<void> visit(const copy_options &op) noexcept override;
private:
result<gsl::span<gsl::byte>> memory_at(const memory_range &mrange) noexcept;
private:
gsl::span<const gsl::byte> text_;
};
END_NS_NNCASE_RT_MODULE

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "runtime_module.h"
#include "runtime_function.h"
#include <nncase/runtime/host_runtime_tensor.h>
#include <nncase/runtime/interpreter.h>
#include <nncase/runtime/k210/error.h>
@ -27,7 +28,7 @@ using namespace nncase::runtime;
using namespace nncase::runtime::detail;
using namespace nncase::runtime::k210;
result<void> k210_runtime_module::initialize_core(runtime_module_init_context &context) noexcept
result<void> k210_runtime_module::initialize_before_functions(runtime_module_init_context &context) noexcept
{
#ifndef NNCASE_SIMULATOR
kpu->interrupt_clear.reg = 7;
@ -56,127 +57,33 @@ result<void> k210_runtime_module::initialize_core(runtime_module_init_context &c
return ok();
}
result<runtime_tensor> k210_runtime_module::allocate_input_tensor(size_t index) noexcept
gsl::span<gsl::byte> k210_runtime_module::data() const noexcept
{
return hrt::create(input_desc(index).datatype, input_shape(index), hrt::pool_shared);
return { data_.get(), mempool(mem_data).size };
}
result<runtime_tensor> k210_runtime_module::allocate_output_tensor(size_t index) noexcept
gsl::span<gsl::byte> k210_runtime_module::kpu_ram() noexcept
{
return hrt::create(output_desc(index).datatype, output_shape(index), hrt::pool_shared);
}
result<void> k210_runtime_module::validate_input_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
{
if (tensor.is_host()
&& hrt::memory_pool(tensor).unwrap() == hrt::pool_shared)
return ok();
return err(std::errc::invalid_argument);
}
result<void> k210_runtime_module::validate_output_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
{
if (tensor.is_host())
return ok();
return err(std::errc::invalid_argument);
}
result<void> k210_runtime_module::run_core() noexcept
{
for (size_t i = 0; i < inputs_size(); i++)
{
try_var(input, device_input_tensor(i));
try_(hrt::sync(input, hrt::sync_write_back));
}
#ifndef NNCASE_SIMULATOR
auto dma_ch = interp().options().get<uint32_t>("dma_ch");
if (dma_ch.is_ok())
{
dma_ch_ = dma_ch.unwrap();
}
else
{
printf("[WARN] KPU DMA channel not set, default to DMAC_CHANNEL5.\n");
dma_ch_ = 5;
}
#endif
try_(visit(text_));
return ok();
}
result<gsl::span<gsl::byte>> k210_runtime_module::memory_at(const memory_range &mrange) noexcept
{
#define ID_NOT_FOUND ((size_t)-1)
gsl::byte *base;
switch (mrange.memory_location)
{
case mem_input:
{
size_t id = ID_NOT_FOUND;
for (size_t i = 0; i < inputs_size(); i++)
{
if (mrange.start == input_desc(i).start)
{
id = i;
break;
}
}
if (id != ID_NOT_FOUND)
{
try_var(tensor, device_input_tensor(id));
base = reinterpret_cast<gsl::byte *>(static_cast<host_runtime_tensor_impl *>(tensor.impl())->memory_block().virtual_address - mrange.start);
}
else
{
return err(std::errc::invalid_argument);
}
break;
}
case mem_output:
{
size_t id = ID_NOT_FOUND;
for (size_t i = 0; i < outputs_size(); i++)
{
if (mrange.start == output_desc(i).start)
{
id = i;
break;
}
}
if (id != ID_NOT_FOUND)
{
try_var(tensor, device_output_tensor(id));
try_var(tensor_map, hrt::map(tensor, hrt::map_read_write));
base = tensor_map.buffer().data() - mrange.start;
}
else
{
return err(std::errc::invalid_argument);
}
break;
}
case mem_rdata:
base = const_cast<gsl::byte *>(rdata_.data());
break;
case mem_data:
base = data_.get();
break;
case mem_kpu:
#ifdef NNCASE_SIMULATOR
base = kpu_ram_.data();
base = kpu_ram_.data();
#else
base = reinterpret_cast<gsl::byte *>(AI_IO_BASE_ADDR);
base = reinterpret_cast<gsl::byte *>(AI_IO_BASE_ADDR);
#endif
break;
default:
return err(nncase_errc::invalid_memory_location);
}
return { base, KPU_RAM_SIZE };
}
return ok(gsl::make_span(base + mrange.start, mrange.size));
gsl::span<const gsl::byte> k210_runtime_module::rdata() const noexcept
{
return rdata_;
}
result<std::unique_ptr<runtime_function>> k210_runtime_module::create_function() noexcept
{
std::unique_ptr<runtime_function> mod(new (std::nothrow) k210_runtime_function(*this));
if (mod)
return ok(std::move(mod));
return err(std::errc::not_enough_memory);
}
result<std::unique_ptr<runtime_module>> k210::create_k210_runtime_module()

View File

@ -13,31 +13,28 @@
* limitations under the License.
*/
#pragma once
#include <nncase/runtime/k210/op_reader.h>
#include <nncase/runtime/k210/runtime_module.h>
#include <nncase/runtime/k210/runtime_types.h>
BEGIN_NS_NNCASE_RT_K210
BEGIN_NS_NNCASE_RT_MODULE(k210)
class k210_runtime_module : public runtime_module, private op_visitor
class k210_runtime_module : public runtime_module
{
public:
gsl::span<gsl::byte> data() const noexcept;
gsl::span<const gsl::byte> rdata() const noexcept;
gsl::span<gsl::byte> kpu_ram() noexcept;
#if !NNCASE_SIMULATOR
uint32_t dma_ch() const noexcept
{
return dma_ch_;
}
#endif
protected:
result<void> initialize_core(runtime_module_init_context &context) noexcept override;
result<runtime_tensor> allocate_input_tensor(size_t index) noexcept override;
result<runtime_tensor> allocate_output_tensor(size_t index) noexcept override;
result<void> validate_input_tensor(size_t index, runtime_tensor tensor) noexcept override;
result<void> validate_output_tensor(size_t index, runtime_tensor tensor) noexcept override;
result<void> run_core() noexcept override;
using op_visitor::visit;
result<void> visit(const kpu_conv2d_options &op) noexcept override;
result<void> visit(const kpu_download_options &op) noexcept override;
result<void> visit(const kpu_upload_options &op) noexcept override;
result<void> visit(const copy_options &op) noexcept override;
private:
result<gsl::span<gsl::byte>> memory_at(const memory_range &mrange) noexcept;
result<void> initialize_before_functions(runtime_module_init_context &context) noexcept override;
result<std::unique_ptr<runtime_function>> create_function() noexcept override;
private:
std::unique_ptr<gsl::byte[]> data_;
@ -50,4 +47,4 @@ private:
#endif
};
END_NS_NNCASE_RT_K210
END_NS_NNCASE_RT_MODULE

View File

@ -22,6 +22,7 @@ NNCASE_INLINE_VAR constexpr char SHADER_SECTION_NAME[] = ".shader";
NNCASE_INLINE_VAR constexpr char DESCRIPTORS_SECTION_NAME[] = ".descriptors";
NNCASE_INLINE_VAR constexpr module_type_t vulkan_module_type = to_module_type("vulkan");
NNCASE_INLINE_VAR constexpr uint32_t vulkan_module_version = 1;
NNCASE_MODULES_VULKAN_API result<std::unique_ptr<runtime_module>> create_vulkan_runtime_module();

View File

@ -39,6 +39,11 @@ module_type_t vulkan_module_builder::module_type() const noexcept
return vulkan_module_type;
}
uint32_t vulkan_module_builder::module_version() const noexcept
{
return vulkan_module_version;
}
section_writer &vulkan_module_builder::text_writer()
{
return writer(".text");
@ -104,17 +109,27 @@ void vulkan_module_builder::ldpipeline(ir::node &node, size_t shader_index, ldpi
tw.write(op);
}
void vulkan_module_builder::begin_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
{
set_current_entry_point(text_writer().position());
}
void vulkan_module_builder::end_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
{
set_current_function_text_end(text_writer().position());
}
void vulkan_module_builder::emit(ir::node &node)
{
#define DEFINE_OP(op) \
if (node.runtime_opcode() == ir::op::opcode()) \
return emit(static_cast<ir::op &>(node));
return emit(static_cast<op &>(node));
#include "ops.def"
#undef DEFINE_OP
module_builder::emit(node);
}
void vulkan_module_builder::end_emit()
void vulkan_module_builder::end_emit_module()
{
auto &sw = writer(DESCRIPTORS_SECTION_NAME);
sw.write(descriptor_sets_);

View File

@ -29,13 +29,16 @@ public:
vulkan_module_builder(std::string_view module_name, const module_builder_params &params);
module_type_t module_type() const noexcept override;
uint32_t module_version() const noexcept override;
protected:
section_writer &text_writer();
section_writer &shader_writer();
void begin_emit_function(const schedule::function_schedule_result &function) override;
void end_emit_function(const schedule::function_schedule_result &function) override;
void emit(ir::node &node) override;
void end_emit() override;
void end_emit_module() override;
private:
std::vector<uint32_t> compile_shader(ir::node &node, const std::string &template_name, const nlohmann::json &context);

View File

@ -1,6 +1,7 @@
cmake_minimum_required (VERSION 3.13)
set(SRCS runtime_module.cpp
runtime_function.cpp
vulkan_error.cpp
op_reader.cpp
ops/ldbuf.cpp

View File

@ -12,7 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include "../vulkan_error.h"
#include <nncase/runtime/dbg.h>
#include <nncase/runtime/error.h>
@ -21,7 +21,7 @@ using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::vulkan;
result<void> vulkan_runtime_module::visit(const barrier_op_t &op) noexcept
result<void> vulkan_runtime_function::visit(const barrier_op_t &op) noexcept
{
CHECK_WITH_ERR(op.memory_barriers == 0, std::errc::not_supported);

View File

@ -12,7 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include "../vulkan_error.h"
#include <nncase/runtime/error.h>
@ -20,7 +20,7 @@ using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::vulkan;
result<void> vulkan_runtime_module::visit(const copybuf_op_t &op) noexcept
result<void> vulkan_runtime_function::visit(const copybuf_op_t &op) noexcept
{
try_var(output, pop_buffer_ref());
try_var(input, pop_buffer_ref());

View File

@ -12,13 +12,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::vulkan;
result<void> vulkan_runtime_module::visit(const dispatch_op_t &op) noexcept
result<void> vulkan_runtime_function::visit(const dispatch_op_t &op) noexcept
{
cmd_buffer_.dispatch(op.x, op.y, op.z);
return ok();

View File

@ -12,7 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include "../vulkan_error.h"
#include <nncase/runtime/error.h>
@ -20,24 +20,24 @@ using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::vulkan;
result<void> vulkan_runtime_module::visit(const ldbuf_op_t &op) noexcept
result<void> vulkan_runtime_function::visit(const ldbuf_op_t &op) noexcept
{
vk::Buffer *dev_buf;
vk::Buffer dev_buf;
switch (op.memory.memory_location)
{
case mem_input:
dev_buf = &input_buffer_;
dev_buf = input_buffer_;
break;
case mem_output:
dev_buf = &output_buffer_;
dev_buf = output_buffer_;
break;
case mem_data:
dev_buf = &data_buffer_;
dev_buf = module().data();
break;
default:
return err(nncase_errc::invalid_memory_location);
}
buffer_refs_.emplace_back(buffer_ref { *dev_buf, op.memory.start, op.memory.size });
buffer_refs_.emplace_back(buffer_ref { dev_buf, op.memory.start, op.memory.size });
return ok();
}

View File

@ -12,7 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include "../vulkan_error.h"
#include <nncase/runtime/error.h>
@ -20,26 +20,26 @@ using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::vulkan;
result<void> vulkan_runtime_module::visit(const ldbufbarrier_op_t &op) noexcept
result<void> vulkan_runtime_function::visit(const ldbufbarrier_op_t &op) noexcept
{
vk::Buffer *dev_buf;
vk::Buffer dev_buf;
switch (op.memory.memory_location)
{
case mem_input:
dev_buf = &input_buffer_;
dev_buf = input_buffer_;
break;
case mem_output:
dev_buf = &output_buffer_;
dev_buf = output_buffer_;
break;
case mem_data:
dev_buf = &data_buffer_;
dev_buf = module().data();
break;
default:
return err(nncase_errc::invalid_memory_location);
}
buffer_barriers_.emplace_back(vk::BufferMemoryBarrier((vk::AccessFlagBits)op.src_access_mask,
(vk::AccessFlagBits)op.dest_access_mask, compute_queue_index_, compute_queue_index_, *dev_buf,
(vk::DeviceSize)op.memory.start, (vk::DeviceSize)op.memory.size));
(vk::AccessFlagBits)op.dest_access_mask, module().compute_queue_index(), module().compute_queue_index(),
dev_buf, (vk::DeviceSize)op.memory.start, (vk::DeviceSize)op.memory.size));
return ok();
}

View File

@ -12,7 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include "../vulkan_error.h"
#include <nncase/runtime/error.h>
@ -20,7 +20,7 @@ using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::vulkan;
result<void> vulkan_runtime_module::visit(const ldbufcopy_op_t &op) noexcept
result<void> vulkan_runtime_function::visit(const ldbufcopy_op_t &op) noexcept
{
buffer_copies_.emplace_back(vk::BufferCopy((vk::DeviceSize)op.src, (vk::DeviceSize)op.dest, (vk::DeviceSize)op.size));
return ok();

View File

@ -12,7 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include "../vulkan_error.h"
#include <vulkan/vulkan.h>
@ -20,11 +20,11 @@ using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::vulkan;
result<void> vulkan_runtime_module::visit(const ldpipeline_op_t &op) noexcept
result<void> vulkan_runtime_function::visit(const ldpipeline_op_t &op) noexcept
{
auto code = shader_.subspan(op.shader_start, op.shader_size).as_span<const uint32_t>();
auto code = module().shader().subspan(op.shader_start, op.shader_size).as_span<const uint32_t>();
vk::ShaderModuleCreateInfo shader_cinfo({}, op.shader_size, code.data());
try_var(shader, vk::to_result(device_.createShaderModule(shader_cinfo)));
try_var(shader, vk::to_result(module().device().createShaderModule(shader_cinfo)));
std::vector<vk::DescriptorSetLayoutBinding> layout_bindings((size_t)op.buffers);
for (int32_t i = (int32_t)op.buffers - 1; i >= 0; i--)
@ -37,20 +37,20 @@ result<void> vulkan_runtime_module::visit(const ldpipeline_op_t &op) noexcept
}
vk::DescriptorSetLayoutCreateInfo desc_layout_cinfo({}, layout_bindings);
try_var(desc_layout, vk::to_result(device_.createDescriptorSetLayout(desc_layout_cinfo)));
try_var(desc_layout, vk::to_result(module().device().createDescriptorSetLayout(desc_layout_cinfo)));
vk::PipelineLayoutCreateInfo ppl_layout_cinfo({}, desc_layout);
try_var(ppl_layout, vk::to_result(device_.createPipelineLayout(ppl_layout_cinfo)));
try_var(ppl_layout, vk::to_result(module().device().createPipelineLayout(ppl_layout_cinfo)));
if (op.shader_type != shader_type_t::compute)
return err(std::errc::not_supported);
vk::ComputePipelineCreateInfo comp_ppl_cinfo({}, { {}, vk::ShaderStageFlagBits::eCompute, shader, "main" }, ppl_layout);
try_var(pipeline, vk::to_result(device_.createComputePipeline({}, comp_ppl_cinfo)));
pipelines_owner_.emplace_back(pipeline);
try_var(pipeline, vk::to_result(module().device().createComputePipeline({}, comp_ppl_cinfo)));
try_(module().add_pipeline(pipeline));
vk::DescriptorSetAllocateInfo desc_alloc_info(buffer_desc_pool_, desc_layout);
try_var(desc_sets, vk::to_result(device_.allocateDescriptorSets(desc_alloc_info)));
vk::DescriptorSetAllocateInfo desc_alloc_info(module().buffer_desc_pool(), desc_layout);
try_var(desc_sets, vk::to_result(module().device().allocateDescriptorSets(desc_alloc_info)));
std::vector<vk::DescriptorBufferInfo> buffer_infos((size_t)op.buffers);
std::vector<vk::WriteDescriptorSet> write_descs(buffer_infos.size());
@ -72,7 +72,7 @@ result<void> vulkan_runtime_module::visit(const ldpipeline_op_t &op) noexcept
write_desc.setDstSet(desc_sets[0]);
}
device_.updateDescriptorSets(write_descs, {});
module().device().updateDescriptorSets(write_descs, {});
cmd_buffer_.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline);
cmd_buffer_.bindDescriptorSets(vk::PipelineBindPoint::eCompute, ppl_layout, 0, desc_sets, {});

View File

@ -0,0 +1,177 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime_function.h"
#include "vulkan_error.h"
#include <nncase/runtime/dbg.h>
#include <nncase/runtime/host_runtime_tensor.h>
#include <nncase/runtime/runtime_loader.h>
#include <nncase/runtime/runtime_op_utility.h>
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::vulkan;
vulkan_runtime_function::~vulkan_runtime_function()
{
free_vulkan_resources();
}
vulkan_runtime_module &vulkan_runtime_function::module() const noexcept
{
return static_cast<vulkan_runtime_module &>(runtime_function::module());
}
result<void> vulkan_runtime_function::initialize_core(runtime_function_init_context &context) noexcept
{
input_pool_size_ = context.header().input_pool_size;
output_pool_size_ = context.header().output_pool_size;
text_ = context.module_init_context().section(".text").subspan(context.header().entrypoint, context.header().text_size);
try_(initialize_vulkan_device());
try_(initialize_vulkan_memory());
try_(initialize_vulkan_commands());
return ok();
}
result<runtime_tensor> vulkan_runtime_function::allocate_input_tensor(size_t index) noexcept
{
return host_runtime_tensor::create(input_desc(index).datatype, input_shape(index));
}
result<runtime_tensor> vulkan_runtime_function::allocate_output_tensor(size_t index) noexcept
{
return host_runtime_tensor::create(output_desc(index).datatype, output_shape(index));
}
result<void> vulkan_runtime_function::validate_input_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
{
if (tensor.is_host() && tensor.is_contiguous())
return ok();
return err(std::errc::invalid_argument);
}
result<void> vulkan_runtime_function::validate_output_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
{
if (tensor.is_host() && tensor.is_contiguous())
return ok();
return err(std::errc::invalid_argument);
}
result<void> vulkan_runtime_function::initialize_vulkan_device() noexcept
{
return ok();
}
result<void> vulkan_runtime_function::initialize_vulkan_memory() noexcept
{
if (input_pool_size_)
{
try_set(input_buffer_, module().allocate_vulkan_buffer(input_pool_size_));
try_set(input_mem_, module().allocate_vulkan_memory({ vk::MemoryPropertyFlagBits::eHostVisible, vk::MemoryPropertyFlagBits::eHostCached, {} }, input_buffer_));
try_(module().bind_vulkan_buffer(input_buffer_, input_mem_));
}
if (output_pool_size_)
{
try_set(output_buffer_, module().allocate_vulkan_buffer(output_pool_size_));
try_set(output_mem_, module().allocate_vulkan_memory({ vk::MemoryPropertyFlagBits::eHostVisible, vk::MemoryPropertyFlagBits::eHostCached, {} }, output_buffer_));
try_(module().bind_vulkan_buffer(output_buffer_, output_mem_));
}
return ok();
}
result<void> vulkan_runtime_function::initialize_vulkan_commands() noexcept
{
vk::CommandBufferAllocateInfo cmdb_cinfo(module().command_pool(), vk::CommandBufferLevel::ePrimary, 1);
try_var(cmdbs, vk::to_result(module().device().allocateCommandBuffers(cmdb_cinfo)));
cmd_buffer_ = cmdbs[0];
vk::CommandBufferBeginInfo cmdb_info;
try_(vk::to_result(cmd_buffer_.begin(cmdb_info)));
try_(visit(text_));
try_(vk::to_result(cmd_buffer_.end()));
return ok();
}
result<void> vulkan_runtime_function::preprocess_inputs() noexcept
{
try_var(dest, vk::to_result(module().device().mapMemory(input_mem_, 0, VK_WHOLE_SIZE, {})));
for (size_t i = 0; i < inputs_size(); i++)
{
try_var(src_tensor, device_input_tensor(i));
try_var(src_map, hrt::map(src_tensor, hrt::map_read));
auto &desc = input_desc(i);
memcpy((uint8_t *)dest + desc.start, src_map.buffer().data(), desc.size);
}
vk::MappedMemoryRange range(input_mem_, 0, VK_WHOLE_SIZE);
try_(vk::to_result(module().device().flushMappedMemoryRanges(range)));
module().device().unmapMemory(input_mem_);
return ok();
}
result<void> vulkan_runtime_function::invoke_core() noexcept
{
try_(preprocess_inputs());
vk::SubmitInfo si({}, {}, cmd_buffer_, {});
try_(vk::to_result(module().compute_queue().submit(si)));
try_(vk::to_result(module().compute_queue().waitIdle()));
try_(postprocess_outputs());
return ok();
}
result<void> vulkan_runtime_function::postprocess_outputs() noexcept
{
try_var(src, vk::to_result(module().device().mapMemory(output_mem_, 0, VK_WHOLE_SIZE, {})));
vk::MappedMemoryRange range(output_mem_, 0, VK_WHOLE_SIZE);
try_(vk::to_result(module().device().invalidateMappedMemoryRanges(range)));
for (size_t i = 0; i < outputs_size(); i++)
{
try_var(dest_tensor, device_output_tensor(i));
try_var(dest_map, hrt::map(dest_tensor, hrt::map_write));
auto &desc = output_desc(i);
memcpy(dest_map.buffer().data(), (const uint8_t *)src + desc.start, desc.size);
}
module().device().unmapMemory(output_mem_);
return ok();
}
result<vulkan_runtime_function::buffer_ref> vulkan_runtime_function::pop_buffer_ref() noexcept
{
if (buffer_refs_.empty())
return err(std::errc::result_out_of_range);
auto buffer_ref = std::move(buffer_refs_.back());
buffer_refs_.pop_back();
return ok(std::move(buffer_ref));
}
void vulkan_runtime_function::free_vulkan_resources() noexcept
{
if (auto device = module().device())
{
if (module().command_pool())
device.freeCommandBuffers(module().command_pool(), cmd_buffer_);
device.destroyBuffer(input_buffer_);
device.destroyBuffer(output_buffer_);
device.freeMemory(input_mem_);
device.freeMemory(output_mem_);
}
}

View File

@ -0,0 +1,81 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "runtime_module.h"
#include <nncase/kernels/kernel_context.h>
#include <nncase/runtime/runtime_function.h>
#include <nncase/runtime/vulkan/op_reader.h>
#include <vulkan/vulkan.hpp>
BEGIN_NS_NNCASE_RT_MODULE(vulkan)
class vulkan_runtime_function : public runtime_function, private op_visitor
{
struct buffer_ref
{
vk::Buffer buffer;
size_t start;
size_t size;
};
public:
using runtime_function::runtime_function;
virtual ~vulkan_runtime_function();
vulkan_runtime_module &module() const noexcept;
protected:
result<void> initialize_core(runtime_function_init_context &context) noexcept override;
result<runtime_tensor> allocate_input_tensor(size_t index) noexcept override;
result<runtime_tensor> allocate_output_tensor(size_t index) noexcept override;
result<void> validate_input_tensor(size_t index, runtime_tensor tensor) noexcept override;
result<void> validate_output_tensor(size_t index, runtime_tensor tensor) noexcept override;
result<void> invoke_core() noexcept override;
using op_visitor::visit;
result<void> visit(const ldbuf_op_t &op) noexcept override;
result<void> visit(const ldbufbarrier_op_t &op) noexcept override;
result<void> visit(const ldbufcopy_op_t &op) noexcept override;
result<void> visit(const copybuf_op_t &op) noexcept override;
result<void> visit(const ldpipeline_op_t &op) noexcept override;
result<void> visit(const dispatch_op_t &op) noexcept override;
result<void> visit(const barrier_op_t &op) noexcept override;
private:
result<void> initialize_vulkan_device() noexcept;
result<void> initialize_vulkan_memory() noexcept;
result<void> initialize_vulkan_commands() noexcept;
result<buffer_ref> pop_buffer_ref() noexcept;
result<void> preprocess_inputs() noexcept;
result<void> postprocess_outputs() noexcept;
void free_vulkan_resources() noexcept;
private:
uint32_t input_pool_size_;
uint32_t output_pool_size_;
gsl::span<const gsl::byte> text_;
vk::Buffer input_buffer_;
vk::Buffer output_buffer_;
vk::DeviceMemory input_mem_;
vk::DeviceMemory output_mem_;
std::vector<buffer_ref> buffer_refs_;
vk::CommandBuffer cmd_buffer_;
std::vector<vk::BufferMemoryBarrier> buffer_barriers_;
std::vector<vk::BufferCopy> buffer_copies_;
};
END_NS_NNCASE_RT_MODULE

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include "runtime_module.h"
#include "runtime_function.h"
#include "vulkan_error.h"
#include <nncase/runtime/dbg.h>
#include <nncase/runtime/host_runtime_tensor.h>
@ -28,59 +29,26 @@ vulkan_runtime_module::~vulkan_runtime_module()
free_vulkan_resources();
}
result<void> vulkan_runtime_module::initialize_core(runtime_module_init_context &context) noexcept
result<void> vulkan_runtime_module::initialize_before_functions(runtime_module_init_context &context) noexcept
{
assert(context.is_section_pinned());
auto descs = context.section(DESCRIPTORS_SECTION_NAME).as_span<const uint32_t>();
descriptor_sets_ = descs[0];
descriptors_ = descs[1];
rdata_ = context.section(".rdata");
text_ = context.section(".text");
shader_ = context.section(".shader");
try_(initialize_vulkan());
// TODO: Load rdata
//rdata_ = context.section(".rdata");
return ok();
}
result<runtime_tensor> vulkan_runtime_module::allocate_input_tensor(size_t index) noexcept
{
return host_runtime_tensor::create(input_desc(index).datatype, input_shape(index));
}
result<runtime_tensor> vulkan_runtime_module::allocate_output_tensor(size_t index) noexcept
{
return host_runtime_tensor::create(output_desc(index).datatype, output_shape(index));
}
result<void> vulkan_runtime_module::validate_input_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
{
if (tensor.is_host() && tensor.is_contiguous())
return ok();
return err(std::errc::invalid_argument);
}
result<void> vulkan_runtime_module::validate_output_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
{
if (tensor.is_host() && tensor.is_contiguous())
return ok();
return err(std::errc::invalid_argument);
}
result<void> vulkan_runtime_module::initialize_vulkan() noexcept
{
try_(initialize_vulkan_instance());
try_(initialize_vulkan_device());
try_(initialize_vulkan_memory());
try_(initialize_vulkan_commands());
return ok();
}
result<void> vulkan_runtime_module::initialize_vulkan_commands() noexcept
{
vk::CommandBufferBeginInfo cmdb_info;
try_(vk::to_result(cmd_buffer_.begin(cmdb_info)));
try_(visit(text_));
try_(vk::to_result(cmd_buffer_.end()));
return ok();
}
@ -110,30 +78,11 @@ result<void> vulkan_runtime_module::initialize_vulkan_device() noexcept
vk::CommandPoolCreateInfo cmdp_cinfo({}, compute_queue_index_);
try_set(cmd_pool_, vk::to_result(device_.createCommandPool(cmdp_cinfo)));
vk::CommandBufferAllocateInfo cmdb_cinfo(cmd_pool_, vk::CommandBufferLevel::ePrimary, 1);
try_var(cmdbs, vk::to_result(device_.allocateCommandBuffers(cmdb_cinfo)));
cmd_buffer_ = cmdbs[0];
return ok();
}
result<void> vulkan_runtime_module::initialize_vulkan_memory() noexcept
{
auto input_mem = mempool(mem_input);
if (input_mem.size)
{
try_set(input_buffer_, allocate_vulkan_buffer(input_mem.size));
try_set(input_mem_, allocate_vulkan_memory({ vk::MemoryPropertyFlagBits::eHostVisible, vk::MemoryPropertyFlagBits::eHostCached, {} }, input_buffer_));
try_(bind_vulkan_buffer(input_buffer_, input_mem_));
}
auto output_mem = mempool(mem_output);
if (output_mem.size)
{
try_set(output_buffer_, allocate_vulkan_buffer(output_mem.size));
try_set(output_mem_, allocate_vulkan_memory({ vk::MemoryPropertyFlagBits::eHostVisible, vk::MemoryPropertyFlagBits::eHostCached, {} }, output_buffer_));
try_(bind_vulkan_buffer(output_buffer_, output_mem_));
}
auto data_mem = mempool(mem_data);
if (data_mem.size)
{
@ -165,6 +114,19 @@ result<void> vulkan_runtime_module::bind_vulkan_buffer(vk::Buffer buffer, vk::De
return vk::to_result(device_.bindBufferMemory(buffer, memory, 0));
}
result<void> vulkan_runtime_module::add_pipeline(vk::Pipeline pipeline) noexcept
{
try
{
pipelines_owner_.emplace_back(pipeline);
return ok();
}
catch (const std::bad_alloc &)
{
return err(std::errc::not_enough_memory);
}
}
result<vk::PhysicalDevice> vulkan_runtime_module::select_physical_device() noexcept
{
vk::PhysicalDevice *intergrated = nullptr;
@ -261,80 +223,29 @@ result<size_t> vulkan_runtime_module::select_memory_type(const vk::PhysicalDevic
return err(std::errc::not_enough_memory);
}
result<void> vulkan_runtime_module::run_core() noexcept
{
try_(preprocess_inputs());
vk::SubmitInfo si({}, {}, cmd_buffer_, {});
try_(vk::to_result(compute_queue_.submit(si)));
try_(vk::to_result(compute_queue_.waitIdle()));
try_(postprocess_outputs());
return ok();
}
result<void> vulkan_runtime_module::preprocess_inputs() noexcept
{
try_var(dest, vk::to_result(device_.mapMemory(input_mem_, 0, VK_WHOLE_SIZE, {})));
for (size_t i = 0; i < inputs_size(); i++)
{
try_var(src_tensor, device_input_tensor(i));
try_var(src_map, hrt::map(src_tensor, hrt::map_read));
auto &desc = input_desc(i);
memcpy((uint8_t *)dest + desc.start, src_map.buffer().data(), desc.size);
}
vk::MappedMemoryRange range(input_mem_, 0, VK_WHOLE_SIZE);
try_(vk::to_result(device_.flushMappedMemoryRanges(range)));
device_.unmapMemory(input_mem_);
return ok();
}
result<void> vulkan_runtime_module::postprocess_outputs() noexcept
{
try_var(src, vk::to_result(device_.mapMemory(output_mem_, 0, VK_WHOLE_SIZE, {})));
vk::MappedMemoryRange range(output_mem_, 0, VK_WHOLE_SIZE);
try_(vk::to_result(device_.invalidateMappedMemoryRanges(range)));
for (size_t i = 0; i < outputs_size(); i++)
{
try_var(dest_tensor, device_output_tensor(i));
try_var(dest_map, hrt::map(dest_tensor, hrt::map_write));
auto &desc = output_desc(i);
memcpy(dest_map.buffer().data(), (const uint8_t *)src + desc.start, desc.size);
}
device_.unmapMemory(output_mem_);
return ok();
}
result<vulkan_runtime_module::buffer_ref> vulkan_runtime_module::pop_buffer_ref() noexcept
{
if (buffer_refs_.empty())
return err(std::errc::result_out_of_range);
auto buffer_ref = std::move(buffer_refs_.back());
buffer_refs_.pop_back();
return ok(std::move(buffer_ref));
}
void vulkan_runtime_module::free_vulkan_resources() noexcept
{
device_.freeCommandBuffers(cmd_pool_, cmd_buffer_);
for (auto &func : functions())
func.reset();
device_.destroyCommandPool(cmd_pool_);
device_.destroyDescriptorPool(buffer_desc_pool_);
for (auto p : pipelines_owner_)
device_.destroyPipeline(p);
device_.destroyBuffer(input_buffer_);
device_.destroyBuffer(output_buffer_);
device_.destroyBuffer(data_buffer_);
device_.freeMemory(input_mem_);
device_.freeMemory(output_mem_);
device_.freeMemory(data_mem_);
device_.destroy({});
instance_.destroy({});
}
result<std::unique_ptr<runtime_function>> vulkan_runtime_module::create_function() noexcept
{
std::unique_ptr<runtime_function> mod(new (std::nothrow) vulkan_runtime_function(*this));
if (mod)
return ok(std::move(mod));
return err(std::errc::not_enough_memory);
}
result<std::unique_ptr<runtime_module>> vulkan::create_vulkan_runtime_module()
{
std::unique_ptr<runtime_module> mod(new (std::nothrow) vulkan_runtime_module());

View File

@ -20,67 +20,54 @@
BEGIN_NS_NNCASE_RT_MODULE(vulkan)
class vulkan_runtime_module : public runtime_module, private op_visitor
template <class T>
struct select_options
{
template <class T>
struct select_options
{
T requried;
T preferred;
T not_preferred;
};
T requried;
T preferred;
T not_preferred;
};
struct buffer_ref
{
vk::Buffer buffer;
size_t start;
size_t size;
};
class vulkan_runtime_module : public runtime_module
{
public:
virtual ~vulkan_runtime_module();
protected:
result<void> initialize_core(runtime_module_init_context &context) noexcept override;
result<runtime_tensor> allocate_input_tensor(size_t index) noexcept override;
result<runtime_tensor> allocate_output_tensor(size_t index) noexcept override;
result<void> validate_input_tensor(size_t index, runtime_tensor tensor) noexcept override;
result<void> validate_output_tensor(size_t index, runtime_tensor tensor) noexcept override;
result<void> run_core() noexcept override;
vk::Buffer data() const noexcept { return data_buffer_; }
vk::Buffer rdata() const noexcept { return {}; }
gsl::span<const gsl::byte> shader() const noexcept { return shader_; }
using op_visitor::visit;
result<void> visit(const ldbuf_op_t &op) noexcept override;
result<void> visit(const ldbufbarrier_op_t &op) noexcept override;
result<void> visit(const ldbufcopy_op_t &op) noexcept override;
result<void> visit(const copybuf_op_t &op) noexcept override;
result<void> visit(const ldpipeline_op_t &op) noexcept override;
result<void> visit(const dispatch_op_t &op) noexcept override;
result<void> visit(const barrier_op_t &op) noexcept override;
vk::Device device() const noexcept { return device_; }
vk::CommandPool command_pool() const noexcept { return cmd_pool_; }
uint32_t compute_queue_index() const noexcept { return compute_queue_index_; }
vk::Queue compute_queue() const noexcept { return compute_queue_; }
vk::DescriptorPool buffer_desc_pool() const noexcept { return buffer_desc_pool_; }
result<vk::DeviceMemory> allocate_vulkan_memory(const select_options<vk::MemoryPropertyFlagBits> &options, vk::Buffer buffer) noexcept;
result<vk::Buffer> allocate_vulkan_buffer(size_t required_size) noexcept;
result<void> bind_vulkan_buffer(vk::Buffer buffer, vk::DeviceMemory memory) noexcept;
result<void> add_pipeline(vk::Pipeline pipeline) noexcept;
protected:
result<void> initialize_before_functions(runtime_module_init_context &context) noexcept override;
result<std::unique_ptr<runtime_function>> create_function() noexcept override;
private:
result<void> initialize_vulkan() noexcept;
result<void> initialize_vulkan_instance() noexcept;
result<void> initialize_vulkan_device() noexcept;
result<void> initialize_vulkan_memory() noexcept;
result<void> initialize_vulkan_commands() noexcept;
result<vk::PhysicalDevice> select_physical_device() noexcept;
result<uint32_t> select_queue_family(const std::vector<vk::QueueFamilyProperties> &families, const select_options<vk::QueueFlagBits> options) noexcept;
result<size_t> select_memory_type(const vk::PhysicalDeviceMemoryProperties &properties, const select_options<vk::MemoryPropertyFlagBits> &options, size_t required_size) noexcept;
result<vk::DeviceMemory> allocate_vulkan_memory(const select_options<vk::MemoryPropertyFlagBits> &options, vk::Buffer buffer) noexcept;
result<vk::Buffer> allocate_vulkan_buffer(size_t required_size) noexcept;
result<void> bind_vulkan_buffer(vk::Buffer buffer, vk::DeviceMemory memory) noexcept;
result<buffer_ref> pop_buffer_ref() noexcept;
result<void> preprocess_inputs() noexcept;
result<void> postprocess_outputs() noexcept;
void free_vulkan_resources() noexcept;
private:
uint32_t descriptors_;
uint32_t descriptor_sets_;
gsl::span<const gsl::byte> rdata_;
gsl::span<const gsl::byte> text_;
gsl::span<const gsl::byte> shader_;
vk::Instance instance_;
@ -88,19 +75,11 @@ private:
vk::Device device_;
uint32_t compute_queue_index_;
vk::Queue compute_queue_;
vk::Buffer input_buffer_;
vk::Buffer output_buffer_;
vk::Buffer data_buffer_;
vk::DeviceMemory input_mem_;
vk::DeviceMemory output_mem_;
vk::DeviceMemory data_mem_;
std::vector<buffer_ref> buffer_refs_;
std::vector<vk::Pipeline> pipelines_owner_;
vk::DescriptorPool buffer_desc_pool_;
vk::CommandPool cmd_pool_;
vk::CommandBuffer cmd_buffer_;
std::vector<vk::BufferMemoryBarrier> buffer_barriers_;
std::vector<vk::BufferCopy> buffer_copies_;
};
END_NS_NNCASE_RT_MODULE

View File

@ -22,6 +22,7 @@
#include <nncase/ir/graph.h>
#include <nncase/runtime/interpreter.h>
#include <nncase/runtime/runtime_op_utility.h>
#include <nncase/schedule/scheduler.h>
#include <nncase/version.h>
#include <pybind11/iostream.h>
#include <pybind11/numpy.h>
@ -74,7 +75,7 @@ void LaunchDebugger()
}
#endif
schedule::schedule_result schedule(target &target, ir::graph &graph)
schedule::model_schedule_result schedule(target &target, ir::graph &graph)
{
schedule::scheduler sched(target, graph, graph.outputs());
return sched.schedule(true);
@ -115,7 +116,7 @@ public:
private:
ir::graph &graph_;
schedule::schedule_result schedule_result_;
schedule::model_schedule_result schedule_result_;
ir::evaluator evaluator_;
};
}

View File

@ -13,6 +13,7 @@
* limitations under the License.
*/
#include <nncase/codegen/model_builder.h>
#include <nncase/ir/op_utils.h>
#include <nncase/runtime/model.h>
#include <nncase/targets/target.h>
@ -22,7 +23,7 @@ using namespace nncase::ir;
using namespace nncase::schedule;
using namespace nncase::runtime;
model_builder::model_builder(target &target, const schedule::schedule_result &sched)
model_builder::model_builder(target &target, const schedule::model_schedule_result &sched)
: target_(target), sched_(sched), dump_asm_(false)
{
}
@ -41,6 +42,7 @@ build_model_result model_builder::build(std::ostream &output)
model_header header {};
header.identifier = MODEL_IDENTIFIER;
header.version = MODEL_VERSION;
header.header_size = sizeof(header);
header.flags = 0;
header.alignment = 8;
header.modules = (uint32_t)sched_.modules.size();
@ -49,19 +51,27 @@ build_model_result model_builder::build(std::ostream &output)
auto header_pos = writer.position();
writer.skip(sizeof(header));
uint32_t main_module_id = 0;
for (auto &graph : sched_.graph_orders)
for (auto &mod_sched : sched_.modules)
{
auto &mod = sched_.modules.at(graph);
module_builder_params params { sched_, mod };
auto builder = target_.create_module_builder(graph->module_type(), graph->name(), params);
builder->config_dump(dump_dir_ / graph->escaped_name(), dump_asm_);
module_builder_params params { sched_, mod_sched };
auto builder = target_.create_module_builder(mod_sched.type, mod_sched.type.data(), params);
builder->config_dump(dump_dir_ / mod_sched.type.data(), dump_asm_);
builder->build(writer);
header.alignment = std::max(header.alignment, builder->alignment());
}
if (graph == sched_.main_module)
header.main_module = main_module_id;
main_module_id++;
// Entry point
for (size_t i = 0; i < sched_.modules.size(); i++)
{
auto &mod_sched = sched_.modules[i];
for (size_t j = 0; j < mod_sched.functions.size(); j++)
{
if (sched_.entry_function == &mod_sched.functions[i])
{
header.entry_module = (uint32_t)i;
header.entry_function = (uint32_t)j;
}
}
}
auto end_pos = writer.position();
@ -78,11 +88,39 @@ build_model_result model_builder::build(std::ostream &output)
size_t model_builder::max_usage(memory_location_t location) const
{
size_t usage = 0;
for (auto &mod : sched_.modules)
if (location == mem_input)
{
auto it = mod.second.max_usages.find(location);
if (it != mod.second.max_usages.end())
usage += it->second;
// Only take into account of main function's inputs
auto graph = sched_.entry_function->graph;
auto &entry_in_allocs = sched_.entry_function->module->allocations;
for (auto in : graph->inputs())
usage += entry_in_allocs.at(&in->output()).size;
}
else if (location == mem_output)
{
// Only take into account of main function's outputs
auto graph = sched_.entry_function->graph;
auto &entry_out_allocs = sched_.entry_function->module->allocations;
for (auto out : graph->outputs())
usage += entry_out_allocs.at(out->input().connection()).size;
}
else if (location != mem_shared_data)
{
for (auto &mod : sched_.modules)
{
auto it = mod.max_usages.find(location);
if (it != mod.max_usages.end())
usage += it->second;
}
}
else
{
for (auto &mod : sched_.modules)
{
for (auto &shared : mod.shared_max_usages)
usage += shared.second;
}
}
return usage;

View File

@ -65,10 +65,10 @@ section_writer &module_builder::writer(std::string_view section_name)
return it->second.writer;
}
std::vector<nncase::ir::node *> module_builder::generate_runtime_ops()
std::vector<nncase::ir::node *> module_builder::generate_current_runtime_ops()
{
std::vector<nncase::ir::node *> runtime_ops;
for (auto &&node : params_.module_sched.compute_sequence)
for (auto &&node : current_function_->compute_sequence)
{
if (!non_runtime_opcodes.contains(node->runtime_opcode()))
runtime_ops.emplace_back(node);
@ -76,7 +76,8 @@ std::vector<nncase::ir::node *> module_builder::generate_runtime_ops()
if (dump_asm_)
{
std::ofstream file(dump_dir_ / "runtime_ops.txt");
std::ofstream file(dump_dir_ / current_function_->graph->escaped_name() / "runtime_ops.txt");
std::filesystem::create_directories(dump_dir_);
for (auto node : runtime_ops)
file << "[" << node->runtime_opcode().name << "] "
<< node->name() << std::endl;
@ -104,15 +105,18 @@ void module_builder::write_constants()
if (it != params_.module_sched.max_usages.end())
{
auto constants = std::make_unique<std::byte[]>(it->second);
for (auto &&node : params_.module_sched.compute_sequence)
for (auto &func_sched : params_.module_sched.functions)
{
if (auto con = node_cast<constant>(*node))
for (auto &&node : func_sched.compute_sequence)
{
if (con->output().memory_location() == mem_rdata)
if (auto con = node_cast<constant>(*node))
{
auto &alloc = allocation(con->output());
auto data = con->data();
std::memcpy(constants.get() + alloc.start, data.data(), data.size_bytes());
if (con->output().memory_location() == mem_rdata)
{
auto &alloc = allocation(con->output());
auto data = con->data();
std::memcpy(constants.get() + alloc.start, data.data(), data.size_bytes());
}
}
}
}
@ -125,11 +129,23 @@ void module_builder::compile()
{
write_constants();
auto runtime_ops = generate_runtime_ops();
begin_emit();
for (auto node : runtime_ops)
emit(*node);
end_emit();
begin_emit_module();
for (auto &func_sched : params_.module_sched.functions)
{
current_function_ = &func_sched;
auto runtime_ops = generate_current_runtime_ops();
begin_emit_function(func_sched);
for (auto node : runtime_ops)
emit(*node);
end_emit_function(func_sched);
if (!entry_points_.contains(current_function_))
throw std::runtime_error("Entry point for " + func_sched.graph->name() + " is not set");
}
end_emit_module();
if (dump_asm_)
{
@ -146,15 +162,34 @@ void module_builder::merge_to_rdata_section(std::string_view from)
rdata_section_merges_.emplace(from, std::in_place);
}
size_t module_builder::module_id(ir::graph *graph)
function_call_id module_builder::function_id(ir::graph *graph)
{
auto &orders = params_.sched.graph_orders;
auto it = std::find(orders.begin(), orders.end(), graph);
if (it != orders.end())
return (size_t)std::distance(orders.begin(), it);
for (size_t i = 0; i < params_.model_sched.modules.size(); i++)
{
auto &mod_sched = params_.model_sched.modules[i];
auto func_it = mod_sched.functions_map.find(graph);
if (func_it != mod_sched.functions_map.end())
{
auto func_sched = func_it->second;
auto &orders = mod_sched.functions;
if (func_sched >= orders.data() && func_sched < orders.data() + orders.size())
return { i, (size_t)(func_sched - orders.data()) };
}
}
throw std::invalid_argument("Can't find graph " + graph->name() + " in modules");
}
void module_builder::set_current_entry_point(std::streampos pos)
{
entry_points_[current_function_] = pos;
}
void module_builder::set_current_function_text_end(std::streampos pos)
{
function_text_end_[current_function_] = pos;
}
std::unique_ptr<section_decompiler> module_builder::create_decompiler([[maybe_unused]] std::string_view section_name)
{
return nullptr;
@ -272,12 +307,95 @@ void module_builder::link()
void module_builder::write_binary(binary_writer &writer)
{
// Skip module header
auto header_pos = writer.position();
writer.skip(sizeof(module_header));
// mempools
for (auto &mem : params_.module_sched.max_usages)
{
mempool_desc desc {};
desc.location = mem.first;
desc.size = mem.second;
writer.write(desc);
}
// functions
for (auto &func_sched : params_.module_sched.functions)
write_function_binary(writer, func_sched);
// sections
for (auto &section : section_writer_)
{
section_header header {};
strncpy(header.name, section.first.c_str(), std::size(header.name) - 1);
auto merge_it = rdata_section_merges_.find(section.first);
if (merge_it == rdata_section_merges_.end())
{
header.flags = 0;
header.body_start = 0;
header.body_size = (uint32_t)section.second.body.size();
}
else
{
header.flags = SECTION_MERGED_INTO_RDATA;
header.body_start = merge_it->second.start;
header.body_size = merge_it->second.size;
}
// Skip section header
auto header_pos = writer.position();
writer.skip(sizeof(header));
if (merge_it == rdata_section_merges_.end())
{
header.body_start = (uint32_t)writer.align_position(alignment_);
// write content
writer.write_array(std::span<uint8_t const>(section.second.body));
}
// write section header
auto end_pos = writer.position();
writer.position(header_pos);
writer.write(header);
writer.position(end_pos);
}
writer.align_position(8);
auto end_pos = writer.position();
// header
module_header header {};
header.type = module_type();
header.version = module_version();
header.header_size = sizeof(header);
header.size = (uint32_t)(end_pos - header_pos);
header.mempools = (uint32_t)params_.module_sched.max_usages.size();
header.shared_mempools = (uint32_t)params_.module_sched.shared_max_usages.size();
header.functions = (uint32_t)params_.module_sched.functions.size();
header.sections = (uint32_t)section_writer_.size();
header.reserved0 = 0;
writer.position(header_pos);
writer.write(header);
writer.position(end_pos);
}
void module_builder::write_function_binary(binary_writer &writer, const schedule::function_schedule_result &function_sched)
{
auto write_shape = [&](const shape_t &shape) {
writer.write((uint32_t)shape.size());
for (auto dim : shape)
writer.write((uint32_t)dim);
};
std::vector<memory_range> inputs;
std::vector<shape_t> input_shapes;
std::vector<memory_range> outputs;
std::vector<shape_t> output_shapes;
for (auto &&node : params_.module_sched.compute_sequence)
for (auto &&node : function_sched.compute_sequence)
{
if (auto in = node_cast<input_node>(*node))
{
@ -293,24 +411,9 @@ void module_builder::write_binary(binary_writer &writer)
}
}
// Skip module header
// Skip function header
auto header_pos = writer.position();
writer.skip(sizeof(module_header));
auto write_shape = [&](const shape_t &shape) {
writer.write((uint32_t)shape.size());
for (auto dim : shape)
writer.write((uint32_t)dim);
};
// mempools
for (auto &mem : params_.module_sched.max_usages)
{
mempool_desc desc {};
desc.location = mem.first;
desc.size = mem.second;
writer.write(desc);
}
writer.skip(sizeof(function_header));
// inputs
writer.write_array<memory_range>(inputs);
@ -322,55 +425,20 @@ void module_builder::write_binary(binary_writer &writer)
for (auto &shape : output_shapes)
write_shape(shape);
// sections
for (auto &section : section_writer_)
{
section_header header {};
strncpy(header.name, section.first.c_str(), std::size(header.name) - 1);
auto merge_it = rdata_section_merges_.find(section.first);
if (merge_it == rdata_section_merges_.end())
{
header.flags = 0;
header.start = 0;
header.size = (uint32_t)section.second.body.size();
}
else
{
header.flags = SECTION_MERGED_INTO_RDATA;
header.start = merge_it->second.start;
header.size = merge_it->second.size;
}
// Skip section header
auto header_pos = writer.position();
writer.skip(sizeof(header));
if (merge_it == rdata_section_merges_.end())
{
header.start = (uint32_t)writer.align_position(alignment_);
// write content
writer.write_array(std::span<uint8_t const>(section.second.body));
}
// write section header
auto end_pos = writer.position();
writer.position(header_pos);
writer.write(header);
writer.position(end_pos);
}
writer.align_position(8);
auto end_pos = writer.position();
// header
module_header header {};
header.type = module_type();
header.size = (uint32_t)(end_pos - header_pos - sizeof(module_header));
header.mempools = (uint32_t)params_.module_sched.max_usages.size();
function_header header {};
header.header_size = sizeof(header);
header.size = (uint32_t)(end_pos - header_pos);
header.input_pool_size = (uint32_t)function_sched.input_pool_size;
header.output_pool_size = (uint32_t)function_sched.output_pool_size;
header.inputs = (uint32_t)inputs.size();
header.outputs = (uint32_t)outputs.size();
header.sections = (uint32_t)section_writer_.size();
header.reserved0 = 0;
auto entrypoint = entry_points_.at(&function_sched);
header.entrypoint = (uint32_t)entrypoint;
header.text_size = (uint32_t)(function_text_end_.at(&function_sched) - entrypoint);
writer.position(header_pos);
writer.write(header);
@ -383,3 +451,19 @@ void module_builder::build(binary_writer &writer)
link();
write_binary(writer);
}
void module_builder::begin_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
{
}
void module_builder::end_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
{
}
void module_builder::begin_emit_module()
{
}
void module_builder::end_emit_module()
{
}

View File

@ -39,11 +39,26 @@ module_type_t stackvm_module_builder::module_type() const noexcept
return stackvm_module_type;
}
uint32_t stackvm_module_builder::module_version() const noexcept
{
return stackvm_module_version;
}
section_writer &stackvm_module_builder::text_writer()
{
return writer(".text");
}
void stackvm_module_builder::begin_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
{
set_current_entry_point(text_writer().position());
}
void stackvm_module_builder::end_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
{
set_current_function_text_end(text_writer().position());
}
void stackvm_module_builder::emit(ir::node &node)
{
stackvm_op_builder builder(node, text_writer());

View File

@ -59,10 +59,13 @@ public:
stackvm_module_builder(std::string_view module_name, const module_builder_params &params);
module_type_t module_type() const noexcept override;
uint32_t module_version() const noexcept override;
protected:
section_writer &text_writer();
void begin_emit_function(const schedule::function_schedule_result &function) override;
void end_emit_function(const schedule::function_schedule_result &function) override;
void emit(ir::node &node) override;
private:

View File

@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/7/14 19:17:49 +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/8/11 17:40:11 +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
@ -513,9 +513,9 @@ void op_builder::tensor_binary_(datatype_t datatype, uint8_t rshape_src1, uint8_
op_writer<tensor_binary_op_t>()(tensor_binary_op_t(datatype, rshape_src1, rstride_src1, rshape_src2, rstride_src2, rstride_dest, binary_op, fused_clamp_low, fused_clamp_high), writer_);
}
void op_builder::tensor_call_(uint32_t module_id, uint8_t num_src, uint8_t num_dst)
void op_builder::tensor_call_(uint32_t function_id, uint16_t module_id, uint8_t num_src, uint8_t num_dst)
{
op_writer<tensor_call_op_t>()(tensor_call_op_t(module_id, num_src, num_dst), writer_);
op_writer<tensor_call_op_t>()(tensor_call_op_t(function_id, module_id, num_src, num_dst), writer_);
}
void op_builder::tensor_conv2d_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rshape_kernel, uint8_t rstride_kernel, uint8_t rstride_bias, uint8_t rstride_dest, uint16_t groups, uint16_t stride_h, uint16_t stride_w, uint16_t dilation_h, uint16_t dilation_w, float fused_clamp_low, float fused_clamp_high)

View File

@ -21,7 +21,7 @@ using namespace nncase::ir;
void stackvm_module_builder::emit(call &node, stackvm_op_builder &builder)
{
auto target_id = module_id(&node.target());
auto target_id = function_id(&node.target());
uint8_t rshape = 0;
for (auto in : node.inputs())
@ -46,5 +46,5 @@ void stackvm_module_builder::emit(call &node, stackvm_op_builder &builder)
builder.ldc_i4_(rshape++);
}
builder.tensor_call_(target_id, (uint8_t)node.inputs().size(), (uint8_t)node.outputs().size());
builder.tensor_call_(target_id.function_id, target_id.module_id, (uint8_t)node.inputs().size(), (uint8_t)node.outputs().size());
}

View File

@ -2,6 +2,7 @@
set(SRCS evaluator.cpp
quantizer.cpp
evaluate_context.cpp
ops/neutral/neutral_ops.cpp)
add_library(evaluator OBJECT ${SRCS})

View File

@ -0,0 +1,238 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <chrono>
#include <nncase/ir/evaluator.h>
#include <nncase/ir/op_utils.h>
#include <nncase/ir/ops/constant.h>
#include <nncase/ir/quantizer.h>
#include <nncase/ir/runtime_type_utils.h>
#include <nncase/targets/target.h>
using namespace nncase;
using namespace nncase::ir;
using namespace nncase::schedule;
using namespace nncase::runtime;
namespace chrono = std::chrono;
namespace
{
std::unordered_map<node_opcode, std::function<void(ir::node &, function_evaluate_context &)>> g_evaluators;
auto &get_evaluator(node_opcode opcode)
{
auto it = g_evaluators.find(opcode);
if (it == std::end(g_evaluators))
throw std::runtime_error("Evaluator for " + std::string(opcode.name) + " is not found");
return it->second;
}
}
void nncase::ir::register_evaluator(ir::node_opcode opcode, std::function<void(ir::node &, function_evaluate_context &)> evaluator)
{
g_evaluators.emplace(opcode, std::move(evaluator));
}
evaluate_tensor::evaluate_tensor(datatype_t datatype, runtime_shape_t shape, runtime_shape_t strides, gsl::span<gsl::byte> buffer)
: datatype_(datatype), shape_(std::move(shape)), strides_(std::move(strides)), buffer_(buffer)
{
}
function_evaluate_context::function_evaluate_context(const function_schedule_result &sched, module_evaluate_context &mod_eval)
: sched_(sched), mod_eval_(mod_eval)
{
input_pool_ = std::make_unique<std::byte[]>(sched.input_pool_size);
output_pool_ = std::make_unique<std::byte[]>(sched.output_pool_size);
for (auto &&node : sched.compute_sequence)
{
auto &opcode = node->runtime_opcode();
if (opcode == op_input_node)
{
inputs_.emplace_back(&node->output_at(0));
}
else if (opcode == op_output_node)
{
outputs_.emplace_back(&node->input_at(0));
}
else if (opcode == op_constant)
{
auto &rnode = static_cast<constant &>(*node);
auto src = rnode.data();
auto dest = memory_at(rnode.output()).buffer().as_span<std::byte>();
std::copy(std::begin(src), std::end(src), dest.begin());
}
}
}
evaluate_tensor function_evaluate_context::memory_at(const output_connector &conn)
{
auto &alloc = module().sched().allocations.at(&conn);
std::byte *base;
switch (alloc.memory_location)
{
case mem_input:
base = input_pool_.get();
break;
case mem_output:
base = output_pool_.get();
break;
default:
base = module().memory_pool(alloc.memory_location);
break;
}
gsl::span<gsl::byte> buffer(reinterpret_cast<gsl::byte *>(base + alloc.start), alloc.size);
return evaluate_tensor(alloc.type, to(alloc.shape), to(alloc.strides), buffer);
}
void function_evaluate_context::evaluate()
{
using clock = chrono::high_resolution_clock;
chrono::nanoseconds total_duration = {};
auto quantizer = module().quantizer();
for (auto &&node : sched_.compute_sequence)
{
auto &evaluator = get_evaluator(node->runtime_opcode());
auto start = clock::now();
evaluator(*node, *this);
auto duration = clock::now() - start;
total_duration += duration;
if (quantizer)
{
for (auto out : node->outputs())
{
if (out->attributes() & cnctr_attr_need_quantize)
{
if (!quantizer->has_record(*out))
{
auto mem = memory_at(*out);
auto dtype = mem.datatype();
if (dtype == dt_bfloat16)
{
auto buffer = mem.buffer().as_span<bfloat16>();
quantizer->record(*out, buffer);
}
else if (dtype == dt_float32)
{
auto buffer = mem.buffer().as_span<float>();
quantizer->record(*out, buffer);
}
else
{
throw std::runtime_error("Quantizer doesn't support datatype of " + std::string(datatype_names(dtype)));
}
}
}
}
}
}
}
module_evaluate_context::module_evaluate_context(const module_schedule_result &sched, model_evaluate_context &model_eval)
: sched_(sched), model_eval_(model_eval), quantizer_(nullptr)
{
for (auto &&usage : sched.max_usages)
memory_pools_.emplace(usage.first, std::make_unique<std::byte[]>(usage.second));
for (auto &func : sched.functions)
{
functions_.emplace(std::piecewise_construct,
std::forward_as_tuple(func.graph),
std::forward_as_tuple(func, *this));
}
}
std::byte *module_evaluate_context::memory_pool(memory_location_t location) const
{
return memory_pools_.at(location).get();
}
function_evaluate_context &module_evaluate_context::function(ir::graph &function)
{
return functions_.at(&function);
}
void module_evaluate_context::enable_ptq(target &target, ir::calibrate_method calib_method)
{
quantizer_ = target.create_quantizer(sched_.type, calib_method);
}
void module_evaluate_context::begin_collect_distribution()
{
if (quantizer_)
quantizer_->begin_collect_distribution();
}
void module_evaluate_context::end_sample()
{
if (quantizer_)
quantizer_->end_sample();
}
void module_evaluate_context::end_collect_distribution(const std::function<void(size_t cnt, size_t total)> &progress)
{
if (quantizer_)
quantizer_->end_collect_distribution(progress);
}
model_evaluate_context::model_evaluate_context(const schedule::model_schedule_result &sched)
: sched_(sched)
{
for (auto &module : sched.modules)
module_ctxs_.emplace(std::piecewise_construct, std::forward_as_tuple(module.type), std::forward_as_tuple(module, *this));
}
function_evaluate_context &model_evaluate_context::entrypoint()
{
auto func = sched_.entry_function;
return module_ctxs_.at(func->module->type).function(*func->graph);
}
module_evaluate_context &model_evaluate_context::module(const module_type_t &module_type)
{
return module_ctxs_.at(module_type);
}
void model_evaluate_context::enable_ptq(target &target, ir::calibrate_method calib_method)
{
for (auto &mod : module_ctxs_)
mod.second.enable_ptq(target, calib_method);
}
void model_evaluate_context::begin_collect_distribution()
{
for (auto &mod : module_ctxs_)
mod.second.begin_collect_distribution();
}
void model_evaluate_context::end_sample()
{
for (auto &mod : module_ctxs_)
mod.second.end_sample();
}
void model_evaluate_context::end_collect_distribution(const std::function<void(size_t cnt, size_t total)> &progress)
{
for (auto &mod : module_ctxs_)
mod.second.end_collect_distribution(progress);
}
void model_evaluate_context::evaluate()
{
entrypoint().evaluate();
}

View File

@ -12,196 +12,66 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nncase/ir/quantizer.h"
#include <chrono>
#include <nncase/ir/evaluator.h>
#include <nncase/ir/op_utils.h>
#include <nncase/ir/ops/constant.h>
#include <nncase/ir/runtime_type_utils.h>
#include <nncase/targets/target.h>
using namespace nncase;
using namespace nncase::ir;
using namespace nncase::schedule;
using namespace nncase::runtime;
namespace chrono = std::chrono;
#define PROFILE 0
namespace
evaluator::evaluator(const schedule::model_schedule_result &sched)
: model_eval_(sched)
{
std::unordered_map<node_opcode, std::function<void(ir::node &, module_evaluate_context &)>> g_evaluators;
auto &get_evaluator(node_opcode opcode)
{
auto it = g_evaluators.find(opcode);
if (it == std::end(g_evaluators))
throw std::runtime_error("Evaluator for " + std::string(opcode.name) + " is not found");
return it->second;
}
}
void nncase::ir::register_evaluator(ir::node_opcode opcode, std::function<void(ir::node &, module_evaluate_context &)> evaluator)
{
g_evaluators.emplace(opcode, std::move(evaluator));
}
evaluate_tensor::evaluate_tensor(datatype_t datatype, runtime_shape_t shape, runtime_shape_t strides, gsl::span<gsl::byte> buffer)
: datatype_(datatype), shape_(std::move(shape)), strides_(std::move(strides)), buffer_(buffer)
{
}
module_evaluate_context::module_evaluate_context(const module_schedule_result &sched)
: sched_(sched), quantizer_(nullptr)
{
for (auto &&usage : sched.max_usages)
memory_pools_.emplace(usage.first, std::make_unique<std::byte[]>(usage.second));
for (auto &&node : sched.compute_sequence)
{
auto &opcode = node->runtime_opcode();
if (opcode == op_input_node)
{
inputs_.emplace_back(&node->output_at(0));
}
else if (opcode == op_output_node)
{
outputs_.emplace_back(&node->input_at(0));
}
else if (opcode == op_constant)
{
auto &rnode = static_cast<constant &>(*node);
auto src = rnode.data();
auto dest = memory_at(rnode.output()).buffer().as_span<std::byte>();
std::copy(std::begin(src), std::end(src), dest.begin());
}
}
}
evaluate_tensor module_evaluate_context::memory_at(const output_connector &conn)
{
auto &alloc = sched_.allocations.at(&conn);
auto &memory_pool = memory_pools_.at(alloc.memory_location);
gsl::span<gsl::byte> buffer(reinterpret_cast<gsl::byte *>(memory_pool.get() + alloc.start), alloc.size);
return evaluate_tensor(alloc.type, to(alloc.shape), to(alloc.strides), buffer);
}
void module_evaluate_context::enable_ptq(target &target, ir::calibrate_method calib_method)
{
quantizer_ = target.create_quantizer(sched_.graph->module_type(), calib_method);
}
void module_evaluate_context::evaluate()
{
using clock = chrono::high_resolution_clock;
chrono::nanoseconds total_duration = {};
for (auto &&node : sched_.compute_sequence)
{
auto &evaluator = get_evaluator(node->runtime_opcode());
auto start = clock::now();
evaluator(*node, *this);
auto duration = clock::now() - start;
total_duration += duration;
#if PROFILE
std::cout << node->name() << "/" << node->runtime_opcode().name << ": " << std::endl;
#endif
if (quantizer_)
{
for (auto out : node->outputs())
{
if (out->attributes() & cnctr_attr_need_quantize)
{
if (!quantizer_->has_record(*out))
{
auto mem = memory_at(*out);
if (mem.datatype() == dt_bfloat16)
{
auto buffer = mem.buffer().as_span<bfloat16>();
quantizer_->record(*out, buffer);
}
else
{
auto buffer = mem.buffer().as_span<float>();
quantizer_->record(*out, buffer);
}
}
#if PROFILE
std::cout << "\t" << out->name() << " range: { " << quantizer_->get(*out).min << " ~ " << quantizer_->get(*out).max << " } , ";
#endif
}
}
// quantizer_->range()
}
#if PROFILE
std::cout << "\t duration: " << duration.count() / 1e6 << "ms" << std::endl;
#endif
}
if (quantizer_)
quantizer_->reset_record();
#if PROFILE
std::cout << "Total: " << total_duration.count() / 1e6 << "ms" << std::endl;
#endif
}
void module_evaluate_context::begin_collect_distribution()
{
if (quantizer_)
quantizer_->begin_collect_distribution();
}
void module_evaluate_context::end_collect_distribution(std::function<void(size_t cnt, size_t total)> progress)
{
if (quantizer_)
quantizer_->end_collect_distribution(progress);
}
evaluator::evaluator(const schedule::schedule_result &sched)
: sched_(sched)
{
for (auto &module_p : sched.modules)
module_ctxs_.emplace(module_p.first, module_p.second);
}
module_evaluate_context &evaluator::module_context(ir::graph &graph)
{
return module_ctxs_.at(&graph);
}
module_evaluate_context &evaluator::main_module_context()
{
return module_context(*sched_.main_module);
}
evaluate_tensor evaluator::memory_at(const output_connector &conn)
{
return main_module_context().memory_at(conn);
}
void evaluator::enable_ptq(target &target, ir::calibrate_method calib_method)
{
for (auto &module_p : module_ctxs_)
module_p.second.enable_ptq(target, calib_method);
return model_eval_.enable_ptq(target, calib_method);
}
void evaluator::evaluate()
{
module_ctxs_.at(sched_.main_module).evaluate();
return model_eval_.evaluate();
}
quantizer *evaluator::quantizer(const module_type_t &module_type)
{
return model_eval_.module(module_type).quantizer();
}
void evaluator::begin_collect_distribution()
{
for (auto &module_p : module_ctxs_)
module_p.second.begin_collect_distribution();
model_eval_.end_sample();
}
void evaluator::end_collect_distribution(std::function<void(size_t cnt, size_t total)> progress)
void evaluator::end_sample()
{
for (auto &module_p : module_ctxs_)
module_p.second.end_collect_distribution(progress);
model_eval_.end_sample();
}
void evaluator::end_collect_distribution(const std::function<void(size_t cnt, size_t total)> &progress)
{
model_eval_.end_collect_distribution(progress);
}
evaluate_tensor evaluator::memory_at(const output_connector &conn)
{
return model_eval_.memory_at(conn);
}
evaluate_tensor evaluator::memory_at(const input_connector &conn)
{
return model_eval_.memory_at(conn);
}
evaluate_tensor evaluator::input_at(size_t index)
{
return model_eval_.input_at(index);
}
evaluate_tensor evaluator::output_at(size_t index)
{
return model_eval_.output_at(index);
}

View File

@ -70,7 +70,7 @@ using namespace nncase::runtime;
namespace
{
void nop_evaluator(ir::node &, module_evaluate_context &)
void nop_evaluator(ir::node &, function_evaluate_context &)
{
}
}
@ -84,7 +84,7 @@ void register_neutral_evaluators()
register_evaluator(op_ignore_node, nop_evaluator);
register_evaluator(op_constant, nop_evaluator);
register_evaluator(op_batch_to_space, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_batch_to_space, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<batch_to_space &>(node);
auto input = context.memory_at(rnode.input());
@ -97,7 +97,7 @@ void register_neutral_evaluators()
.unwrap_or_throw();
});
register_evaluator(op_binary, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_binary, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<binary &>(node);
assert(rnode.input_a().type() == dt_float32);
@ -112,7 +112,7 @@ void register_neutral_evaluators()
.unwrap_or_throw();
});
register_evaluator(op_broadcast, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_broadcast, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<broadcast &>(node);
auto input = context.memory_at(rnode.input());
@ -122,7 +122,7 @@ void register_neutral_evaluators()
.unwrap_or_throw();
});
register_evaluator(op_concat, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_concat, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<concat &>(node);
std::vector<const gsl::byte *> inputs_mem;
@ -141,7 +141,7 @@ void register_neutral_evaluators()
.unwrap_or_throw();
});
register_evaluator(op_conv2d, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_conv2d, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<conv2d &>(node);
assert(rnode.input().type() == dt_float32);
@ -161,7 +161,7 @@ void register_neutral_evaluators()
.unwrap_or_throw();
});
register_evaluator(op_conv2d_transpose, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_conv2d_transpose, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<conv2d_transpose &>(node);
assert(rnode.input().type() == dt_float32);
@ -175,7 +175,7 @@ void register_neutral_evaluators()
rnode.dilation_h(), rnode.dilation_w(), rnode.padding_h(), rnode.padding_w(), rnode.fused_activation());
});
register_evaluator(op_dequantize, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_dequantize, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<dequantize &>(node);
auto output = context.memory_at(rnode.output()).buffer().as_span<float>();
@ -199,7 +199,7 @@ void register_neutral_evaluators()
}
});
register_evaluator(op_fused_unary, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_fused_unary, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<fused_unary &>(node);
auto input = context.memory_at(rnode.input()).buffer().as_span<float>();
@ -217,7 +217,7 @@ void register_neutral_evaluators()
.unwrap_or_throw();
});
register_evaluator(op_matmul, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_matmul, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<matmul &>(node);
assert(rnode.input_a().type() == dt_float32);
@ -233,7 +233,7 @@ void register_neutral_evaluators()
neutral::matmul(input_a.data(), input_b.data(), output.data(), bias.data(), (int32_t)a_shape[0], (int32_t)a_shape[1], (int32_t)b_shape[1], rnode.fused_activation());
});
register_evaluator(op_pad, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_pad, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<pad &>(node);
auto input = context.memory_at(rnode.input());
@ -246,7 +246,7 @@ void register_neutral_evaluators()
.unwrap_or_throw();
});
register_evaluator(op_quantize, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_quantize, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<quantize &>(node);
auto input = context.memory_at(rnode.input()).buffer().as_span<float>();
@ -255,7 +255,7 @@ void register_neutral_evaluators()
neutral::quantize(input.data(), output.data(), xt::compute_size(rnode.input().shape()), rnode.quant_param());
});
register_evaluator(op_reduce, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_reduce, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<reduce &>(node);
assert(rnode.input().type() == dt_float32);
@ -269,7 +269,7 @@ void register_neutral_evaluators()
.unwrap_or_throw();
});
register_evaluator(op_reduce_window2d, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_reduce_window2d, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<reduce_window2d &>(node);
assert(rnode.input().type() == dt_float32);
@ -284,7 +284,7 @@ void register_neutral_evaluators()
.unwrap_or_throw();
});
register_evaluator(op_bitcast, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_bitcast, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<bitcast &>(node);
auto input = context.memory_at(rnode.input()).buffer();
@ -293,7 +293,7 @@ void register_neutral_evaluators()
std::copy(input.begin(), input.end(), output.begin());
});
register_evaluator(op_resize_image, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_resize_image, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<resize_image &>(node);
auto input = context.memory_at(rnode.input());
@ -317,7 +317,7 @@ void register_neutral_evaluators()
}
});
register_evaluator(op_slice, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_slice, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<slice &>(node);
auto input = context.memory_at(rnode.input());
@ -330,7 +330,7 @@ void register_neutral_evaluators()
.unwrap_or_throw();
});
register_evaluator(op_transpose, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_transpose, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<transpose &>(node);
auto input = context.memory_at(rnode.input());
@ -343,7 +343,7 @@ void register_neutral_evaluators()
.unwrap_or_throw();
});
register_evaluator(op_unary, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_unary, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<unary &>(node);
assert(rnode.input().type() == dt_float32);
@ -416,7 +416,7 @@ void register_neutral_evaluators()
}
});
register_evaluator(op_table_lookup1d, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_table_lookup1d, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<table_lookup1d &>(node);
assert(rnode.input().type() == dt_uint8);
@ -427,7 +427,7 @@ void register_neutral_evaluators()
kernels::neutral::table_lookup1d(input.data(), output.data(), input.size(), table.data());
});
register_evaluator(op_clamp, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_clamp, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<clamp &>(node);
assert(rnode.input().type() == dt_float32);
@ -446,7 +446,7 @@ void register_neutral_evaluators()
}
});
register_evaluator(op_convert, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_convert, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<convert &>(node);
auto input = context.memory_at(rnode.input());
@ -459,7 +459,7 @@ void register_neutral_evaluators()
.unwrap_or_throw();
});
register_evaluator(op_gather, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_gather, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<gather &>(node);
auto input = context.memory_at(rnode.input());
@ -473,7 +473,7 @@ void register_neutral_evaluators()
.unwrap_or_throw();
});
register_evaluator(op_gather_nd, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_gather_nd, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<gather_nd &>(node);
auto input = context.memory_at(rnode.input());
@ -487,7 +487,7 @@ void register_neutral_evaluators()
.unwrap_or_throw();
});
register_evaluator(op_onehot, [](ir::node &node, module_evaluate_context &context) {
register_evaluator(op_onehot, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<onehot &>(node);
auto indices = context.memory_at(rnode.indices());

View File

@ -356,7 +356,7 @@ private:
{
auto graph_runner = [&](ir::graph &graph) {
ir::transforms::pass_manager pmgr(graph, *target_);
auto quant = evaluator.module_context(graph).quantizer();
auto quant = evaluator.quantizer(graph.module_type());
if (!compile_options_.use_dataset_as_input_stat)
{
@ -486,6 +486,7 @@ private:
std::memcpy(input_buffer.data(), tensor.data(), input_buffer.size_bytes());
evaluator.evaluate();
evaluator.end_sample();
if (options.progress)
options.progress(i++, dataset.total_size());
}
@ -519,6 +520,7 @@ private:
std::memcpy(input_buffer.data(), options.tensor_data.data() + i * input_buffer.size_bytes(), input_buffer.size_bytes());
evaluator.evaluate();
evaluator.end_sample();
if (options.progress)
options.progress(i++, options.samples_count);
}

View File

@ -3,9 +3,11 @@
set(SRCS interpreter.cpp
error.cpp
runtime_loader.cpp
runtime_function.cpp
runtime_module.cpp
runtime_tensor.cpp
runtime_tensor_impl.cpp
section.cpp
host_runtime_tensor.cpp
allocator.cpp)

View File

@ -14,6 +14,7 @@
*/
#include <cassert>
#include <iostream>
#include <nncase/runtime/dbg.h>
#include <nncase/runtime/error.h>
#include <nncase/runtime/interpreter.h>
#include <nncase/runtime/runtime_loader.h>
@ -23,7 +24,7 @@ using namespace nncase;
using namespace nncase::runtime;
interpreter::interpreter() noexcept
: main_module_(nullptr)
: entry_function_(nullptr)
{
}
@ -49,82 +50,79 @@ result<void> interpreter::load_model(gsl::span<const gsl::byte> buffer) noexcept
for (size_t i = 0; i < header->modules; i++)
{
auto mod_header = reader.get_ref<module_header>();
reader.skip(mod_header->size);
try_var(rt_module, runtime_module::create(mod_header->type));
auto mod_type = reader.peek_with_offset<decltype(module_header::type)>(offsetof(module_header, type));
auto mod_size = reader.peek_with_offset<decltype(module_header::size)>(offsetof(module_header, size));
auto payload = reader.read_span(mod_size);
try_var(rt_module, runtime_module::create(mod_type));
try_(rt_module->initialize(*mod_header, *this));
if (i == header->main_module)
main_module_ = rt_module.get();
try_(rt_module->initialize(payload, *this));
if (i == header->entry_module)
try_set(entry_function_, rt_module->find_function_by_id(header->entry_function));
modules_[i] = std::move(rt_module);
}
for (auto &mod : modules_)
try_(mod->initialize_inter_modules(*this));
return ok();
}
size_t interpreter::inputs_size() const noexcept
{
return main_module_->inputs_size();
return entry_function_->inputs_size();
}
size_t interpreter::outputs_size() const noexcept
{
return main_module_->outputs_size();
return entry_function_->outputs_size();
}
const memory_range &interpreter::input_desc(size_t index) const noexcept
{
return main_module_->input_desc(index);
return entry_function_->input_desc(index);
}
const memory_range &interpreter::output_desc(size_t index) const noexcept
{
return main_module_->output_desc(index);
return entry_function_->output_desc(index);
}
const runtime_shape_t &interpreter::input_shape(size_t index) const noexcept
{
return main_module_->input_shape(index);
return entry_function_->input_shape(index);
}
const runtime_shape_t &interpreter::output_shape(size_t index) const noexcept
{
return main_module_->output_shape(index);
return entry_function_->output_shape(index);
}
result<runtime_tensor> interpreter::input_tensor(size_t index) noexcept
{
return main_module_->input_tensor(index);
return entry_function_->input_tensor(index);
}
result<void> interpreter::input_tensor(size_t index, runtime_tensor tensor) noexcept
{
return main_module_->input_tensor(index, tensor);
return entry_function_->input_tensor(index, tensor);
}
result<runtime_tensor> interpreter::output_tensor(size_t index) noexcept
{
return main_module_->output_tensor(index);
return entry_function_->output_tensor(index);
}
result<void> interpreter::output_tensor(size_t index, runtime_tensor tensor) noexcept
{
return main_module_->output_tensor(index, tensor);
return entry_function_->output_tensor(index, tensor);
}
result<void> interpreter::run() noexcept
{
return main_module_->run();
return entry_function_->invoke();
}
result<runtime_module *> interpreter::find_module_by_id(size_t index) noexcept
{
if (index < modules_.size())
return ok(modules_[index].get());
return err(std::errc::result_out_of_range);
CHECK_WITH_ERR(index < modules_.size(), std::errc::result_out_of_range);
return ok(modules_[index].get());
}
options_dict &interpreter::options() noexcept

View File

@ -1,80 +0,0 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <runtime/cpu/cpu_ops_body.h>
#include <runtime/k210/k210_ops_body.h>
#include <runtime/kernel_registry.h>
#include <runtime/neutral/neutral_ops_body.h>
#include <runtime/span_reader.h>
using namespace nncase;
using namespace nncase::runtime;
namespace nncase
{
namespace runtime
{
#define BEGINE_DEFINE_TARGET(target) \
namespace target \
{
#define DEFINE_NEUTRAL_RUNTIME_OP(id, name, value) \
kernel_call_result id(id##_options &, interpreter_t &, interpreter_step_t);
#define DEFINE_RUNTIME_OP(target, id, name, value) \
kernel_call_result id(id##_options &, interpreter_t &, interpreter_step_t);
#define END_DEFINE_TARGET() }
#include <runtime/runtime_op.def>
#undef BEGINE_DEFINE_TARGET
#undef DEFINE_NEUTRAL_RUNTIME_OP
#undef DEFINE_RUNTIME_OP
#undef END_DEFINE_TARGET
}
}
kernel_call_result runtime::call_kernel(runtime_opcode opcode, xtl::span<const uint8_t> body, interpreter_t &interpreter, interpreter_step_t step)
{
span_reader reader(body);
switch (opcode)
{
#define BEGINE_DEFINE_TARGET(...)
#define DEFINE_NEUTRAL_RUNTIME_OP(id, name, value) \
case rop_##id: \
{ \
nncase::runtime::neutral::id##_options options; \
options.deserialize(reader); \
return nncase::runtime::neutral::id(options, interpreter, step); \
}
#define DEFINE_RUNTIME_OP(target, id, name, value) \
case rop_##target##_##id: \
{ \
nncase::runtime::target::id##_options options; \
options.deserialize(reader); \
return nncase::runtime::target::id(options, interpreter, step); \
}
#define END_DEFINE_TARGET()
#include <runtime/runtime_op.def>
#undef BEGINE_DEFINE_TARGET
#undef DEFINE_NEUTRAL_RUNTIME_OP
#undef DEFINE_RUNTIME_OP
#undef END_DEFINE_TARGET
default:
return kcr_error;
}
}

View File

@ -0,0 +1,310 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "section.h"
#include <nncase/runtime/dbg.h>
#include <nncase/runtime/error.h>
#include <nncase/runtime/runtime_function.h>
#include <nncase/runtime/span_reader.h>
using namespace nncase;
using namespace nncase::runtime;
namespace
{
class runtime_function_init_context_impl : public runtime_function_init_context
{
public:
runtime_function_init_context_impl(const function_header &header, runtime_module_init_context &module_init_context, gsl::span<const gsl::byte> body) noexcept
: header_(header), module_init_context_(module_init_context), body_(body)
{
}
runtime_module_init_context &module_init_context() noexcept override
{
return module_init_context_;
}
const function_header &header() noexcept override
{
return header_;
}
gsl::span<const gsl::byte> body() noexcept override
{
return body_;
}
private:
const function_header &header_;
runtime_module_init_context &module_init_context_;
gsl::span<const gsl::byte> body_;
};
}
runtime_function::runtime_function(runtime_module &rt_module)
: rt_module_(rt_module)
{
}
runtime_module &runtime_function::module() const noexcept
{
return rt_module_;
}
uint32_t runtime_function::inputs_size() const noexcept
{
return header_.inputs;
}
uint32_t runtime_function::outputs_size() const noexcept
{
return header_.outputs;
}
const memory_range &runtime_function::input_desc(size_t index) const noexcept
{
assert(index < input_tensors_.size());
return input_tensors_[index].range;
}
const memory_range &runtime_function::output_desc(size_t index) const noexcept
{
assert(index < output_tensors_.size());
return output_tensors_[index].range;
}
const runtime_shape_t &runtime_function::input_shape(size_t index) const noexcept
{
assert(index < input_tensors_.size());
return input_tensors_[index].shape;
}
const runtime_shape_t &runtime_function::output_shape(size_t index) const noexcept
{
assert(index < output_tensors_.size());
return output_tensors_[index].shape;
}
result<void> runtime_function::initialize(gsl::span<const gsl::byte> payload, runtime_module_init_context &module_init_context) noexcept
{
span_reader reader(payload);
reader.read(header_);
try
{
input_tensors_.resize(inputs_size());
output_tensors_.resize(outputs_size());
}
catch (...)
{
return err(std::errc::not_enough_memory);
}
auto read_shape = [&](runtime_shape_t &shape) {
shape.resize(reader.read<uint32_t>());
for (auto &dim : shape)
dim = reader.read<uint32_t>();
};
// inputs
for (auto &in : input_tensors_)
reader.read(in.range);
for (auto &in : input_tensors_)
read_shape(in.shape);
// outputs
for (auto &out : output_tensors_)
reader.read(out.range);
for (auto &out : output_tensors_)
read_shape(out.shape);
runtime_function_init_context_impl init_context(header_, module_init_context, reader.read_avail());
return initialize_core(init_context);
}
#define INOUT_TENSOR_GETTER_IMPL(name) \
CHECK_WITH_ERR(index < name##_tensors_.size(), std::errc::result_out_of_range); \
\
auto &info = name##_tensors_[index]; \
if (info.bind_tensor.empty()) \
{ \
try_set(info.bind_tensor, allocate_##name##_tensor(index)); \
} \
return ok(info.bind_tensor);
result<runtime_tensor> runtime_function::input_tensor(size_t index) noexcept
{
INOUT_TENSOR_GETTER_IMPL(input);
}
result<runtime_tensor> runtime_function::output_tensor(size_t index) noexcept
{
INOUT_TENSOR_GETTER_IMPL(output);
}
#define DEV_INOUT_TENSOR_GETTER_IMPL(name) \
CHECK_WITH_ERR(index < name##_tensors_.size(), std::errc::result_out_of_range); \
\
auto &info = name##_tensors_[index]; \
if (info.bind_tensor.empty()) \
{ \
try_set(info.bind_tensor, allocate_##name##_tensor(index)); \
} \
if (!info.device_tensor.empty()) \
{ \
return ok(info.device_tensor); \
} \
return ok(info.bind_tensor);
result<runtime_tensor> runtime_function::device_input_tensor(size_t index) noexcept
{
DEV_INOUT_TENSOR_GETTER_IMPL(input);
}
result<runtime_tensor> runtime_function::device_output_tensor(size_t index) noexcept
{
DEV_INOUT_TENSOR_GETTER_IMPL(output);
}
result<void> runtime_function::input_tensor(size_t index, runtime_tensor tensor) noexcept
{
CHECK_WITH_ERR(!tensor.empty(), std::errc::invalid_argument);
CHECK_WITH_ERR(index < input_tensors_.size(), std::errc::result_out_of_range);
auto &info = input_tensors_[index];
CHECK_WITH_ERR(info.range.datatype == tensor.datatype(), nncase_errc::datatype_mismatch);
CHECK_WITH_ERR(info.shape == tensor.shape(), nncase_errc::shape_mismatch);
if (info.bind_tensor != tensor)
{
if (validate_input_tensor(index, tensor).is_err())
{
auto device_tensor = info.device_tensor;
if (device_tensor.empty())
try_var(device_tensor, allocate_input_tensor(index));
if (!tensor.can_copy_to_without_staging(device_tensor))
{
try_set(info.staging_tensor, host_runtime_tensor::create(info.range.datatype, info.shape));
}
else
{
info.staging_tensor.reset();
}
info.device_tensor = device_tensor;
}
else
{
info.device_tensor.reset();
info.staging_tensor.reset();
}
info.bind_tensor = tensor;
}
return ok();
}
result<void> runtime_function::output_tensor(size_t index, runtime_tensor tensor) noexcept
{
CHECK_WITH_ERR(!tensor.empty(), std::errc::invalid_argument);
CHECK_WITH_ERR(index < output_tensors_.size(), std::errc::result_out_of_range);
auto &info = output_tensors_[index];
CHECK_WITH_ERR(info.range.datatype == tensor.datatype(), nncase_errc::datatype_mismatch);
CHECK_WITH_ERR(info.shape == tensor.shape(), nncase_errc::shape_mismatch);
if (info.bind_tensor != tensor)
{
if (validate_output_tensor(index, tensor).is_err())
{
auto device_tensor = info.device_tensor;
if (device_tensor.empty())
try_var(device_tensor, allocate_output_tensor(index));
if (!device_tensor.can_copy_to_without_staging(tensor))
{
try_set(info.staging_tensor, host_runtime_tensor::create(info.range.datatype, info.shape));
}
else
{
info.staging_tensor.reset();
}
info.device_tensor = device_tensor;
}
else
{
info.device_tensor.reset();
info.staging_tensor.reset();
}
info.bind_tensor = tensor;
}
return ok();
}
result<void> runtime_function::invoke() noexcept
{
// 1. Ensure bindings
for (size_t i = 0; i < input_tensors_.size(); i++)
{
auto &info = input_tensors_[i];
if (info.bind_tensor.empty())
try_set(info.bind_tensor, allocate_input_tensor(i));
}
for (size_t i = 0; i < output_tensors_.size(); i++)
{
auto &info = output_tensors_[i];
if (info.bind_tensor.empty())
try_set(info.bind_tensor, allocate_output_tensor(i));
}
// 2. Copy inputs
for (auto &in : input_tensors_)
{
if (in.staging_tensor.empty())
{
if (!in.device_tensor.empty())
try_(in.bind_tensor.copy_to(in.device_tensor));
}
else
{
try_(in.bind_tensor.copy_to(in.staging_tensor));
try_(in.staging_tensor.copy_to(in.device_tensor));
}
}
// 3. Run
try_(invoke_core());
// 4. Copy outputs
for (auto &out : output_tensors_)
{
if (out.staging_tensor.empty())
{
if (!out.device_tensor.empty())
try_(out.device_tensor.copy_to(out.bind_tensor));
}
else
{
try_(out.device_tensor.copy_to(out.staging_tensor));
try_(out.staging_tensor.copy_to(out.bind_tensor));
}
}
return ok();
}

View File

@ -12,6 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "section.h"
#include <nncase/runtime/dbg.h>
#include <nncase/runtime/error.h>
#include <nncase/runtime/runtime_module.h>
@ -47,33 +48,7 @@ public:
gsl::span<const gsl::byte> section(const char *name) noexcept override
{
span_reader reader(sections_);
while (!reader.empty())
{
auto header = reader.get_ref<section_header>();
if (!strncmp(header->name, name, MAX_SECTION_NAME_LENGTH))
{
gsl::span<const gsl::byte> result;
if (header->flags & SECTION_MERGED_INTO_RDATA)
{
auto rdata_span = section(".rdata");
result = rdata_span.subspan(header->start, header->size);
}
else
{
result = reader.read_avail().subspan(header->start, header->size);
}
return result;
}
else
{
if (!(header->flags & SECTION_MERGED_INTO_RDATA))
reader.skip((size_t)header->start + header->size);
}
}
return {};
return find_section(name, sections_);
}
private:
@ -81,6 +56,21 @@ private:
interpreter &interp_;
gsl::span<const gsl::byte> sections_;
};
gsl::span<const gsl::byte> read_functions(span_reader &sr, size_t functions) noexcept
{
auto nest_sr = sr;
size_t size = 0;
for (size_t i = 0; i < functions; i++)
{
auto func_size = nest_sr.peek_with_offset<decltype(function_header::size)>(offsetof(function_header, size));
nest_sr.skip(func_size);
size += func_size;
}
return sr.read_span(size);
}
}
const module_type_t &runtime_module::type() const noexcept
@ -112,50 +102,17 @@ mempool_desc runtime_module::mempool(memory_location_t location) const noexcept
return desc;
}
uint32_t runtime_module::inputs_size() const noexcept
{
return header_.inputs;
}
uint32_t runtime_module::outputs_size() const noexcept
{
return header_.outputs;
}
const memory_range &runtime_module::input_desc(size_t index) const noexcept
{
assert(index < input_tensors_.size());
return input_tensors_[index].range;
}
const memory_range &runtime_module::output_desc(size_t index) const noexcept
{
assert(index < output_tensors_.size());
return output_tensors_[index].range;
}
const runtime_shape_t &runtime_module::input_shape(size_t index) const noexcept
{
assert(index < input_tensors_.size());
return input_tensors_[index].shape;
}
const runtime_shape_t &runtime_module::output_shape(size_t index) const noexcept
{
assert(index < output_tensors_.size());
return output_tensors_[index].shape;
}
result<void> runtime_module::initialize(const module_header &header, interpreter &interp) noexcept
result<void> runtime_module::initialize(gsl::span<const gsl::byte> payload, interpreter &interp) noexcept
{
interp_ = &interp;
header_ = header;
span_reader reader(gsl::make_span(reinterpret_cast<const gsl::byte *>(&header) + sizeof(module_header), header.size));
span_reader reader(payload);
reader.read(header_);
try
{
input_tensors_.resize(inputs_size());
output_tensors_.resize(outputs_size());
mempools_.resize(mempools_size());
mempools_.resize(header_.mempools);
shared_mempools_.resize(header_.shared_mempools);
functions_.resize(header_.functions);
}
catch (...)
{
@ -166,204 +123,38 @@ result<void> runtime_module::initialize(const module_header &header, interpreter
for (auto &desc : mempools_)
reader.read(desc);
auto read_shape = [&](runtime_shape_t &shape) {
shape.resize(reader.read<uint32_t>());
for (auto &dim : shape)
dim = reader.read<uint32_t>();
};
// shared mempools
for (auto &desc : shared_mempools_)
reader.read(desc);
// inputs
for (auto &in : input_tensors_)
reader.read(in.range);
for (auto &in : input_tensors_)
read_shape(in.shape);
span_reader func_reader(read_functions(reader, header_.functions));
runtime_module_init_context_impl init_context(header_, interp, read_sections(reader, header_.sections));
try_(initialize_before_functions(init_context));
// outputs
for (auto &out : output_tensors_)
reader.read(out.range);
for (auto &out : output_tensors_)
read_shape(out.shape);
runtime_module_init_context_impl init_context(header, interp, reader.read_avail());
return initialize_core(init_context);
}
#define INOUT_TENSOR_GETTER_IMPL(name) \
CHECK_WITH_ERR(index < name##_tensors_.size(), std::errc::result_out_of_range); \
\
auto &info = name##_tensors_[index]; \
if (info.bind_tensor.empty()) \
{ \
try_set(info.bind_tensor, allocate_##name##_tensor(index)); \
} \
return ok(info.bind_tensor);
result<runtime_tensor> runtime_module::input_tensor(size_t index) noexcept
{
INOUT_TENSOR_GETTER_IMPL(input);
}
result<runtime_tensor> runtime_module::output_tensor(size_t index) noexcept
{
INOUT_TENSOR_GETTER_IMPL(output);
}
#define DEV_INOUT_TENSOR_GETTER_IMPL(name) \
CHECK_WITH_ERR(index < name##_tensors_.size(), std::errc::result_out_of_range); \
\
auto &info = name##_tensors_[index]; \
if (info.bind_tensor.empty()) \
{ \
try_set(info.bind_tensor, allocate_##name##_tensor(index)); \
} \
if (!info.device_tensor.empty()) \
{ \
return ok(info.device_tensor); \
} \
return ok(info.bind_tensor);
result<runtime_tensor> runtime_module::device_input_tensor(size_t index) noexcept
{
DEV_INOUT_TENSOR_GETTER_IMPL(input);
}
result<runtime_tensor> runtime_module::device_output_tensor(size_t index) noexcept
{
DEV_INOUT_TENSOR_GETTER_IMPL(output);
}
result<void> runtime_module::input_tensor(size_t index, runtime_tensor tensor) noexcept
{
CHECK_WITH_ERR(!tensor.empty(), std::errc::invalid_argument);
CHECK_WITH_ERR(index < input_tensors_.size(), std::errc::result_out_of_range);
auto &info = input_tensors_[index];
CHECK_WITH_ERR(info.range.datatype == tensor.datatype(), nncase_errc::datatype_mismatch);
CHECK_WITH_ERR(info.shape == tensor.shape(), nncase_errc::shape_mismatch);
if (info.bind_tensor != tensor)
for (size_t i = 0; i < header_.functions; i++)
{
if (validate_input_tensor(index, tensor).is_err())
{
auto device_tensor = info.device_tensor;
if (device_tensor.empty())
try_var(device_tensor, allocate_input_tensor(index));
if (!tensor.can_copy_to_without_staging(device_tensor))
{
try_set(info.staging_tensor, host_runtime_tensor::create(info.range.datatype, info.shape));
}
else
{
info.staging_tensor.reset();
}
info.device_tensor = device_tensor;
}
else
{
info.device_tensor.reset();
info.staging_tensor.reset();
}
info.bind_tensor = tensor;
auto func_size = func_reader.peek_with_offset<decltype(function_header::size)>(offsetof(function_header, size));
auto payload = func_reader.read_span(func_size);
try_var(func, create_function());
try_(func->initialize(payload, init_context));
functions_[i] = std::move(func);
}
return initialize_after_functions(init_context);
}
result<runtime_function *> runtime_module::find_function_by_id(size_t index) noexcept
{
CHECK_WITH_ERR(index < functions_.size(), std::errc::result_out_of_range);
return ok(functions_[index].get());
}
result<void> runtime_module::initialize_before_functions(NNCASE_UNUSED runtime_module_init_context &context) noexcept
{
return ok();
}
result<void> runtime_module::output_tensor(size_t index, runtime_tensor tensor) noexcept
{
CHECK_WITH_ERR(!tensor.empty(), std::errc::invalid_argument);
CHECK_WITH_ERR(index < output_tensors_.size(), std::errc::result_out_of_range);
auto &info = output_tensors_[index];
CHECK_WITH_ERR(info.range.datatype == tensor.datatype(), nncase_errc::datatype_mismatch);
CHECK_WITH_ERR(info.shape == tensor.shape(), nncase_errc::shape_mismatch);
if (info.bind_tensor != tensor)
{
if (validate_output_tensor(index, tensor).is_err())
{
auto device_tensor = info.device_tensor;
if (device_tensor.empty())
try_var(device_tensor, allocate_output_tensor(index));
if (!device_tensor.can_copy_to_without_staging(tensor))
{
try_set(info.staging_tensor, host_runtime_tensor::create(info.range.datatype, info.shape));
}
else
{
info.staging_tensor.reset();
}
info.device_tensor = device_tensor;
}
else
{
info.device_tensor.reset();
info.staging_tensor.reset();
}
info.bind_tensor = tensor;
}
return ok();
}
result<void> runtime_module::run() noexcept
{
// 1. Ensure bindings
for (size_t i = 0; i < input_tensors_.size(); i++)
{
auto &info = input_tensors_[i];
if (info.bind_tensor.empty())
try_set(info.bind_tensor, allocate_input_tensor(i));
}
for (size_t i = 0; i < output_tensors_.size(); i++)
{
auto &info = output_tensors_[i];
if (info.bind_tensor.empty())
try_set(info.bind_tensor, allocate_output_tensor(i));
}
// 2. Copy inputs
for (auto &in : input_tensors_)
{
if (in.staging_tensor.empty())
{
if (!in.device_tensor.empty())
try_(in.bind_tensor.copy_to(in.device_tensor));
}
else
{
try_(in.bind_tensor.copy_to(in.staging_tensor));
try_(in.staging_tensor.copy_to(in.device_tensor));
}
}
// 3. Run
try_(run_core());
// 4. Copy outputs
for (auto &out : output_tensors_)
{
if (out.staging_tensor.empty())
{
if (!out.device_tensor.empty())
try_(out.device_tensor.copy_to(out.bind_tensor));
}
else
{
try_(out.device_tensor.copy_to(out.staging_tensor));
try_(out.staging_tensor.copy_to(out.bind_tensor));
}
}
return ok();
}
result<void> runtime_module::initialize_inter_modules(NNCASE_UNUSED interpreter &interp) noexcept
result<void> runtime_module::initialize_after_functions(NNCASE_UNUSED runtime_module_init_context &context) noexcept
{
return ok();
}

71
src/runtime/section.cpp Normal file
View File

@ -0,0 +1,71 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "section.h"
#include <nncase/runtime/span_reader.h>
using namespace nncase;
using namespace nncase::runtime;
gsl::span<const gsl::byte> runtime::find_section(const char *name, gsl::span<const gsl::byte> sections) noexcept
{
span_reader reader(sections);
while (!reader.empty())
{
auto header = reader.get_ref<section_header>();
if (!strncmp(header->name, name, MAX_SECTION_NAME_LENGTH))
{
gsl::span<const gsl::byte> result;
if (header->flags & SECTION_MERGED_INTO_RDATA)
{
auto rdata_span = find_section(".rdata", sections);
result = rdata_span.subspan(header->body_start, header->body_size);
}
else
{
result = reader.read_avail().subspan(header->body_start, header->body_size);
}
return result;
}
else
{
if (!(header->flags & SECTION_MERGED_INTO_RDATA))
reader.skip((size_t)header->body_start + header->body_size);
}
}
return {};
}
gsl::span<const gsl::byte> runtime::read_sections(span_reader &sr, size_t sections) noexcept
{
auto nest_sr = sr;
size_t size = 0;
for (size_t i = 0; i < sections; i++)
{
auto header = nest_sr.get_ref<section_header>();
size += sizeof(section_header);
if (!(header->flags & SECTION_MERGED_INTO_RDATA))
{
auto to_skip = (size_t)header->body_start + header->body_size;
nest_sr.skip(to_skip);
size += to_skip;
}
}
return sr.read_span(size);
}

View File

@ -13,12 +13,13 @@
* limitations under the License.
*/
#pragma once
#include "../runtime_module.h"
#include <nncase/kernels/kernel_context.h>
NNCASE_MODULES_K210_API
#include <nncase/runtime/model.h>
#include <nncase/runtime/result.h>
#include <nncase/runtime/span_reader.h>
struct NNCASE_API k210_kernel_context : public kernels::kernel_context
{
};
BEGIN_NS_NNCASE_RUNTIME
END_NS_NNCASE_KERNELS_K210
gsl::span<const gsl::byte> find_section(const char *name, gsl::span<const gsl::byte> sections) noexcept;
gsl::span<const gsl::byte> read_sections(span_reader &sr, size_t sections) noexcept;
END_NS_NNCASE_RUNTIME

View File

@ -1,6 +1,7 @@
cmake_minimum_required (VERSION 3.13)
set(SRCS runtime_module.cpp
runtime_function.cpp
op_reader.cpp
evaluate_stack.cpp
ops/control.cpp

View File

@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/7/14 19:17:48 +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/8/11 17:40:11 +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*

View File

@ -12,23 +12,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const nop_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const nop_op_t &op) noexcept
{
return ok();
}
result<void> stackvm_runtime_module::visit(const br_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const br_op_t &op) noexcept
{
return pc_relative(op.target);
}
result<void> stackvm_runtime_module::visit(const br_true_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const br_true_op_t &op) noexcept
{
try_var(value, stack_.pop());
if (value.as_i())
@ -36,7 +36,7 @@ result<void> stackvm_runtime_module::visit(const br_true_op_t &op) noexcept
return ok();
}
result<void> stackvm_runtime_module::visit(const br_false_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const br_false_op_t &op) noexcept
{
try_var(value, stack_.pop());
if (!value.as_i())
@ -44,7 +44,7 @@ result<void> stackvm_runtime_module::visit(const br_false_op_t &op) noexcept
return ok();
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ret_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ret_op_t &op) noexcept
{
if (call_depth_ == 0)
{
@ -57,22 +57,22 @@ result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ret_op_t &op) noe
return pc(target.as_u());
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const call_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const call_op_t &op) noexcept
{
return err(std::errc::not_supported);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ecall_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ecall_op_t &op) noexcept
{
return err(std::errc::not_supported);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const throw_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const throw_op_t &op) noexcept
{
return err(std::errc::not_supported);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const break_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const break_op_t &op) noexcept
{
return err(std::errc::not_supported);
}

View File

@ -12,7 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
using namespace nncase;
using namespace nncase::runtime;
@ -25,52 +25,52 @@ using namespace nncase::runtime::stackvm;
else \
return stack_.push((type)value.as_r())
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_i1_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_i1_op_t &op) noexcept
{
CONV_IMPL(int8_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_i2_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_i2_op_t &op) noexcept
{
CONV_IMPL(int16_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_i4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_i4_op_t &op) noexcept
{
CONV_IMPL(int32_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_i_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_i_op_t &op) noexcept
{
CONV_IMPL(intptr_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_u1_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_u1_op_t &op) noexcept
{
CONV_IMPL(uint8_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_u2_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_u2_op_t &op) noexcept
{
CONV_IMPL(uint16_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_u4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_u4_op_t &op) noexcept
{
CONV_IMPL(uint32_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_u_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_u_op_t &op) noexcept
{
CONV_IMPL(uintptr_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_br2_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_br2_op_t &op) noexcept
{
CONV_IMPL(bfloat16);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_r4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_r4_op_t &op) noexcept
{
CONV_IMPL(float);
}

View File

@ -12,33 +12,33 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(const ldc_i4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const ldc_i4_op_t &op) noexcept
{
return stack_.push(op.imm);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldnull_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldnull_op_t &op) noexcept
{
return stack_.push((uintptr_t)0);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldc_i4_0_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldc_i4_0_op_t &op) noexcept
{
return stack_.push((int32_t)0);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldc_i4_1_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldc_i4_1_op_t &op) noexcept
{
return stack_.push((int32_t)1);
}
result<void> stackvm_runtime_module::visit(const ldc_r4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const ldc_r4_op_t &op) noexcept
{
return stack_.push(op.imm);
}
@ -49,52 +49,52 @@ result<void> stackvm_runtime_module::visit(const ldc_r4_op_t &op) noexcept
return err(std::errc::bad_address); \
return stack_.push(*reinterpret_cast<const type *>(addr.as_u()))
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_i1_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_i1_op_t &op) noexcept
{
LDINDIMPL(int8_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_i2_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_i2_op_t &op) noexcept
{
LDINDIMPL(int16_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_i4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_i4_op_t &op) noexcept
{
LDINDIMPL(int32_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_i_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_i_op_t &op) noexcept
{
LDINDIMPL(intptr_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_u1_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_u1_op_t &op) noexcept
{
LDINDIMPL(uint8_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_u2_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_u2_op_t &op) noexcept
{
LDINDIMPL(uint16_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_u4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_u4_op_t &op) noexcept
{
LDINDIMPL(uint32_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_u_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_u_op_t &op) noexcept
{
LDINDIMPL(uintptr_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_br2_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_br2_op_t &op) noexcept
{
LDINDIMPL(bfloat16);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_r4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_r4_op_t &op) noexcept
{
LDINDIMPL(float);
}
@ -107,44 +107,43 @@ result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_r4_op_t &op
*reinterpret_cast<decltype(value.as_##type()) *>(addr.as_u()) = value.as_##type(); \
return ok()
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stind_i1_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stind_i1_op_t &op) noexcept
{
STINDIMPL(i1);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stind_i2_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stind_i2_op_t &op) noexcept
{
STINDIMPL(i2);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stind_i4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stind_i4_op_t &op) noexcept
{
STINDIMPL(i4);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stind_i_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stind_i_op_t &op) noexcept
{
STINDIMPL(i);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stind_br2_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stind_br2_op_t &op) noexcept
{
STINDIMPL(br2);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stind_r4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stind_r4_op_t &op) noexcept
{
STINDIMPL(r4);
}
result<void> stackvm_runtime_module::visit(const lea_gp_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const lea_gp_op_t &op) noexcept
{
if (op.gpid >= regs_.size())
return err(std::errc::result_out_of_range);
return stack_.push((intptr_t)regs_[op.gpid] + op.offset);
try_var(reg, module().reg(op.gpid));
return stack_.push((intptr_t)reg + op.offset);
}
result<void> stackvm_runtime_module::visit(const lea_buffer_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const lea_buffer_op_t &op) noexcept
{
#define ID_NOT_FOUND ((size_t)-1)
// TODO: use subres
@ -206,13 +205,13 @@ result<void> stackvm_runtime_module::visit(const lea_buffer_op_t &op) noexcept
}
else if (op.location == mem_rdata)
{
auto buffer = rdata_.subspan(op.offset);
auto buffer = module().rdata().subspan(op.offset);
return stack_.push((uintptr_t)buffer.data());
}
else if (op.location == mem_data)
{
auto buffer = data_.get() + op.offset;
return stack_.push((uintptr_t)buffer);
auto buffer = module().data().subspan(op.offset);
return stack_.push((uintptr_t)buffer.data());
}
else
{
@ -225,52 +224,52 @@ result<void> stackvm_runtime_module::visit(const lea_buffer_op_t &op) noexcept
try_var(addr, stack_.pop()); \
return stack_.push(reinterpret_cast<const type *>(addr.as_u())[offset.as_u()])
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_i1_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_i1_op_t &op) noexcept
{
LDELEM_IMPL(int8_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_i2_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_i2_op_t &op) noexcept
{
LDELEM_IMPL(int16_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_i4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_i4_op_t &op) noexcept
{
LDELEM_IMPL(int32_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_i_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_i_op_t &op) noexcept
{
LDELEM_IMPL(intptr_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_u1_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_u1_op_t &op) noexcept
{
LDELEM_IMPL(uint8_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_u2_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_u2_op_t &op) noexcept
{
LDELEM_IMPL(uint16_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_u4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_u4_op_t &op) noexcept
{
LDELEM_IMPL(uint32_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_u_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_u_op_t &op) noexcept
{
LDELEM_IMPL(uintptr_t);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_br2_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_br2_op_t &op) noexcept
{
LDELEM_IMPL(bfloat16);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_r4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_r4_op_t &op) noexcept
{
LDELEM_IMPL(float);
}
@ -282,116 +281,110 @@ result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_r4_op_t &o
reinterpret_cast<decltype(value.as_##type()) *>(addr.as_u())[offset.as_u()] = value.as_##type(); \
return ok()
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stelem_i1_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stelem_i1_op_t &op) noexcept
{
STELEM_IMPL(i1);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stelem_i2_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stelem_i2_op_t &op) noexcept
{
STELEM_IMPL(i2);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stelem_i4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stelem_i4_op_t &op) noexcept
{
STELEM_IMPL(i4);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stelem_i_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stelem_i_op_t &op) noexcept
{
STELEM_IMPL(i);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stelem_br2_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stelem_br2_op_t &op) noexcept
{
STELEM_IMPL(br2);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stelem_r4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stelem_r4_op_t &op) noexcept
{
STELEM_IMPL(r4);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldarg_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldarg_op_t &op) noexcept
{
return err(std::errc::not_supported);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldarg_0_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldarg_0_op_t &op) noexcept
{
return err(std::errc::not_supported);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldarg_1_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldarg_1_op_t &op) noexcept
{
return err(std::errc::not_supported);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldarg_2_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldarg_2_op_t &op) noexcept
{
return err(std::errc::not_supported);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldarg_3_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldarg_3_op_t &op) noexcept
{
return err(std::errc::not_supported);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldarg_4_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldarg_4_op_t &op) noexcept
{
return err(std::errc::not_supported);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldarg_5_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldarg_5_op_t &op) noexcept
{
return err(std::errc::not_supported);
}
result<void> stackvm_runtime_module::visit(const stshape_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const stshape_op_t &op) noexcept
{
if (op.rshape >= shape_regs_.size())
shape_regs_.resize(op.rshape + 1);
auto &reg = shape_regs_[op.rshape];
runtime_shape_t shape;
try
{
reg.resize(op.rank);
shape.resize(op.rank);
}
catch (...)
{
return err(std::errc::not_enough_memory);
}
for (size_t i = 0; i < reg.size(); i++)
for (size_t i = 0; i < shape.size(); i++)
{
try_var(dim, stack_.pop());
reg[op.rank - i - 1] = (size_t)dim.as_u();
shape[op.rank - i - 1] = (size_t)dim.as_u();
}
return ok();
return module().shape_reg(op.rshape, std::move(shape));
}
result<void> stackvm_runtime_module::visit(const stpaddings_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const stpaddings_op_t &op) noexcept
{
if (op.rpaddings >= paddings_regs_.size())
paddings_regs_.resize(op.rpaddings + 1);
auto &reg = paddings_regs_[op.rpaddings];
runtime_paddings_t paddings;
try
{
reg.resize(op.rank);
paddings.resize(op.rank);
}
catch (...)
{
return err(std::errc::not_enough_memory);
}
for (size_t i = 0; i < reg.size(); i++)
for (size_t i = 0; i < paddings.size(); i++)
{
try_var(after, stack_.pop());
try_var(before, stack_.pop());
reg[op.rank - i - 1] = { before.as_i4(), after.as_i4() };
paddings[op.rank - i - 1] = { before.as_i4(), after.as_i4() };
}
return ok();
return module().paddings_reg(op.rpaddings, std::move(paddings));
}

View File

@ -12,13 +12,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const neg_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const neg_op_t &op) noexcept
{
try_var(value, stack_.pop());
if (!value.is_real())
@ -53,73 +53,73 @@ result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const neg_op_t &op) noe
try_var(a, stack_.pop()); \
return stack_.push(a.as_u() op b.as_u());
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const add_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const add_op_t &op) noexcept
{
BINARY_IMPL(+);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const sub_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const sub_op_t &op) noexcept
{
BINARY_IMPL(-);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const mul_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const mul_op_t &op) noexcept
{
BINARY_IMPL(*);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const div_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const div_op_t &op) noexcept
{
BINARY_IMPL(/);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const div_u_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const div_u_op_t &op) noexcept
{
BINARY_U_IMPL(/);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const rem_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const rem_op_t &op) noexcept
{
BINARY_IMPL(/);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const rem_u_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const rem_u_op_t &op) noexcept
{
BINARY_U_IMPL(/);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const and_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const and_op_t &op) noexcept
{
BINARY_BIT_U_IMPL(&);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const or_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const or_op_t &op) noexcept
{
BINARY_BIT_U_IMPL(|);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const xor_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const xor_op_t &op) noexcept
{
BINARY_BIT_U_IMPL(^);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const not_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const not_op_t &op) noexcept
{
try_var(value, stack_.pop());
return stack_.push(~value.as_u());
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const shl_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const shl_op_t &op) noexcept
{
BINARY_BIT_U_IMPL(<<);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const shr_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const shr_op_t &op) noexcept
{
BINARY_BIT_IMPL(>>);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const shr_u_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const shr_u_op_t &op) noexcept
{
BINARY_BIT_U_IMPL(>>);
}
@ -140,52 +140,52 @@ result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const shr_u_op_t &op) n
else \
return stack_.push(a.as_r() op b.as_r() ? 1 : 0)
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const clt_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const clt_op_t &op) noexcept
{
COMPARE_IMPL(<);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const clt_u_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const clt_u_op_t &op) noexcept
{
COMPARE_U_IMPL(<);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const cle_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const cle_op_t &op) noexcept
{
COMPARE_IMPL(<=);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const cle_u_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const cle_u_op_t &op) noexcept
{
COMPARE_U_IMPL(<=);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ceq_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ceq_op_t &op) noexcept
{
COMPARE_U_IMPL(==);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const cge_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const cge_op_t &op) noexcept
{
COMPARE_IMPL(>=);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const cge_u_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const cge_u_op_t &op) noexcept
{
COMPARE_U_IMPL(>=);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const cgt_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const cgt_op_t &op) noexcept
{
COMPARE_IMPL(>);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const cgt_u_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const cgt_u_op_t &op) noexcept
{
COMPARE_U_IMPL(>);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const cne_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const cne_op_t &op) noexcept
{
COMPARE_U_IMPL(!=);
}

View File

@ -12,19 +12,19 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const dup_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const dup_op_t &op) noexcept
{
try_var(entry, stack_.peek());
return stack_.push(entry);
}
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const pop_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const pop_op_t &op) noexcept
{
try_(stack_.pop());
return ok();

View File

@ -12,23 +12,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/tensor_compute.h>
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(const tensor_batch_to_space_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const tensor_batch_to_space_op_t &op) noexcept
{
try_var(output, pop_addr());
try_var(input, pop_addr());
auto &in_shape = shape_regs_[op.rshape_src];
auto &block_shape = shape_regs_[op.rshape_block];
auto &crops = paddings_regs_[op.rpad_crops];
auto &in_strides = shape_regs_[op.rstride_src];
auto &out_strides = shape_regs_[op.rstride_dest];
try_var(in_shape, module().shape_reg(op.rshape_src));
try_var(block_shape, module().shape_reg(op.rshape_block));
try_var(crops, module().paddings_reg(op.rpad_crops));
try_var(in_strides, module().shape_reg(op.rstride_src));
try_var(out_strides, module().shape_reg(op.rstride_dest));
return kernels::batch_to_space(op.datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output),
in_shape, block_shape, crops, in_strides, out_strides, kernel_context());
in_shape, block_shape, crops, in_strides, out_strides, module().kernel_context());
}

View File

@ -12,24 +12,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/tensor_compute.h>
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(const tensor_binary_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const tensor_binary_op_t &op) noexcept
{
try_var(output, pop_addr());
try_var(input_b, pop_addr());
try_var(input_a, pop_addr());
auto &in_a_shape = shape_regs_[op.rshape_src1];
auto &in_a_strides = shape_regs_[op.rstride_src1];
auto &in_b_shape = shape_regs_[op.rshape_src2];
auto &in_b_strides = shape_regs_[op.rstride_src2];
auto &out_strides = shape_regs_[op.rstride_dest];
try_var(in_a_shape, module().shape_reg(op.rshape_src1));
try_var(in_a_strides, module().shape_reg(op.rstride_src1));
try_var(in_b_shape, module().shape_reg(op.rshape_src2));
try_var(in_b_strides, module().shape_reg(op.rstride_src2));
try_var(out_strides, module().shape_reg(op.rstride_dest));
return kernels::binary(op.binary_op, reinterpret_cast<const float *>(input_a), reinterpret_cast<const float *>(input_b),
reinterpret_cast<float *>(output), in_a_shape, in_a_strides, in_b_shape, in_b_strides, out_strides, { op.fused_clamp_low, op.fused_clamp_high }, kernel_context());
reinterpret_cast<float *>(output), in_a_shape, in_a_strides, in_b_shape, in_b_strides, out_strides, { op.fused_clamp_low, op.fused_clamp_high }, module().kernel_context());
}

View File

@ -12,22 +12,22 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/tensor_compute.h>
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(const tensor_broadcast_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const tensor_broadcast_op_t &op) noexcept
{
try_var(output, pop_addr());
try_var(input, pop_addr());
auto &in_shape = shape_regs_[op.rshape_src];
auto &in_strides = shape_regs_[op.rstride_src];
auto &out_shape = shape_regs_[op.rshape_dest];
auto &out_strides = shape_regs_[op.rstride_dest];
try_var(in_shape, module().shape_reg(op.rshape_src));
try_var(in_strides, module().shape_reg(op.rstride_src));
try_var(out_shape, module().shape_reg(op.rshape_dest));
try_var(out_strides, module().shape_reg(op.rstride_dest));
return kernels::broadcast(op.datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output),
in_shape, in_strides, out_shape, out_strides, kernel_context());
in_shape, in_strides, out_shape, out_strides, module().kernel_context());
}

View File

@ -12,7 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/tensor_compute.h>
#include <nncase/runtime/interpreter.h>
#include <nncase/runtime/runtime_op_utility.h>
@ -21,15 +21,16 @@ using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(const tensor_call_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const tensor_call_op_t &op) noexcept
{
try_var(mod, interp().find_module_by_id(op.module_id));
try_var(mod, module().interp().find_module_by_id(op.module_id));
try_var(func, mod->find_function_by_id(op.function_id));
auto create_tensor = [&]() -> result<runtime_tensor> {
try_var(rstrides, stack_.pop());
auto &strides = shape_regs_[rstrides.as_u4()];
try_var(strides, module().shape_reg(rstrides.as_u4()));
try_var(rshape, stack_.pop());
auto &shape = shape_regs_[rshape.as_u4()];
try_var(shape, module().shape_reg(rshape.as_u4()));
try_var(e_datatype, stack_.pop());
try_var(addr, pop_addr());
@ -40,14 +41,14 @@ result<void> stackvm_runtime_module::visit(const tensor_call_op_t &op) noexcept
for (uint8_t i = 0; i < op.num_dst; i++)
{
try_var(tensor, create_tensor());
try_(mod->output_tensor(op.num_dst - i - 1, tensor));
try_(func->output_tensor((size_t)op.num_dst - i - 1, tensor));
}
for (uint8_t i = 0; i < op.num_src; i++)
{
try_var(tensor, create_tensor());
try_(mod->input_tensor(op.num_src - i - 1, tensor));
try_(func->input_tensor((size_t)op.num_src - i - 1, tensor));
}
return mod->run();
return func->invoke();
}

View File

@ -12,14 +12,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/convolution.h>
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(const tensor_conv2d_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const tensor_conv2d_op_t &op) noexcept
{
try_var(padding_w, pop_padding());
try_var(padding_h, pop_padding());
@ -27,16 +27,16 @@ result<void> stackvm_runtime_module::visit(const tensor_conv2d_op_t &op) noexcep
try_var(bias, pop_addr());
try_var(weights, pop_addr());
try_var(input, pop_addr());
auto &in_shape = shape_regs_[op.rshape_src];
auto &in_strides = shape_regs_[op.rstride_src];
auto &w_shape = shape_regs_[op.rshape_kernel];
auto &w_strides = shape_regs_[op.rstride_kernel];
auto &bias_strides = shape_regs_[op.rstride_bias];
auto &out_strides = shape_regs_[op.rstride_dest];
try_var(in_shape, module().shape_reg(op.rshape_src));
try_var(in_strides, module().shape_reg(op.rstride_src));
try_var(w_shape, module().shape_reg(op.rshape_kernel));
try_var(w_strides, module().shape_reg(op.rstride_kernel));
try_var(bias_strides, module().shape_reg(op.rstride_bias));
try_var(out_strides, module().shape_reg(op.rstride_dest));
if (op.datatype != dt_float32)
return err(nncase_errc::datatype_mismatch);
return kernels::conv2d(reinterpret_cast<const float *>(input), reinterpret_cast<const float *>(weights),
reinterpret_cast<const float *>(bias), reinterpret_cast<float *>(output), in_shape, in_strides, w_shape, w_strides, bias_strides, out_strides,
padding_h, padding_w, op.groups, op.stride_h, op.stride_w, op.dilation_h, op.dilation_w, { op.fused_clamp_low, op.fused_clamp_high }, kernel_context());
padding_h, padding_w, op.groups, op.stride_h, op.stride_w, op.dilation_h, op.dilation_w, { op.fused_clamp_low, op.fused_clamp_high }, module().kernel_context());
}

View File

@ -12,20 +12,20 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/tensor_compute.h>
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(const tensor_convert_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const tensor_convert_op_t &op) noexcept
{
try_var(output, pop_addr());
try_var(input, pop_addr());
auto &shape = shape_regs_[op.rshape_src];
auto &in_strides = shape_regs_[op.rstride_src];
auto &out_strides = shape_regs_[op.rstride_dest];
try_var(shape, module().shape_reg(op.rshape_src));
try_var(in_strides, module().shape_reg(op.rstride_src));
try_var(out_strides, module().shape_reg(op.rstride_dest));
return kernels::convert(op.in_datatype, op.dst_datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output), shape, in_strides, out_strides, kernel_context());
return kernels::convert(op.in_datatype, op.dst_datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output), shape, in_strides, out_strides, module().kernel_context());
}

View File

@ -12,7 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/tensor_compute.h>
#include <nncase/runtime/interpreter.h>
#include <nncase/runtime/runtime_op_utility.h>
@ -21,13 +21,13 @@ using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(const tensor_copy_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const tensor_copy_op_t &op) noexcept
{
try_var(output, pop_addr());
try_var(input, pop_addr());
auto &shape = shape_regs_[op.rshape];
auto &in_strides = shape_regs_[op.rstride_src];
auto &out_strides = shape_regs_[op.rstride_dest];
try_var(shape, module().shape_reg(op.rshape));
try_var(in_strides, module().shape_reg(op.rstride_src));
try_var(out_strides, module().shape_reg(op.rstride_dest));
return kernels::copy(op.datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output), shape, in_strides, out_strides, kernel_context());
return kernels::copy(op.datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output), shape, in_strides, out_strides, module().kernel_context());
}

View File

@ -12,24 +12,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/tensor_compute.h>
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(const tensor_dequantize_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const tensor_dequantize_op_t &op) noexcept
{
try_var(bias, stack_.pop());
try_var(scale, stack_.pop());
try_var(output, pop_addr());
try_var(input, pop_addr());
auto &shape = shape_regs_[op.rshape_src];
auto &in_strides = shape_regs_[op.rstride_src];
auto &out_strides = shape_regs_[op.rstride_dest];
try_var(shape, module().shape_reg(op.rshape_src));
try_var(in_strides, module().shape_reg(op.rstride_src));
try_var(out_strides, module().shape_reg(op.rstride_dest));
return kernels::dequantize(op.in_datatype, op.dst_datatype, reinterpret_cast<const gsl::byte *>(input),
reinterpret_cast<gsl::byte *>(output), shape, in_strides, out_strides, scale.as_r4(), bias.as_r4(), kernel_context());
reinterpret_cast<gsl::byte *>(output), shape, in_strides, out_strides, scale.as_r4(), bias.as_r4(), module().kernel_context());
}

View File

@ -12,7 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/tensor_compute.h>
#include <nncase/runtime/interpreter.h>
#include <nncase/runtime/runtime_op_utility.h>
@ -21,17 +21,17 @@ using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(const tensor_gather_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const tensor_gather_op_t &op) noexcept
{
try_var(indices, pop_addr());
try_var(output, pop_addr());
try_var(input, pop_addr());
auto &in_shape = shape_regs_[op.rshape_src];
auto &in_strides = shape_regs_[op.rstride_src];
auto &out_shape = shape_regs_[op.rshape_dest];
auto &out_strides = shape_regs_[op.rstride_dest];
auto &indices_shape = shape_regs_[op.rshape_indices];
try_var(in_shape, module().shape_reg(op.rshape_src));
try_var(in_strides, module().shape_reg(op.rstride_src));
try_var(out_shape, module().shape_reg(op.rshape_dest));
try_var(out_strides, module().shape_reg(op.rstride_dest));
try_var(indices_shape, module().shape_reg(op.rshape_indices));
return kernels::gather(op.datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output), in_shape, out_shape,
in_strides, out_strides, reinterpret_cast<const int32_t *>(indices), indices_shape, op.axis);

View File

@ -12,7 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/tensor_compute.h>
#include <nncase/runtime/interpreter.h>
#include <nncase/runtime/runtime_op_utility.h>
@ -21,17 +21,17 @@ using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(const tensor_gather_nd_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const tensor_gather_nd_op_t &op) noexcept
{
try_var(indices, pop_addr());
try_var(output, pop_addr());
try_var(input, pop_addr());
auto &in_shape = shape_regs_[op.rshape_src];
auto &in_strides = shape_regs_[op.rstride_src];
auto &out_shape = shape_regs_[op.rshape_dest];
auto &out_strides = shape_regs_[op.rstride_dest];
auto &indices_shape = shape_regs_[op.rshape_indices];
try_var(in_shape, module().shape_reg(op.rshape_src));
try_var(in_strides, module().shape_reg(op.rstride_src));
try_var(out_shape, module().shape_reg(op.rshape_dest));
try_var(out_strides, module().shape_reg(op.rstride_dest));
try_var(indices_shape, module().shape_reg(op.rshape_indices));
return kernels::gather_nd(op.datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output), in_shape, out_shape,
in_strides, out_strides, reinterpret_cast<const int32_t *>(indices), indices_shape, op.batch_dims);

View File

@ -12,23 +12,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/tensor_compute.h>
using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(const tensor_lut1d_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const tensor_lut1d_op_t &op) noexcept
{
try_var(max_value, pop_scalar(op.datatype));
try_var(min_value, pop_scalar(op.datatype));
try_var(output, pop_addr());
try_var(table, pop_addr());
try_var(input, pop_addr());
auto &shape = shape_regs_[op.rshape_src];
auto &in_strides = shape_regs_[op.rstride_src];
auto &out_strides = shape_regs_[op.rstride_dest];
try_var(shape, module().shape_reg(op.rshape_src));
try_var(in_strides, module().shape_reg(op.rstride_src));
try_var(out_strides, module().shape_reg(op.rstride_dest));
return kernels::lut1d(op.datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<const gsl::byte *>(table),
reinterpret_cast<gsl::byte *>(output), shape, in_strides, out_strides, min_value, max_value);

View File

@ -13,7 +13,7 @@
* limitations under the License.
*/
#include "../runtime_module.h"
#include "../runtime_function.h"
#include <nncase/kernels/tensor_compute.h>
#include <nncase/runtime/interpreter.h>
#include <nncase/runtime/runtime_op_utility.h>
@ -22,7 +22,7 @@ using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;
result<void> stackvm_runtime_module::visit(const tensor_onehot_op_t &op) noexcept
result<void> stackvm_runtime_function::visit(const tensor_onehot_op_t &op) noexcept
{
try_var(output, pop_addr());
try_var(off_value, pop_addr());
@ -30,11 +30,11 @@ result<void> stackvm_runtime_module::visit(const tensor_onehot_op_t &op) noexcep
try_var(depth, pop_addr());
try_var(indices, pop_addr());
auto &indices_shape = shape_regs_[op.rshape_indices];
auto &out_shape = shape_regs_[op.rshape_dest];
auto &out_strides = shape_regs_[op.rstride_dest];
try_var(indices_shape, module().shape_reg(op.rshape_indices));
try_var(out_shape, module().shape_reg(op.rshape_dest));
try_var(out_strides, module().shape_reg(op.rstride_dest));
return kernels::onehot(op.datatype, reinterpret_cast<const int32_t *>(indices), reinterpret_cast<gsl::byte *>(output),
indices_shape, out_shape, out_strides, reinterpret_cast<gsl::byte *>(depth), reinterpret_cast<gsl::byte *>(off_value),
reinterpret_cast<gsl::byte *>(on_value), op.axis, op.onehot_mode);
reinterpret_cast<gsl::byte *>(on_value), op.axis, op.onehot_mode, module().kernel_context());
}

Some files were not shown because too many files have changed in this diff Show More