mirror of https://github.com/kendryte/nncase.git
Split runtime function from runtime module (#337)
* Update * Update * Update * Pass unary * Merge from master * Fix build errors * Fix build errors * Fix build errors * Fix schedule * Fix build errors * Update benchmark modelspull/340/head
parent
7ad9920682
commit
3bfbd61ccf
|
@ -24,6 +24,7 @@ import nncase
|
|||
import requests
|
||||
import onnxsim
|
||||
import onnx
|
||||
from io import BytesIO
|
||||
|
||||
TEMP_DIR = "tmp"
|
||||
MODEL_DIR = "models"
|
||||
|
@ -45,7 +46,7 @@ def _download(url, name, in_shapes):
|
|||
if not os.path.exists(filename):
|
||||
req = requests.get(url)
|
||||
onnx_model, check = onnxsim.simplify(
|
||||
onnx.load(req.content), check_n=3, input_shapes=in_shapes)
|
||||
onnx.load_model(BytesIO(req.content)), check_n=3, input_shapes=in_shapes)
|
||||
assert check, "Simplified ONNX model could not be validated"
|
||||
onnx.save(onnx_model, filename)
|
||||
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -25,7 +25,7 @@ struct build_model_result
|
|||
class NNCASE_API model_builder
|
||||
{
|
||||
public:
|
||||
model_builder(target &target, const schedule::schedule_result &sched);
|
||||
model_builder(target &target, const schedule::model_schedule_result &sched);
|
||||
model_builder(model_builder &) = delete;
|
||||
model_builder(model_builder &&) = delete;
|
||||
|
||||
|
@ -36,7 +36,7 @@ public:
|
|||
|
||||
private:
|
||||
target &target_;
|
||||
const schedule::schedule_result &sched_;
|
||||
const schedule::model_schedule_result &sched_;
|
||||
std::filesystem::path dump_dir_;
|
||||
bool dump_asm_;
|
||||
};
|
||||
|
|
|
@ -35,10 +35,16 @@ public:
|
|||
|
||||
struct module_builder_params
|
||||
{
|
||||
const schedule::schedule_result &sched;
|
||||
const schedule::model_schedule_result &model_sched;
|
||||
const schedule::module_schedule_result &module_sched;
|
||||
};
|
||||
|
||||
struct function_call_id
|
||||
{
|
||||
size_t module_id;
|
||||
size_t function_id;
|
||||
};
|
||||
|
||||
class NNCASE_API module_builder
|
||||
{
|
||||
private:
|
||||
|
@ -80,22 +86,27 @@ public:
|
|||
section_writer &writer(std::string_view section_name);
|
||||
|
||||
virtual module_type_t module_type() const noexcept = 0;
|
||||
virtual uint32_t module_version() const noexcept = 0;
|
||||
virtual std::unique_ptr<section_decompiler> create_decompiler(std::string_view section_name);
|
||||
|
||||
protected:
|
||||
void merge_to_rdata_section(std::string_view from);
|
||||
size_t module_id(ir::graph *graph);
|
||||
function_call_id function_id(ir::graph *graph);
|
||||
void set_current_entry_point(std::streampos pos);
|
||||
void set_current_function_text_end(std::streampos pos);
|
||||
|
||||
virtual void begin_emit() { }
|
||||
virtual void begin_emit_module();
|
||||
virtual void begin_emit_function(const schedule::function_schedule_result &function);
|
||||
virtual void end_emit_function(const schedule::function_schedule_result &function);
|
||||
virtual void emit(ir::node &node);
|
||||
virtual void end_emit() { }
|
||||
virtual void end_emit_module();
|
||||
|
||||
protected:
|
||||
std::filesystem::path dump_dir_;
|
||||
bool dump_asm_;
|
||||
|
||||
private:
|
||||
std::vector<nncase::ir::node *> generate_runtime_ops();
|
||||
std::vector<nncase::ir::node *> generate_current_runtime_ops();
|
||||
void compile();
|
||||
void decompile(std::string_view stage, std::string_view section_name, std::span<const uint8_t> input, std::span<const symbol> symbols);
|
||||
|
||||
|
@ -105,6 +116,7 @@ private:
|
|||
void write_symbol_refs();
|
||||
void link();
|
||||
void write_binary(binary_writer &writer);
|
||||
void write_function_binary(binary_writer &writer, const schedule::function_schedule_result &function_sched);
|
||||
|
||||
private:
|
||||
uint32_t alignment_;
|
||||
|
@ -113,5 +125,9 @@ private:
|
|||
std::map<std::string, section, std::less<>> section_writer_;
|
||||
std::map<std::string, rdata_merge_info, std::less<>> rdata_section_merges_;
|
||||
std::unordered_map<std::string_view, std::pair<size_t, std::string_view>> symbol_offsets_;
|
||||
|
||||
const schedule::function_schedule_result *current_function_;
|
||||
std::unordered_map<const schedule::function_schedule_result *, std::streampos> entry_points_;
|
||||
std::unordered_map<const schedule::function_schedule_result *, std::streampos> function_text_end_;
|
||||
};
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/7/14 19:17:48 +08:00.
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/8/11 17:40:11 +08:00.
|
||||
*
|
||||
* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
|
@ -944,6 +944,7 @@ struct op_writer<nncase::runtime::stackvm::tensor_call_op_t>
|
|||
{
|
||||
writer.write(static_cast<uint8_t>(op.opcode));
|
||||
writer.write(static_cast<uint16_t>(op.funct));
|
||||
writer.write(op.function_id);
|
||||
writer.write(op.module_id);
|
||||
writer.write(op.num_src);
|
||||
writer.write(op.num_dst);
|
||||
|
@ -1322,7 +1323,7 @@ public:
|
|||
void tensor_batch_to_space_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_block, uint8_t rpad_crops);
|
||||
void tensor_broadcast_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rshape_dest, uint8_t rstride_dest);
|
||||
void tensor_binary_(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rstride_dest, binary_op_t binary_op, float fused_clamp_low, float fused_clamp_high);
|
||||
void tensor_call_(uint32_t module_id, uint8_t num_src, uint8_t num_dst);
|
||||
void tensor_call_(uint32_t function_id, uint16_t module_id, uint8_t num_src, uint8_t num_dst);
|
||||
void tensor_conv2d_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rshape_kernel, uint8_t rstride_kernel, uint8_t rstride_bias, uint8_t rstride_dest, uint16_t groups, uint16_t stride_h, uint16_t stride_w, uint16_t dilation_h, uint16_t dilation_w, float fused_clamp_low, float fused_clamp_high);
|
||||
void tensor_copy_(datatype_t datatype, uint8_t rshape, uint8_t rstride_src, uint8_t rstride_dest);
|
||||
void tensor_convert_(datatype_t in_datatype, datatype_t dst_datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest);
|
||||
|
|
|
@ -116,6 +116,7 @@ protected:
|
|||
virtual void process(const std::vector<uint8_t> &src, float *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) = 0;
|
||||
virtual void process(const std::vector<uint8_t> &src, uint8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) = 0;
|
||||
virtual void process(const std::vector<uint8_t> &src, int8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) = 0;
|
||||
virtual bool do_normalize() const noexcept { return true; }
|
||||
|
||||
private:
|
||||
template <class T>
|
||||
|
@ -170,5 +171,6 @@ protected:
|
|||
void process(const std::vector<uint8_t> &src, float *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) override;
|
||||
void process(const std::vector<uint8_t> &src, uint8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) override;
|
||||
void process(const std::vector<uint8_t> &src, int8_t *dest, const xt::dynamic_shape<size_t> &shape, std::string layout) override;
|
||||
bool do_normalize() const noexcept override { return false; }
|
||||
};
|
||||
}
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "evaluate_types.h"
|
||||
#include "quantizer.h"
|
||||
#include <nncase/schedule/schedule_types.h>
|
||||
|
||||
namespace nncase
|
||||
{
|
||||
class target;
|
||||
}
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class module_evaluate_context;
|
||||
class model_evaluate_context;
|
||||
|
||||
class NNCASE_API function_evaluate_context
|
||||
{
|
||||
public:
|
||||
function_evaluate_context(const schedule::function_schedule_result &sched, module_evaluate_context &mod_eval);
|
||||
function_evaluate_context(const function_evaluate_context &) = delete;
|
||||
function_evaluate_context(function_evaluate_context &&) = default;
|
||||
|
||||
evaluate_tensor memory_at(const output_connector &conn);
|
||||
|
||||
evaluate_tensor memory_at(const input_connector &conn)
|
||||
{
|
||||
return memory_at(*conn.connection());
|
||||
}
|
||||
|
||||
evaluate_tensor input_at(size_t index)
|
||||
{
|
||||
return memory_at(*inputs_[index]);
|
||||
}
|
||||
|
||||
evaluate_tensor output_at(size_t index)
|
||||
{
|
||||
return memory_at(*outputs_[index]);
|
||||
}
|
||||
|
||||
module_evaluate_context &module() const noexcept { return mod_eval_; }
|
||||
|
||||
void evaluate();
|
||||
|
||||
private:
|
||||
const schedule::function_schedule_result &sched_;
|
||||
module_evaluate_context &mod_eval_;
|
||||
std::unique_ptr<std::byte[]> input_pool_;
|
||||
std::unique_ptr<std::byte[]> output_pool_;
|
||||
|
||||
std::vector<output_connector *> inputs_;
|
||||
std::vector<input_connector *> outputs_;
|
||||
};
|
||||
|
||||
class NNCASE_API module_evaluate_context
|
||||
{
|
||||
public:
|
||||
module_evaluate_context(const schedule::module_schedule_result &sched, model_evaluate_context &model_eval);
|
||||
module_evaluate_context(module_evaluate_context &) = delete;
|
||||
module_evaluate_context(module_evaluate_context &&) = default;
|
||||
|
||||
const schedule::module_schedule_result &sched() const noexcept { return sched_; }
|
||||
std::byte *memory_pool(memory_location_t location) const;
|
||||
ir::quantizer *quantizer() noexcept { return quantizer_.get(); }
|
||||
function_evaluate_context &function(ir::graph &function);
|
||||
model_evaluate_context &model() const noexcept { return model_eval_; }
|
||||
|
||||
void enable_ptq(target &target, ir::calibrate_method calib_method);
|
||||
void begin_collect_distribution();
|
||||
void end_sample();
|
||||
void end_collect_distribution(const std::function<void(size_t cnt, size_t total)> &progress);
|
||||
|
||||
private:
|
||||
const schedule::module_schedule_result &sched_;
|
||||
model_evaluate_context &model_eval_;
|
||||
std::unordered_map<memory_location_t, std::unique_ptr<std::byte[]>> memory_pools_;
|
||||
|
||||
std::vector<output_connector *> inputs_;
|
||||
std::vector<input_connector *> outputs_;
|
||||
std::unique_ptr<ir::quantizer> quantizer_;
|
||||
std::unordered_map<ir::graph *, function_evaluate_context> functions_;
|
||||
};
|
||||
|
||||
class NNCASE_API model_evaluate_context
|
||||
{
|
||||
public:
|
||||
model_evaluate_context(const schedule::model_schedule_result &sched);
|
||||
model_evaluate_context(const model_evaluate_context &) = delete;
|
||||
model_evaluate_context(model_evaluate_context &&) = default;
|
||||
|
||||
function_evaluate_context &entrypoint();
|
||||
module_evaluate_context &module(const module_type_t &module_type);
|
||||
|
||||
evaluate_tensor memory_at(const output_connector &conn)
|
||||
{
|
||||
return entrypoint().memory_at(conn);
|
||||
}
|
||||
|
||||
evaluate_tensor memory_at(const input_connector &conn)
|
||||
{
|
||||
return memory_at(*conn.connection());
|
||||
}
|
||||
|
||||
evaluate_tensor input_at(size_t index)
|
||||
{
|
||||
return entrypoint().input_at(index);
|
||||
}
|
||||
|
||||
evaluate_tensor output_at(size_t index)
|
||||
{
|
||||
return entrypoint().output_at(index);
|
||||
}
|
||||
|
||||
void enable_ptq(nncase::target &target, ir::calibrate_method calib_method);
|
||||
void begin_collect_distribution();
|
||||
void end_sample();
|
||||
void end_collect_distribution(const std::function<void(size_t cnt, size_t total)> &progress);
|
||||
|
||||
void evaluate();
|
||||
|
||||
private:
|
||||
const schedule::model_schedule_result &sched_;
|
||||
std::unordered_map<module_type_t, module_evaluate_context> module_ctxs_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include <nncase/runtime/datatypes.h>
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class NNCASE_API evaluate_tensor
|
||||
{
|
||||
public:
|
||||
evaluate_tensor(datatype_t datatype, runtime_shape_t shape, runtime_shape_t strides, gsl::span<gsl::byte> buffer);
|
||||
|
||||
datatype_t datatype() const noexcept { return datatype_; }
|
||||
const runtime_shape_t &shape() const noexcept { return shape_; }
|
||||
const runtime_shape_t &strides() const noexcept { return strides_; }
|
||||
gsl::span<gsl::byte> buffer() const noexcept { return buffer_; }
|
||||
|
||||
private:
|
||||
datatype_t datatype_;
|
||||
runtime_shape_t shape_;
|
||||
runtime_shape_t strides_;
|
||||
gsl::span<gsl::byte> buffer_;
|
||||
};
|
||||
}
|
|
@ -13,114 +13,35 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
#include <nncase/ir/graph.h>
|
||||
#include <nncase/ir/op_utils.h>
|
||||
#include <nncase/ir/quantizer.h>
|
||||
#include <nncase/kernels/kernel_context.h>
|
||||
#include <nncase/runtime/compiler_defs.h>
|
||||
#include <nncase/schedule/scheduler.h>
|
||||
#include <unordered_map>
|
||||
#include "evaluate_context.h"
|
||||
#include "evaluate_types.h"
|
||||
|
||||
namespace nncase::ir
|
||||
{
|
||||
class quantizer;
|
||||
|
||||
class NNCASE_API evaluate_tensor
|
||||
{
|
||||
public:
|
||||
evaluate_tensor(datatype_t datatype, runtime_shape_t shape, runtime_shape_t strides, gsl::span<gsl::byte> buffer);
|
||||
|
||||
datatype_t datatype() const noexcept { return datatype_; }
|
||||
const runtime_shape_t &shape() const noexcept { return shape_; }
|
||||
const runtime_shape_t &strides() const noexcept { return strides_; }
|
||||
gsl::span<gsl::byte> buffer() const noexcept { return buffer_; }
|
||||
|
||||
private:
|
||||
datatype_t datatype_;
|
||||
runtime_shape_t shape_;
|
||||
runtime_shape_t strides_;
|
||||
gsl::span<gsl::byte> buffer_;
|
||||
};
|
||||
|
||||
class NNCASE_API module_evaluate_context
|
||||
{
|
||||
public:
|
||||
module_evaluate_context(const schedule::module_schedule_result &sched);
|
||||
module_evaluate_context(module_evaluate_context &) = delete;
|
||||
module_evaluate_context(module_evaluate_context &&) = default;
|
||||
|
||||
evaluate_tensor memory_at(const output_connector &conn);
|
||||
|
||||
evaluate_tensor memory_at(const input_connector &conn)
|
||||
{
|
||||
return memory_at(*conn.connection());
|
||||
}
|
||||
|
||||
evaluate_tensor input_at(size_t index)
|
||||
{
|
||||
return memory_at(*inputs_[index]);
|
||||
}
|
||||
|
||||
evaluate_tensor output_at(size_t index)
|
||||
{
|
||||
return memory_at(*outputs_[index]);
|
||||
}
|
||||
|
||||
ir::quantizer *quantizer() noexcept { return quantizer_.get(); }
|
||||
|
||||
void enable_ptq(target &target, ir::calibrate_method calib_method);
|
||||
void evaluate();
|
||||
|
||||
void begin_collect_distribution();
|
||||
void end_collect_distribution(std::function<void(size_t cnt, size_t total)> progress);
|
||||
|
||||
private:
|
||||
const schedule::module_schedule_result &sched_;
|
||||
std::unordered_map<memory_location_t, std::unique_ptr<std::byte[]>> memory_pools_;
|
||||
|
||||
std::vector<output_connector *> inputs_;
|
||||
std::vector<input_connector *> outputs_;
|
||||
std::unique_ptr<ir::quantizer> quantizer_;
|
||||
};
|
||||
|
||||
class NNCASE_API evaluator
|
||||
{
|
||||
public:
|
||||
evaluator(const schedule::schedule_result &sched);
|
||||
evaluator(const schedule::model_schedule_result &sched);
|
||||
evaluator(evaluator &) = delete;
|
||||
evaluator(evaluator &&) = default;
|
||||
|
||||
module_evaluate_context &module_context(ir::graph &graph);
|
||||
module_evaluate_context &main_module_context();
|
||||
|
||||
void enable_ptq(target &target, ir::calibrate_method calib_method);
|
||||
void evaluate();
|
||||
|
||||
ir::quantizer *quantizer(const module_type_t &module_type);
|
||||
void begin_collect_distribution();
|
||||
void end_collect_distribution(std::function<void(size_t cnt, size_t total)> progress);
|
||||
void end_sample();
|
||||
void end_collect_distribution(const std::function<void(size_t cnt, size_t total)> &progress);
|
||||
|
||||
evaluate_tensor memory_at(const output_connector &conn);
|
||||
evaluate_tensor memory_at(const input_connector &conn);
|
||||
|
||||
evaluate_tensor memory_at(const input_connector &conn)
|
||||
{
|
||||
return memory_at(*conn.connection());
|
||||
}
|
||||
|
||||
evaluate_tensor input_at(size_t index)
|
||||
{
|
||||
return main_module_context().input_at(index);
|
||||
}
|
||||
|
||||
evaluate_tensor output_at(size_t index)
|
||||
{
|
||||
return main_module_context().output_at(index);
|
||||
}
|
||||
evaluate_tensor input_at(size_t index);
|
||||
evaluate_tensor output_at(size_t index);
|
||||
|
||||
private:
|
||||
const schedule::schedule_result &sched_;
|
||||
std::unordered_map<ir::graph *, module_evaluate_context> module_ctxs_;
|
||||
model_evaluate_context model_eval_;
|
||||
};
|
||||
|
||||
NNCASE_API void register_evaluator(ir::node_opcode opcode, std::function<void(ir::node &, module_evaluate_context &)> evaluator);
|
||||
NNCASE_API void register_evaluator(ir::node_opcode opcode, std::function<void(ir::node &, function_evaluate_context &)> evaluator);
|
||||
}
|
||||
|
|
|
@ -124,7 +124,7 @@ public:
|
|||
void begin_collect_distribution();
|
||||
void end_collect_distribution(std::function<void(size_t cnt, size_t total)> progress);
|
||||
size_t histograms_count() const noexcept { return histograms_.size(); }
|
||||
void reset_record() { has_record_.clear(); }
|
||||
void end_sample() { has_record_.clear(); }
|
||||
|
||||
private:
|
||||
calibrate_method cali_method_;
|
||||
|
|
|
@ -325,6 +325,8 @@ NNCASE_INLINE_VAR constexpr memory_location_t mem_input = 0;
|
|||
NNCASE_INLINE_VAR constexpr memory_location_t mem_output = 1;
|
||||
NNCASE_INLINE_VAR constexpr memory_location_t mem_rdata = 2;
|
||||
NNCASE_INLINE_VAR constexpr memory_location_t mem_data = 3;
|
||||
NNCASE_INLINE_VAR constexpr memory_location_t mem_shared_data = 4;
|
||||
NNCASE_INLINE_VAR constexpr memory_location_t mem_private_base = 64;
|
||||
|
||||
using runtime_shape_t = itlib::small_vector<size_t, 4>;
|
||||
using runtime_axis_t = itlib::small_vector<int32_t, 4>;
|
||||
|
@ -402,7 +404,7 @@ struct memory_range
|
|||
{
|
||||
memory_location_t memory_location;
|
||||
datatype_t datatype;
|
||||
uint16_t reserved0;
|
||||
uint16_t shared_module;
|
||||
uint32_t start;
|
||||
uint32_t size;
|
||||
};
|
||||
|
@ -463,3 +465,16 @@ inline bool operator!=(const scalar &lhs, const scalar &rhs) noexcept
|
|||
return lhs.type != rhs.type || memcmp(&lhs.storage, &rhs.storage, valid_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
struct std::hash<nncase::module_type_t>
|
||||
{
|
||||
auto operator()(const nncase::module_type_t &key) const noexcept
|
||||
{
|
||||
size_t result = 0;
|
||||
const size_t prime = 31;
|
||||
for (auto c : key)
|
||||
result = c + (result * prime);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -74,7 +74,7 @@ public:
|
|||
|
||||
private:
|
||||
std::vector<std::unique_ptr<runtime_module>> modules_;
|
||||
runtime_module *main_module_;
|
||||
runtime_function *entry_function_;
|
||||
options_dict options_;
|
||||
};
|
||||
|
||||
|
|
|
@ -24,26 +24,49 @@ struct model_header
|
|||
{
|
||||
uint32_t identifier;
|
||||
uint32_t version;
|
||||
uint32_t header_size;
|
||||
uint32_t flags;
|
||||
uint32_t alignment;
|
||||
uint32_t modules;
|
||||
uint32_t main_module;
|
||||
uint32_t entry_module;
|
||||
uint32_t entry_function;
|
||||
};
|
||||
|
||||
struct function_header
|
||||
{
|
||||
uint32_t header_size;
|
||||
uint32_t size;
|
||||
uint32_t input_pool_size;
|
||||
uint32_t output_pool_size;
|
||||
uint32_t inputs;
|
||||
uint32_t outputs;
|
||||
uint32_t entrypoint;
|
||||
uint32_t text_size;
|
||||
};
|
||||
|
||||
struct module_header
|
||||
{
|
||||
module_type_t type;
|
||||
uint32_t version;
|
||||
uint32_t header_size;
|
||||
uint32_t size;
|
||||
uint32_t mempools;
|
||||
uint32_t inputs;
|
||||
uint32_t outputs;
|
||||
uint32_t shared_mempools;
|
||||
uint32_t sections;
|
||||
uint32_t functions;
|
||||
uint32_t reserved0;
|
||||
};
|
||||
|
||||
struct mempool_desc
|
||||
{
|
||||
memory_location_t location;
|
||||
uint8_t reserved0[3];
|
||||
uint32_t size;
|
||||
};
|
||||
|
||||
struct shared_mempool_desc
|
||||
{
|
||||
uint32_t module;
|
||||
uint32_t size;
|
||||
};
|
||||
|
||||
|
@ -51,8 +74,8 @@ struct section_header
|
|||
{
|
||||
char name[MAX_SECTION_NAME_LENGTH];
|
||||
uint32_t flags;
|
||||
uint32_t start;
|
||||
uint32_t size;
|
||||
uint32_t body_start;
|
||||
uint32_t body_size;
|
||||
uint32_t reserved0;
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "model.h"
|
||||
#include "result.h"
|
||||
#include "runtime_tensor.h"
|
||||
|
||||
BEGIN_NS_NNCASE_RUNTIME
|
||||
|
||||
class interpreter;
|
||||
class runtime_module;
|
||||
struct runtime_module_init_context;
|
||||
|
||||
struct NNCASE_API runtime_function_init_context
|
||||
{
|
||||
virtual runtime_module_init_context &module_init_context() noexcept = 0;
|
||||
virtual const function_header &header() noexcept = 0;
|
||||
virtual gsl::span<const gsl::byte> body() noexcept = 0;
|
||||
};
|
||||
|
||||
class NNCASE_API runtime_function
|
||||
{
|
||||
private:
|
||||
struct inout_tensor_info
|
||||
{
|
||||
runtime_shape_t shape;
|
||||
runtime_shape_t strides;
|
||||
memory_range range;
|
||||
runtime_tensor bind_tensor;
|
||||
runtime_tensor staging_tensor;
|
||||
runtime_tensor device_tensor;
|
||||
};
|
||||
|
||||
public:
|
||||
runtime_function(runtime_module &rt_module);
|
||||
runtime_function(const runtime_function &) = delete;
|
||||
virtual ~runtime_function() = default;
|
||||
runtime_function &operator=(const runtime_function &) = delete;
|
||||
|
||||
result<void> initialize(gsl::span<const gsl::byte> payload, runtime_module_init_context &module_init_context) noexcept;
|
||||
runtime_module &module() const noexcept;
|
||||
|
||||
uint32_t inputs_size() const noexcept;
|
||||
const runtime_shape_t &input_shape(size_t index) const noexcept;
|
||||
const memory_range &input_desc(size_t index) const noexcept;
|
||||
result<runtime_tensor> input_tensor(size_t index) noexcept;
|
||||
result<void> input_tensor(size_t index, runtime_tensor tensor) noexcept;
|
||||
|
||||
uint32_t outputs_size() const noexcept;
|
||||
const runtime_shape_t &output_shape(size_t index) const noexcept;
|
||||
const memory_range &output_desc(size_t index) const noexcept;
|
||||
result<runtime_tensor> output_tensor(size_t index) noexcept;
|
||||
result<void> output_tensor(size_t index, runtime_tensor tensor) noexcept;
|
||||
|
||||
result<void> invoke() noexcept;
|
||||
|
||||
protected:
|
||||
virtual result<void> initialize_core(runtime_function_init_context &context) noexcept = 0;
|
||||
virtual result<runtime_tensor> allocate_input_tensor(size_t index) noexcept = 0;
|
||||
virtual result<runtime_tensor> allocate_output_tensor(size_t index) noexcept = 0;
|
||||
virtual result<void> validate_input_tensor(size_t index, runtime_tensor tensor) noexcept = 0;
|
||||
virtual result<void> validate_output_tensor(size_t index, runtime_tensor tensor) noexcept = 0;
|
||||
result<runtime_tensor> device_input_tensor(size_t index) noexcept;
|
||||
result<runtime_tensor> device_output_tensor(size_t index) noexcept;
|
||||
virtual result<void> invoke_core() noexcept = 0;
|
||||
|
||||
private:
|
||||
function_header header_;
|
||||
std::vector<inout_tensor_info> input_tensors_;
|
||||
std::vector<inout_tensor_info> output_tensors_;
|
||||
runtime_module &rt_module_;
|
||||
};
|
||||
|
||||
END_NS_NNCASE_RUNTIME
|
|
@ -15,6 +15,7 @@
|
|||
#pragma once
|
||||
#include "model.h"
|
||||
#include "result.h"
|
||||
#include "runtime_function.h"
|
||||
#include "runtime_tensor.h"
|
||||
|
||||
BEGIN_NS_NNCASE_RUNTIME
|
||||
|
@ -31,26 +32,15 @@ struct NNCASE_API runtime_module_init_context
|
|||
|
||||
class NNCASE_API runtime_module
|
||||
{
|
||||
private:
|
||||
struct inout_tensor_info
|
||||
{
|
||||
runtime_shape_t shape;
|
||||
runtime_shape_t strides;
|
||||
memory_range range;
|
||||
runtime_tensor bind_tensor;
|
||||
runtime_tensor staging_tensor;
|
||||
runtime_tensor device_tensor;
|
||||
};
|
||||
|
||||
public:
|
||||
static result<std::unique_ptr<runtime_module>> create(const module_type_t &type);
|
||||
|
||||
runtime_module() = default;
|
||||
runtime_module(runtime_module &) = delete;
|
||||
runtime_module(const runtime_module &) = delete;
|
||||
virtual ~runtime_module() = default;
|
||||
runtime_module &operator=(const runtime_module &) = delete;
|
||||
|
||||
result<void> initialize(const module_header &header, interpreter &interp) noexcept;
|
||||
virtual result<void> initialize_inter_modules(interpreter &interp) noexcept;
|
||||
result<void> initialize(gsl::span<const gsl::byte> payload, interpreter &interp) noexcept;
|
||||
const module_type_t &type() const noexcept;
|
||||
|
||||
interpreter &interp() const noexcept { return *interp_; }
|
||||
|
@ -59,35 +49,20 @@ public:
|
|||
const mempool_desc &mempool(size_t index) const noexcept;
|
||||
mempool_desc mempool(memory_location_t location) const noexcept;
|
||||
|
||||
uint32_t inputs_size() const noexcept;
|
||||
const runtime_shape_t &input_shape(size_t index) const noexcept;
|
||||
const memory_range &input_desc(size_t index) const noexcept;
|
||||
result<runtime_tensor> input_tensor(size_t index) noexcept;
|
||||
result<void> input_tensor(size_t index, runtime_tensor tensor) noexcept;
|
||||
|
||||
uint32_t outputs_size() const noexcept;
|
||||
const runtime_shape_t &output_shape(size_t index) const noexcept;
|
||||
const memory_range &output_desc(size_t index) const noexcept;
|
||||
result<runtime_tensor> output_tensor(size_t index) noexcept;
|
||||
result<void> output_tensor(size_t index, runtime_tensor tensor) noexcept;
|
||||
|
||||
result<void> run() noexcept;
|
||||
result<runtime_function *> find_function_by_id(size_t index) noexcept;
|
||||
|
||||
protected:
|
||||
virtual result<void> initialize_core(runtime_module_init_context &context) noexcept = 0;
|
||||
virtual result<runtime_tensor> allocate_input_tensor(size_t index) noexcept = 0;
|
||||
virtual result<runtime_tensor> allocate_output_tensor(size_t index) noexcept = 0;
|
||||
virtual result<void> validate_input_tensor(size_t index, runtime_tensor tensor) noexcept = 0;
|
||||
virtual result<void> validate_output_tensor(size_t index, runtime_tensor tensor) noexcept = 0;
|
||||
result<runtime_tensor> device_input_tensor(size_t index) noexcept;
|
||||
result<runtime_tensor> device_output_tensor(size_t index) noexcept;
|
||||
virtual result<void> run_core() noexcept = 0;
|
||||
virtual result<void> initialize_before_functions(runtime_module_init_context &context) noexcept;
|
||||
virtual result<void> initialize_after_functions(runtime_module_init_context &context) noexcept;
|
||||
virtual result<std::unique_ptr<runtime_function>> create_function() noexcept = 0;
|
||||
|
||||
gsl::span<std::unique_ptr<runtime_function>> functions() noexcept { return functions_; }
|
||||
|
||||
private:
|
||||
module_header header_;
|
||||
std::vector<mempool_desc> mempools_;
|
||||
std::vector<inout_tensor_info> input_tensors_;
|
||||
std::vector<inout_tensor_info> output_tensors_;
|
||||
std::vector<mempool_desc> shared_mempools_;
|
||||
std::vector<std::unique_ptr<runtime_function>> functions_;
|
||||
interpreter *interp_ = nullptr;
|
||||
};
|
||||
|
||||
|
|
|
@ -23,17 +23,20 @@ inline constexpr size_t get_bytes(datatype_t type)
|
|||
return nncase::detail::datatype_bytes(type);
|
||||
}
|
||||
|
||||
inline size_t compute_size(const runtime_shape_t &shape)
|
||||
template <class TShape>
|
||||
inline size_t compute_size(const TShape &shape)
|
||||
{
|
||||
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
|
||||
}
|
||||
|
||||
inline size_t get_bytes(datatype_t type, const runtime_shape_t &shape)
|
||||
template <class TShape>
|
||||
inline size_t get_bytes(datatype_t type, const TShape &shape)
|
||||
{
|
||||
return compute_size(shape) * get_bytes(type);
|
||||
}
|
||||
|
||||
inline size_t compute_size(const runtime_shape_t &shape, const runtime_shape_t &strides)
|
||||
template <class TShape>
|
||||
inline size_t compute_size(const TShape &shape, const TShape &strides)
|
||||
{
|
||||
size_t max_stride = 0, max_shape = 0;
|
||||
for (size_t i = 0; i < shape.size(); i++)
|
||||
|
@ -48,7 +51,8 @@ inline size_t compute_size(const runtime_shape_t &shape, const runtime_shape_t &
|
|||
return size ? size : 1;
|
||||
}
|
||||
|
||||
inline size_t get_bytes(datatype_t type, const runtime_shape_t &shape, const runtime_shape_t &strides)
|
||||
template <class TShape>
|
||||
inline size_t get_bytes(datatype_t type, const TShape &shape, const TShape &strides)
|
||||
{
|
||||
return compute_size(shape, strides) * get_bytes(type);
|
||||
}
|
||||
|
|
|
@ -88,18 +88,16 @@ public:
|
|||
}
|
||||
|
||||
template <class T>
|
||||
T peek()
|
||||
T peek_with_offset(size_t offset)
|
||||
{
|
||||
auto value = *reinterpret_cast<const T *>(span_.data());
|
||||
auto value = *reinterpret_cast<const T *>(span_.data() + offset);
|
||||
return value;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T peek_unaligned()
|
||||
T peek()
|
||||
{
|
||||
T value;
|
||||
std::memcpy(&value, span_.data(), sizeof(T));
|
||||
return value;
|
||||
return peek_with_offset<T>(0);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
|
@ -110,6 +108,12 @@ public:
|
|||
return value;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T peek_unaligned()
|
||||
{
|
||||
return peek_unaligned_with_offset<T>(0);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
const T *get_ref()
|
||||
{
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/7/14 19:17:48 +08:00.
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/8/11 17:40:11 +08:00.
|
||||
*
|
||||
* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
|
@ -1141,7 +1141,8 @@ struct op_reader<tensor_call_op_t>
|
|||
tensor_call_op_t op(default_init);
|
||||
op.opcode = static_cast<opcode_t>(reader.read_unaligned<uint8_t>());
|
||||
op.funct = static_cast<tensor_function_t>(reader.read_unaligned<uint16_t>());
|
||||
op.module_id = reader.read_unaligned<uint32_t>();
|
||||
op.function_id = reader.read_unaligned<uint32_t>();
|
||||
op.module_id = reader.read_unaligned<uint16_t>();
|
||||
op.num_src = reader.read_unaligned<uint8_t>();
|
||||
op.num_dst = reader.read_unaligned<uint8_t>();
|
||||
return op;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/7/14 19:17:48 +08:00.
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/8/11 17:40:11 +08:00.
|
||||
*
|
||||
* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
|
@ -1265,13 +1265,14 @@ struct tensor_call_op_t
|
|||
{
|
||||
opcode_t opcode;
|
||||
tensor_function_t funct;
|
||||
uint32_t module_id;
|
||||
uint32_t function_id;
|
||||
uint16_t module_id;
|
||||
uint8_t num_src;
|
||||
uint8_t num_dst;
|
||||
|
||||
tensor_call_op_t(default_init_t) noexcept { }
|
||||
explicit tensor_call_op_t(uint32_t module_id, uint8_t num_src, uint8_t num_dst) noexcept
|
||||
: opcode(opcode_t::TENSOR), funct(tensor_function_t::CALL), module_id(module_id), num_src(num_src), num_dst(num_dst)
|
||||
explicit tensor_call_op_t(uint32_t function_id, uint16_t module_id, uint8_t num_src, uint8_t num_dst) noexcept
|
||||
: opcode(opcode_t::TENSOR), funct(tensor_function_t::CALL), function_id(function_id), module_id(module_id), num_src(num_src), num_dst(num_dst)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
BEGIN_NS_NNCASE_RT_MODULE(stackvm)
|
||||
|
||||
NNCASE_INLINE_VAR constexpr module_type_t stackvm_module_type = to_module_type("stackvm");
|
||||
NNCASE_INLINE_VAR constexpr uint32_t stackvm_module_version = 1;
|
||||
|
||||
NNCASE_API result<std::unique_ptr<runtime_module>> create_stackvm_runtime_module();
|
||||
|
||||
|
|
|
@ -73,4 +73,5 @@ private:
|
|||
};
|
||||
|
||||
using allocator_map_t = std::unordered_map<memory_location_t, buffer_allocator *>;
|
||||
using shared_allocator_map_t = std::unordered_map<module_type_t, buffer_allocator *>;
|
||||
}
|
||||
|
|
|
@ -75,6 +75,9 @@ public:
|
|||
memory_location_t memory_location() const noexcept { return memory_location_; }
|
||||
memory_location_t &memory_location() noexcept { return memory_location_; }
|
||||
|
||||
module_type_t shared_module() const noexcept { return shared_module_; }
|
||||
void shared_module(const module_type_t &type) noexcept { shared_module_ = type; }
|
||||
|
||||
bool no_action_concat_with_strides() const noexcept { return no_action_concat_with_strides_; }
|
||||
bool &no_action_concat_with_strides() noexcept { return no_action_concat_with_strides_; }
|
||||
|
||||
|
@ -82,6 +85,7 @@ private:
|
|||
size_t id_;
|
||||
ir::output_connector &owner_;
|
||||
memory_location_t memory_location_;
|
||||
module_type_t shared_module_;
|
||||
std::optional<sub_buffer_desc> parent_;
|
||||
ir::shape_t strides_shape_;
|
||||
buffer_lifetime lifetime_ {};
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "schedule_types.h"
|
||||
|
||||
namespace nncase::schedule
|
||||
{
|
||||
class lifetime_recorder
|
||||
{
|
||||
public:
|
||||
lifetime_recorder(std::list<logical_buffer> &buffers, std::unordered_map<const ir::output_connector *, logical_buffer *> &buffer_map);
|
||||
|
||||
size_t current_age() const noexcept { return cnt_age_; }
|
||||
void current_age(size_t age);
|
||||
|
||||
void allocate(ir::output_connector &conn, memory_location_t location);
|
||||
void release(ir::output_connector &conn);
|
||||
void grow_age();
|
||||
|
||||
private:
|
||||
size_t next_buffer_id_ = 0;
|
||||
size_t cnt_age_ = 0;
|
||||
std::list<logical_buffer> &buffers_;
|
||||
std::unordered_map<const ir::output_connector *, logical_buffer *> &buffer_map_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,135 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "buffer_allocator.h"
|
||||
#include "liveness_analysis.h"
|
||||
#include "schedule_types.h"
|
||||
#include <filesystem>
|
||||
|
||||
namespace nncase
|
||||
{
|
||||
class target;
|
||||
}
|
||||
|
||||
namespace nncase::schedule
|
||||
{
|
||||
class module_schedule_context;
|
||||
class model_schedule_context;
|
||||
|
||||
struct caller_context
|
||||
{
|
||||
lifetime_recorder &lifetime;
|
||||
};
|
||||
|
||||
class function_schedule_context : public function_schedule_result
|
||||
{
|
||||
public:
|
||||
function_schedule_context(ir::graph &graph, module_schedule_context &mod_sched);
|
||||
function_schedule_context(const function_schedule_context &) = delete;
|
||||
function_schedule_context(function_schedule_context &&) = default;
|
||||
|
||||
function_schedule_context &operator=(const function_schedule_context &) = delete;
|
||||
|
||||
const module_type_t &module_type() const noexcept { return graph->module_type(); }
|
||||
std::span<ir::output_node *> outputs() const noexcept { return outputs_; }
|
||||
std::unordered_map<const ir::output_connector *, logical_buffer *> &logical_buffer_map() noexcept { return logical_buffer_map_; }
|
||||
std::list<logical_buffer> &logical_buffers() noexcept { return logical_buffers_; }
|
||||
std::vector<physical_buffer> &physical_buffers() noexcept { return physical_buffers_; }
|
||||
|
||||
void visit_function(caller_context &caller_ctx);
|
||||
void end_schedule();
|
||||
|
||||
private:
|
||||
void create_allocators();
|
||||
void generate_compute_sequence();
|
||||
void make_logical_buffers(caller_context &caller_ctx);
|
||||
void analyze_buffer_alias();
|
||||
void update_offset();
|
||||
void fix_lifetime();
|
||||
void make_physical_buffers();
|
||||
void allocate_physical_buffers();
|
||||
void assign_allocations();
|
||||
|
||||
void dump(const std::filesystem::path &dump_dir);
|
||||
|
||||
private:
|
||||
module_schedule_context &mod_sched_;
|
||||
std::span<ir::output_node *> outputs_;
|
||||
allocator_map_t allocators_;
|
||||
std::vector<std::shared_ptr<buffer_allocator>> allocator_holder_;
|
||||
std::unordered_map<const ir::output_connector *, logical_buffer *> logical_buffer_map_;
|
||||
std::list<logical_buffer> logical_buffers_;
|
||||
std::vector<physical_buffer> physical_buffers_;
|
||||
};
|
||||
|
||||
class module_schedule_context
|
||||
{
|
||||
public:
|
||||
module_schedule_context(module_schedule_result &result, model_schedule_context &model_sched, module_type_t type);
|
||||
module_schedule_context(const module_schedule_context &) = delete;
|
||||
module_schedule_context(module_schedule_context &&) = default;
|
||||
|
||||
module_schedule_context &operator=(const module_schedule_context &) = delete;
|
||||
|
||||
module_schedule_result &module_result() const noexcept { return result_; }
|
||||
model_schedule_context &model_sched() const noexcept { return model_sched_; }
|
||||
allocator_map_t &allocators() noexcept { return allocators_; }
|
||||
buffer_allocator &shared_allocator(const module_type_t &type);
|
||||
|
||||
void visit_function(ir::graph &graph, caller_context &caller_ctx);
|
||||
void end_schedule();
|
||||
|
||||
private:
|
||||
module_schedule_result &result_;
|
||||
model_schedule_context &model_sched_;
|
||||
module_type_t type_;
|
||||
allocator_map_t allocators_;
|
||||
std::vector<std::shared_ptr<buffer_allocator>> allocator_holder_;
|
||||
shared_allocator_map_t shared_allocators_;
|
||||
std::vector<function_schedule_context> functions_;
|
||||
std::filesystem::path dump_dir_;
|
||||
};
|
||||
|
||||
class model_schedule_context
|
||||
{
|
||||
public:
|
||||
model_schedule_context(model_schedule_result &result, nncase::target &target, bool skip_buffer_alias);
|
||||
model_schedule_context(const model_schedule_context &) = delete;
|
||||
model_schedule_context(model_schedule_context &&) = default;
|
||||
|
||||
model_schedule_context &operator=(const model_schedule_context &) = delete;
|
||||
|
||||
nncase::target &target() const noexcept { return target_; }
|
||||
bool skip_buffer_alias() const noexcept { return skip_buffer_alias_; }
|
||||
void config_dump(std::filesystem::path dump_dir);
|
||||
const std::filesystem::path &dump_dir() const noexcept { return dump_dir_; }
|
||||
model_schedule_result &model_result() const noexcept { return result_; }
|
||||
|
||||
void schedule(ir::graph &entry_function);
|
||||
void visit_function(ir::graph &graph, caller_context &caller_ctx);
|
||||
|
||||
private:
|
||||
void end_schedule();
|
||||
|
||||
private:
|
||||
model_schedule_result &result_;
|
||||
nncase::target &target_;
|
||||
bool skip_buffer_alias_;
|
||||
std::filesystem::path dump_dir_;
|
||||
module_schedule_context *entry_module_;
|
||||
ir::graph *entry_function_;
|
||||
std::unordered_map<module_type_t, module_schedule_context> module_contexts_;
|
||||
};
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "buffer_allocator.h"
|
||||
#include "buffers.h"
|
||||
#include <nncase/ir/graph.h>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace nncase::schedule
|
||||
{
|
||||
struct buffer_allocation
|
||||
{
|
||||
memory_location_t memory_location;
|
||||
datatype_t type;
|
||||
size_t shared_module;
|
||||
size_t start;
|
||||
size_t size;
|
||||
ir::shape_t shape;
|
||||
ir::shape_t strides;
|
||||
ir::shape_t strides_shape;
|
||||
|
||||
size_t linear_end() const noexcept { return start + size; }
|
||||
|
||||
bool overlap(const buffer_allocation &rhs) const noexcept
|
||||
{
|
||||
return size != 0 && rhs.size != 0 && memory_location == rhs.memory_location && (start < rhs.linear_end() && linear_end() > rhs.start);
|
||||
}
|
||||
|
||||
memory_range runtime_type() const
|
||||
{
|
||||
return { .memory_location = memory_location, .datatype = type, .shared_module = (uint16_t)shared_module, .start = (uint32_t)start, .size = (uint32_t)size };
|
||||
}
|
||||
};
|
||||
|
||||
using allocation_map_t = std::unordered_map<const ir::output_connector *, buffer_allocation>;
|
||||
struct module_schedule_result;
|
||||
|
||||
struct function_schedule_result
|
||||
{
|
||||
ir::graph *graph;
|
||||
module_schedule_result *module;
|
||||
std::vector<ir::node *> compute_sequence;
|
||||
size_t input_pool_size;
|
||||
size_t output_pool_size;
|
||||
};
|
||||
|
||||
struct module_schedule_result
|
||||
{
|
||||
module_type_t type;
|
||||
std::vector<function_schedule_result> functions;
|
||||
std::unordered_map<ir::graph *, function_schedule_result *> functions_map;
|
||||
allocation_map_t allocations;
|
||||
std::unordered_map<memory_location_t, size_t> max_usages;
|
||||
std::unordered_map<module_type_t, size_t> shared_max_usages;
|
||||
};
|
||||
|
||||
struct model_schedule_result
|
||||
{
|
||||
std::vector<module_schedule_result> modules;
|
||||
function_schedule_result *entry_function;
|
||||
};
|
||||
}
|
|
@ -13,14 +13,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "buffer_allocator.h"
|
||||
#include "buffers.h"
|
||||
#include "schedule_context.h"
|
||||
#include "schedule_types.h"
|
||||
#include <filesystem>
|
||||
#include <functional>
|
||||
#include <nncase/ir/graph.h>
|
||||
#include <span>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace nncase
|
||||
{
|
||||
|
@ -28,78 +24,14 @@ class target;
|
|||
|
||||
namespace schedule
|
||||
{
|
||||
struct buffer_allocation
|
||||
{
|
||||
memory_location_t memory_location;
|
||||
datatype_t type;
|
||||
size_t start;
|
||||
size_t size;
|
||||
ir::shape_t shape;
|
||||
ir::shape_t strides;
|
||||
ir::shape_t strides_shape;
|
||||
|
||||
size_t linear_end() const noexcept { return start + size; }
|
||||
|
||||
bool overlap(const buffer_allocation &rhs) const noexcept
|
||||
{
|
||||
return size != 0 && rhs.size != 0 && memory_location == rhs.memory_location && (start < rhs.linear_end() && linear_end() > rhs.start);
|
||||
}
|
||||
|
||||
memory_range runtime_type() const
|
||||
{
|
||||
return { .memory_location = memory_location, .datatype = type, .start = (uint32_t)start, .size = (uint32_t)size };
|
||||
}
|
||||
};
|
||||
|
||||
using allocation_map_t = std::unordered_map<const ir::output_connector *, buffer_allocation>;
|
||||
|
||||
struct module_schedule_result
|
||||
{
|
||||
ir::graph *graph;
|
||||
std::vector<ir::node *> compute_sequence;
|
||||
std::unordered_map<memory_location_t, size_t> max_usages;
|
||||
allocation_map_t allocations;
|
||||
};
|
||||
|
||||
struct schedule_result
|
||||
{
|
||||
std::vector<ir::graph *> graph_orders;
|
||||
std::unordered_map<ir::graph *, module_schedule_result> modules;
|
||||
ir::graph *main_module;
|
||||
};
|
||||
|
||||
struct schedule_context : module_schedule_result
|
||||
{
|
||||
bool skip_buffer_alias = false;
|
||||
nncase::target *target;
|
||||
module_type_t module_type;
|
||||
std::span<ir::output_node *> outputs;
|
||||
std::unordered_map<const ir::output_connector *, logical_buffer *> logical_buffer_map;
|
||||
std::list<logical_buffer> logical_buffers;
|
||||
std::vector<physical_buffer> physical_buffers;
|
||||
|
||||
void generate_compute_sequence();
|
||||
void make_logical_buffers();
|
||||
void analyze_buffer_alias();
|
||||
void update_offset();
|
||||
void fix_lifetime();
|
||||
void make_physical_buffers();
|
||||
void allocate_physical_buffers();
|
||||
void assign_allocations();
|
||||
};
|
||||
|
||||
class NNCASE_API scheduler
|
||||
{
|
||||
public:
|
||||
scheduler(target &target, ir::graph &main_graph, std::span<ir::output_node *> outputs)
|
||||
: target_(target), main_graph_(main_graph), outputs_(outputs) { }
|
||||
scheduler(target &target, ir::graph &main_graph, std::span<ir::output_node *> outputs);
|
||||
|
||||
schedule_result schedule(bool skip_buffer_alias = false);
|
||||
model_schedule_result schedule(bool skip_buffer_alias = false);
|
||||
void config_dump(std::filesystem::path dump_dir);
|
||||
|
||||
private:
|
||||
void dump_schedule(const schedule_context &context);
|
||||
|
||||
private:
|
||||
target &target_;
|
||||
ir::graph &main_graph_;
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace nncase::ir::transforms
|
|||
struct run_pass_options
|
||||
{
|
||||
ir::quantizer *quantizer;
|
||||
schedule::schedule_context *schedule_context;
|
||||
schedule::function_schedule_context *schedule_context;
|
||||
std::optional<std::filesystem::path> dump_dir;
|
||||
};
|
||||
|
||||
|
@ -105,14 +105,14 @@ public:
|
|||
|
||||
void dump_dir(const std::filesystem::path &dir);
|
||||
void quantizer(ir::quantizer *q);
|
||||
void schedule_context(schedule::schedule_context *c);
|
||||
void schedule_context(schedule::function_schedule_context *c);
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<pass>> passes_;
|
||||
graph &graph_;
|
||||
nncase::target &target_;
|
||||
ir::quantizer *quantizer_;
|
||||
schedule::schedule_context *schedule_context_;
|
||||
schedule::function_schedule_context *schedule_context_;
|
||||
std::optional<std::filesystem::path> dump_dir_;
|
||||
};
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ class target;
|
|||
|
||||
namespace schedule
|
||||
{
|
||||
struct schedule_context;
|
||||
class function_schedule_context;
|
||||
}
|
||||
|
||||
namespace ir
|
||||
|
@ -45,7 +45,7 @@ namespace ir
|
|||
ir::graph &graph;
|
||||
nncase::target ⌖
|
||||
ir::quantizer *quantizer;
|
||||
schedule::schedule_context *schedule_context;
|
||||
schedule::function_schedule_context *schedule_context;
|
||||
std::optional<std::filesystem::path> dump_dir;
|
||||
std::vector<node *> matched_nodes;
|
||||
std::vector<input_connector *> inputs;
|
||||
|
|
|
@ -25,18 +25,6 @@
|
|||
#define NNCASE_MODULES_K210_API __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
#define BEGIN_NS_NNCASE_RT_K210 \
|
||||
namespace nncase \
|
||||
{ \
|
||||
namespace runtime \
|
||||
{ \
|
||||
namespace k210 \
|
||||
{
|
||||
#define END_NS_NNCASE_RT_K210 \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
#define BEGIN_NS_NNCASE_KERNELS_K210 \
|
||||
namespace nncase \
|
||||
{ \
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
#include "compiler_defs.h"
|
||||
#include <nncase/runtime/error.h>
|
||||
|
||||
BEGIN_NS_NNCASE_RT_K210
|
||||
BEGIN_NS_NNCASE_RT_MODULE(k210)
|
||||
|
||||
enum class nncase_k210_errc
|
||||
{
|
||||
|
@ -26,7 +26,7 @@ enum class nncase_k210_errc
|
|||
NNCASE_MODULES_K210_API const std::error_category &nncase_k210_category() noexcept;
|
||||
NNCASE_MODULES_K210_API std::error_condition make_error_condition(nncase_k210_errc code);
|
||||
|
||||
END_NS_NNCASE_RT_K210
|
||||
END_NS_NNCASE_RT_MODULE
|
||||
|
||||
namespace std
|
||||
{
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include <nncase/runtime/result.h>
|
||||
#include <nncase/runtime/span_reader.h>
|
||||
|
||||
BEGIN_NS_NNCASE_RT_K210
|
||||
BEGIN_NS_NNCASE_RT_MODULE(k210)
|
||||
|
||||
class NNCASE_MODULES_K210_API op_visitor
|
||||
{
|
||||
|
@ -44,4 +44,4 @@ private:
|
|||
result<void> next() noexcept;
|
||||
};
|
||||
|
||||
END_NS_NNCASE_RT_K210
|
||||
END_NS_NNCASE_RT_MODULE
|
||||
|
|
|
@ -16,10 +16,11 @@
|
|||
#include "compiler_defs.h"
|
||||
#include <nncase/runtime/runtime_module.h>
|
||||
|
||||
BEGIN_NS_NNCASE_RT_K210
|
||||
BEGIN_NS_NNCASE_RT_MODULE(k210)
|
||||
|
||||
NNCASE_INLINE_VAR constexpr module_type_t k210_module_type = to_module_type("k210");
|
||||
NNCASE_INLINE_VAR constexpr uint32_t k210_module_version = 1;
|
||||
|
||||
NNCASE_MODULES_K210_API result<std::unique_ptr<runtime_module>> create_k210_runtime_module();
|
||||
|
||||
END_NS_NNCASE_RT_K210
|
||||
END_NS_NNCASE_RT_MODULE
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
#pragma once
|
||||
#include "runtime_types.h"
|
||||
|
||||
BEGIN_NS_NNCASE_RT_K210
|
||||
BEGIN_NS_NNCASE_RT_MODULE(k210)
|
||||
|
||||
struct kpu_layout
|
||||
{
|
||||
|
@ -184,4 +184,4 @@ inline std::array<int32_t, 2> get_kpu_select_pool_offset(kpu_pool_type_t pool_ty
|
|||
}
|
||||
}
|
||||
|
||||
END_NS_NNCASE_RT_K210
|
||||
END_NS_NNCASE_RT_MODULE
|
||||
|
|
|
@ -16,9 +16,9 @@
|
|||
#include "compiler_defs.h"
|
||||
#include <nncase/runtime/datatypes.h>
|
||||
|
||||
BEGIN_NS_NNCASE_RT_K210
|
||||
BEGIN_NS_NNCASE_RT_MODULE(k210)
|
||||
|
||||
NNCASE_INLINE_VAR constexpr memory_location_t mem_kpu = 4;
|
||||
NNCASE_INLINE_VAR constexpr memory_location_t mem_kpu = mem_private_base + 0;
|
||||
NNCASE_INLINE_VAR constexpr size_t KPU_RAM_SIZE = 2 * 1024 * 1024; // 2MB
|
||||
|
||||
typedef struct
|
||||
|
@ -341,4 +341,4 @@ struct copy_options
|
|||
kpu_shape_t out_strides;
|
||||
};
|
||||
|
||||
END_NS_NNCASE_RT_K210
|
||||
END_NS_NNCASE_RT_MODULE
|
||||
|
|
|
@ -38,11 +38,26 @@ module_type_t k210_module_builder::module_type() const noexcept
|
|||
return k210_module_type;
|
||||
}
|
||||
|
||||
uint32_t k210_module_builder::module_version() const noexcept
|
||||
{
|
||||
return k210_module_version;
|
||||
}
|
||||
|
||||
section_writer &k210_module_builder::text_writer()
|
||||
{
|
||||
return writer(".text");
|
||||
}
|
||||
|
||||
void k210_module_builder::begin_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
|
||||
{
|
||||
set_current_entry_point(text_writer().position());
|
||||
}
|
||||
|
||||
void k210_module_builder::end_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
|
||||
{
|
||||
set_current_function_text_end(text_writer().position());
|
||||
}
|
||||
|
||||
void k210_module_builder::emit(ir::node &node)
|
||||
{
|
||||
#define DEFINE_OP(op) \
|
||||
|
|
|
@ -28,10 +28,13 @@ public:
|
|||
k210_module_builder(std::string_view module_name, const module_builder_params ¶ms);
|
||||
|
||||
module_type_t module_type() const noexcept override;
|
||||
uint32_t module_version() const noexcept override;
|
||||
|
||||
protected:
|
||||
section_writer &text_writer();
|
||||
|
||||
void begin_emit_function(const schedule::function_schedule_result &function) override;
|
||||
void end_emit_function(const schedule::function_schedule_result &function) override;
|
||||
void emit(ir::node &node) override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -27,7 +27,7 @@ using namespace nncase::runtime;
|
|||
|
||||
void ir::k210::register_k210_evaluators()
|
||||
{
|
||||
register_evaluator(op_k210_fake_kpu_conv2d, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_k210_fake_kpu_conv2d, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<fake_kpu_conv2d &>(node);
|
||||
|
||||
assert(rnode.input().type() == dt_float32);
|
||||
|
@ -42,13 +42,13 @@ void ir::k210::register_k210_evaluators()
|
|||
|
||||
auto in_shape = input.shape();
|
||||
shape_t conv_out_shape { in_shape[0], (size_t)rnode.output_channels(), in_shape[2], in_shape[3] };
|
||||
auto conv_out_fmap_size = xt::compute_size(conv_out_shape);
|
||||
auto conv_out_fmap_size = runtime::compute_size(conv_out_shape);
|
||||
|
||||
auto conv_output_tmp = std::make_unique<float[]>(conv_out_fmap_size);
|
||||
auto batch = in_shape[0];
|
||||
auto in_size_per_batch = xt::compute_size(in_shape) / batch;
|
||||
auto in_size_per_batch = runtime::compute_size(in_shape) / batch;
|
||||
auto conv_output_tmp_size_per_batch = conv_out_fmap_size / batch;
|
||||
auto out_size_per_batch = xt::compute_size(rnode.output().shape()) / batch;
|
||||
auto out_size_per_batch = runtime::compute_size(rnode.output().shape()) / batch;
|
||||
auto p_input = input_mem.data();
|
||||
auto p_conv_ouput_tmp = conv_output_tmp.get();
|
||||
auto p_output = output_mem.data();
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
add_subdirectory(kendryte)
|
||||
|
||||
set(SRCS runtime_module.cpp
|
||||
runtime_function.cpp
|
||||
op_reader.cpp
|
||||
error.cpp
|
||||
ops/kpu_conv2d.cpp
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/k210/k210_kernels.h>
|
||||
#include <nncase/kernels/tensor_compute.h>
|
||||
|
||||
|
@ -20,7 +20,7 @@ using namespace nncase;
|
|||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::k210;
|
||||
|
||||
result<void> k210_runtime_module::visit(const copy_options &op) noexcept
|
||||
result<void> k210_runtime_function::visit(const copy_options &op) noexcept
|
||||
{
|
||||
try_var(input, memory_at(op.input));
|
||||
try_var(output, memory_at(op.output));
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/k210/k210_kernels.h>
|
||||
#include <nncase/runtime/dbg.h>
|
||||
#ifndef NNCASE_SIMULATOR
|
||||
|
@ -71,7 +71,7 @@ void kpu_conv2d_normal(runtime::k210::kpu_layer_argument_t &layer, plic_irq_call
|
|||
#endif
|
||||
}
|
||||
|
||||
result<void> k210_runtime_module::visit(const kpu_conv2d_options &op) noexcept
|
||||
result<void> k210_runtime_function::visit(const kpu_conv2d_options &op) noexcept
|
||||
{
|
||||
auto layer = op.layer;
|
||||
|
||||
|
|
|
@ -12,14 +12,14 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/k210/k210_kernels.h>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::k210;
|
||||
|
||||
result<void> k210_runtime_module::visit(const kpu_download_options &op) noexcept
|
||||
result<void> k210_runtime_function::visit(const kpu_download_options &op) noexcept
|
||||
{
|
||||
try_var(input, memory_at(op.input));
|
||||
try_var(output, memory_at(op.output));
|
||||
|
|
|
@ -12,14 +12,14 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/k210/k210_kernels.h>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::k210;
|
||||
|
||||
result<void> k210_runtime_module::visit(const kpu_upload_options &op) noexcept
|
||||
result<void> k210_runtime_function::visit(const kpu_upload_options &op) noexcept
|
||||
{
|
||||
try_var(input, memory_at(op.input));
|
||||
try_var(output, memory_at(op.output));
|
||||
|
@ -30,7 +30,7 @@ result<void> k210_runtime_module::visit(const kpu_upload_options &op) noexcept
|
|||
#ifdef NNCASE_SIMULATOR
|
||||
0
|
||||
#else
|
||||
dma_ch_
|
||||
module().dma_ch()
|
||||
#endif
|
||||
);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,145 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "runtime_function.h"
|
||||
#include <nncase/runtime/host_runtime_tensor.h>
|
||||
#include <nncase/runtime/interpreter.h>
|
||||
#include <nncase/runtime/k210/error.h>
|
||||
#include <nncase/runtime/k210/runtime_types.h>
|
||||
#include <nncase/runtime/runtime_loader.h>
|
||||
#ifndef NNCASE_SIMULATOR
|
||||
#include <kpu.h>
|
||||
#endif
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::detail;
|
||||
using namespace nncase::runtime::k210;
|
||||
|
||||
k210_runtime_module &k210_runtime_function::module() const noexcept
|
||||
{
|
||||
return static_cast<k210_runtime_module &>(runtime_function::module());
|
||||
}
|
||||
|
||||
result<void> k210_runtime_function::initialize_core(runtime_function_init_context &context) noexcept
|
||||
{
|
||||
text_ = context.module_init_context().section(".text").subspan(context.header().entrypoint, context.header().text_size);
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<runtime_tensor> k210_runtime_function::allocate_input_tensor(size_t index) noexcept
|
||||
{
|
||||
return hrt::create(input_desc(index).datatype, input_shape(index), hrt::pool_shared);
|
||||
}
|
||||
|
||||
result<runtime_tensor> k210_runtime_function::allocate_output_tensor(size_t index) noexcept
|
||||
{
|
||||
return hrt::create(output_desc(index).datatype, output_shape(index), hrt::pool_shared);
|
||||
}
|
||||
|
||||
result<void> k210_runtime_function::validate_input_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
|
||||
{
|
||||
if (tensor.is_host()
|
||||
&& hrt::memory_pool(tensor).unwrap() == hrt::pool_shared)
|
||||
return ok();
|
||||
return err(std::errc::invalid_argument);
|
||||
}
|
||||
|
||||
result<void> k210_runtime_function::validate_output_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
|
||||
{
|
||||
if (tensor.is_host())
|
||||
return ok();
|
||||
return err(std::errc::invalid_argument);
|
||||
}
|
||||
|
||||
result<void> k210_runtime_function::invoke_core() noexcept
|
||||
{
|
||||
for (size_t i = 0; i < inputs_size(); i++)
|
||||
{
|
||||
try_var(input, device_input_tensor(i));
|
||||
try_(hrt::sync(input, hrt::sync_write_back));
|
||||
}
|
||||
|
||||
try_(visit(text_));
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<gsl::span<gsl::byte>> k210_runtime_function::memory_at(const memory_range &mrange) noexcept
|
||||
{
|
||||
#define ID_NOT_FOUND ((size_t)-1)
|
||||
gsl::byte *base;
|
||||
switch (mrange.memory_location)
|
||||
{
|
||||
case mem_input:
|
||||
{
|
||||
size_t id = ID_NOT_FOUND;
|
||||
for (size_t i = 0; i < inputs_size(); i++)
|
||||
{
|
||||
if (mrange.start == input_desc(i).start)
|
||||
{
|
||||
id = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (id != ID_NOT_FOUND)
|
||||
{
|
||||
try_var(tensor, device_input_tensor(id));
|
||||
base = reinterpret_cast<gsl::byte *>(static_cast<host_runtime_tensor_impl *>(tensor.impl())->memory_block().virtual_address - mrange.start);
|
||||
}
|
||||
else
|
||||
{
|
||||
return err(std::errc::invalid_argument);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case mem_output:
|
||||
{
|
||||
size_t id = ID_NOT_FOUND;
|
||||
for (size_t i = 0; i < outputs_size(); i++)
|
||||
{
|
||||
if (mrange.start == output_desc(i).start)
|
||||
{
|
||||
id = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (id != ID_NOT_FOUND)
|
||||
{
|
||||
try_var(tensor, device_output_tensor(id));
|
||||
try_var(tensor_map, hrt::map(tensor, hrt::map_read_write));
|
||||
base = tensor_map.buffer().data() - mrange.start;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err(std::errc::invalid_argument);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case mem_rdata:
|
||||
base = const_cast<gsl::byte *>(module().rdata().data());
|
||||
break;
|
||||
case mem_data:
|
||||
base = module().data().data();
|
||||
break;
|
||||
case mem_kpu:
|
||||
base = module().kpu_ram().data();
|
||||
break;
|
||||
default:
|
||||
return err(nncase_errc::invalid_memory_location);
|
||||
}
|
||||
|
||||
return ok(gsl::make_span(base + mrange.start, mrange.size));
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "runtime_module.h"
|
||||
#include <nncase/runtime/k210/op_reader.h>
|
||||
#include <nncase/runtime/k210/runtime_types.h>
|
||||
#include <nncase/runtime/runtime_function.h>
|
||||
|
||||
BEGIN_NS_NNCASE_RT_MODULE(k210)
|
||||
|
||||
class k210_runtime_function : public runtime_function, private op_visitor
|
||||
{
|
||||
public:
|
||||
using runtime_function::runtime_function;
|
||||
|
||||
k210_runtime_module &module() const noexcept;
|
||||
|
||||
protected:
|
||||
result<void> initialize_core(runtime_function_init_context &context) noexcept override;
|
||||
result<runtime_tensor> allocate_input_tensor(size_t index) noexcept override;
|
||||
result<runtime_tensor> allocate_output_tensor(size_t index) noexcept override;
|
||||
result<void> validate_input_tensor(size_t index, runtime_tensor tensor) noexcept override;
|
||||
result<void> validate_output_tensor(size_t index, runtime_tensor tensor) noexcept override;
|
||||
result<void> invoke_core() noexcept override;
|
||||
|
||||
using op_visitor::visit;
|
||||
result<void> visit(const kpu_conv2d_options &op) noexcept override;
|
||||
result<void> visit(const kpu_download_options &op) noexcept override;
|
||||
result<void> visit(const kpu_upload_options &op) noexcept override;
|
||||
result<void> visit(const copy_options &op) noexcept override;
|
||||
|
||||
private:
|
||||
result<gsl::span<gsl::byte>> memory_at(const memory_range &mrange) noexcept;
|
||||
|
||||
private:
|
||||
gsl::span<const gsl::byte> text_;
|
||||
};
|
||||
|
||||
END_NS_NNCASE_RT_MODULE
|
|
@ -13,6 +13,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "runtime_module.h"
|
||||
#include "runtime_function.h"
|
||||
#include <nncase/runtime/host_runtime_tensor.h>
|
||||
#include <nncase/runtime/interpreter.h>
|
||||
#include <nncase/runtime/k210/error.h>
|
||||
|
@ -27,7 +28,7 @@ using namespace nncase::runtime;
|
|||
using namespace nncase::runtime::detail;
|
||||
using namespace nncase::runtime::k210;
|
||||
|
||||
result<void> k210_runtime_module::initialize_core(runtime_module_init_context &context) noexcept
|
||||
result<void> k210_runtime_module::initialize_before_functions(runtime_module_init_context &context) noexcept
|
||||
{
|
||||
#ifndef NNCASE_SIMULATOR
|
||||
kpu->interrupt_clear.reg = 7;
|
||||
|
@ -56,127 +57,33 @@ result<void> k210_runtime_module::initialize_core(runtime_module_init_context &c
|
|||
return ok();
|
||||
}
|
||||
|
||||
result<runtime_tensor> k210_runtime_module::allocate_input_tensor(size_t index) noexcept
|
||||
gsl::span<gsl::byte> k210_runtime_module::data() const noexcept
|
||||
{
|
||||
return hrt::create(input_desc(index).datatype, input_shape(index), hrt::pool_shared);
|
||||
return { data_.get(), mempool(mem_data).size };
|
||||
}
|
||||
|
||||
result<runtime_tensor> k210_runtime_module::allocate_output_tensor(size_t index) noexcept
|
||||
gsl::span<gsl::byte> k210_runtime_module::kpu_ram() noexcept
|
||||
{
|
||||
return hrt::create(output_desc(index).datatype, output_shape(index), hrt::pool_shared);
|
||||
}
|
||||
|
||||
result<void> k210_runtime_module::validate_input_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
|
||||
{
|
||||
if (tensor.is_host()
|
||||
&& hrt::memory_pool(tensor).unwrap() == hrt::pool_shared)
|
||||
return ok();
|
||||
return err(std::errc::invalid_argument);
|
||||
}
|
||||
|
||||
result<void> k210_runtime_module::validate_output_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
|
||||
{
|
||||
if (tensor.is_host())
|
||||
return ok();
|
||||
return err(std::errc::invalid_argument);
|
||||
}
|
||||
|
||||
result<void> k210_runtime_module::run_core() noexcept
|
||||
{
|
||||
for (size_t i = 0; i < inputs_size(); i++)
|
||||
{
|
||||
try_var(input, device_input_tensor(i));
|
||||
try_(hrt::sync(input, hrt::sync_write_back));
|
||||
}
|
||||
|
||||
#ifndef NNCASE_SIMULATOR
|
||||
auto dma_ch = interp().options().get<uint32_t>("dma_ch");
|
||||
if (dma_ch.is_ok())
|
||||
{
|
||||
dma_ch_ = dma_ch.unwrap();
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("[WARN] KPU DMA channel not set, default to DMAC_CHANNEL5.\n");
|
||||
dma_ch_ = 5;
|
||||
}
|
||||
#endif
|
||||
|
||||
try_(visit(text_));
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<gsl::span<gsl::byte>> k210_runtime_module::memory_at(const memory_range &mrange) noexcept
|
||||
{
|
||||
#define ID_NOT_FOUND ((size_t)-1)
|
||||
gsl::byte *base;
|
||||
switch (mrange.memory_location)
|
||||
{
|
||||
case mem_input:
|
||||
{
|
||||
size_t id = ID_NOT_FOUND;
|
||||
for (size_t i = 0; i < inputs_size(); i++)
|
||||
{
|
||||
if (mrange.start == input_desc(i).start)
|
||||
{
|
||||
id = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (id != ID_NOT_FOUND)
|
||||
{
|
||||
try_var(tensor, device_input_tensor(id));
|
||||
base = reinterpret_cast<gsl::byte *>(static_cast<host_runtime_tensor_impl *>(tensor.impl())->memory_block().virtual_address - mrange.start);
|
||||
}
|
||||
else
|
||||
{
|
||||
return err(std::errc::invalid_argument);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case mem_output:
|
||||
{
|
||||
size_t id = ID_NOT_FOUND;
|
||||
for (size_t i = 0; i < outputs_size(); i++)
|
||||
{
|
||||
if (mrange.start == output_desc(i).start)
|
||||
{
|
||||
id = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (id != ID_NOT_FOUND)
|
||||
{
|
||||
try_var(tensor, device_output_tensor(id));
|
||||
try_var(tensor_map, hrt::map(tensor, hrt::map_read_write));
|
||||
base = tensor_map.buffer().data() - mrange.start;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err(std::errc::invalid_argument);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case mem_rdata:
|
||||
base = const_cast<gsl::byte *>(rdata_.data());
|
||||
break;
|
||||
case mem_data:
|
||||
base = data_.get();
|
||||
break;
|
||||
case mem_kpu:
|
||||
#ifdef NNCASE_SIMULATOR
|
||||
base = kpu_ram_.data();
|
||||
base = kpu_ram_.data();
|
||||
#else
|
||||
base = reinterpret_cast<gsl::byte *>(AI_IO_BASE_ADDR);
|
||||
base = reinterpret_cast<gsl::byte *>(AI_IO_BASE_ADDR);
|
||||
#endif
|
||||
break;
|
||||
default:
|
||||
return err(nncase_errc::invalid_memory_location);
|
||||
}
|
||||
return { base, KPU_RAM_SIZE };
|
||||
}
|
||||
|
||||
return ok(gsl::make_span(base + mrange.start, mrange.size));
|
||||
gsl::span<const gsl::byte> k210_runtime_module::rdata() const noexcept
|
||||
{
|
||||
return rdata_;
|
||||
}
|
||||
|
||||
result<std::unique_ptr<runtime_function>> k210_runtime_module::create_function() noexcept
|
||||
{
|
||||
std::unique_ptr<runtime_function> mod(new (std::nothrow) k210_runtime_function(*this));
|
||||
if (mod)
|
||||
return ok(std::move(mod));
|
||||
return err(std::errc::not_enough_memory);
|
||||
}
|
||||
|
||||
result<std::unique_ptr<runtime_module>> k210::create_k210_runtime_module()
|
||||
|
|
|
@ -13,31 +13,28 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include <nncase/runtime/k210/op_reader.h>
|
||||
#include <nncase/runtime/k210/runtime_module.h>
|
||||
#include <nncase/runtime/k210/runtime_types.h>
|
||||
|
||||
BEGIN_NS_NNCASE_RT_K210
|
||||
BEGIN_NS_NNCASE_RT_MODULE(k210)
|
||||
|
||||
class k210_runtime_module : public runtime_module, private op_visitor
|
||||
class k210_runtime_module : public runtime_module
|
||||
{
|
||||
public:
|
||||
gsl::span<gsl::byte> data() const noexcept;
|
||||
gsl::span<const gsl::byte> rdata() const noexcept;
|
||||
gsl::span<gsl::byte> kpu_ram() noexcept;
|
||||
|
||||
#if !NNCASE_SIMULATOR
|
||||
uint32_t dma_ch() const noexcept
|
||||
{
|
||||
return dma_ch_;
|
||||
}
|
||||
#endif
|
||||
|
||||
protected:
|
||||
result<void> initialize_core(runtime_module_init_context &context) noexcept override;
|
||||
result<runtime_tensor> allocate_input_tensor(size_t index) noexcept override;
|
||||
result<runtime_tensor> allocate_output_tensor(size_t index) noexcept override;
|
||||
result<void> validate_input_tensor(size_t index, runtime_tensor tensor) noexcept override;
|
||||
result<void> validate_output_tensor(size_t index, runtime_tensor tensor) noexcept override;
|
||||
result<void> run_core() noexcept override;
|
||||
|
||||
using op_visitor::visit;
|
||||
result<void> visit(const kpu_conv2d_options &op) noexcept override;
|
||||
result<void> visit(const kpu_download_options &op) noexcept override;
|
||||
result<void> visit(const kpu_upload_options &op) noexcept override;
|
||||
result<void> visit(const copy_options &op) noexcept override;
|
||||
|
||||
private:
|
||||
result<gsl::span<gsl::byte>> memory_at(const memory_range &mrange) noexcept;
|
||||
result<void> initialize_before_functions(runtime_module_init_context &context) noexcept override;
|
||||
result<std::unique_ptr<runtime_function>> create_function() noexcept override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<gsl::byte[]> data_;
|
||||
|
@ -50,4 +47,4 @@ private:
|
|||
#endif
|
||||
};
|
||||
|
||||
END_NS_NNCASE_RT_K210
|
||||
END_NS_NNCASE_RT_MODULE
|
||||
|
|
|
@ -22,6 +22,7 @@ NNCASE_INLINE_VAR constexpr char SHADER_SECTION_NAME[] = ".shader";
|
|||
NNCASE_INLINE_VAR constexpr char DESCRIPTORS_SECTION_NAME[] = ".descriptors";
|
||||
|
||||
NNCASE_INLINE_VAR constexpr module_type_t vulkan_module_type = to_module_type("vulkan");
|
||||
NNCASE_INLINE_VAR constexpr uint32_t vulkan_module_version = 1;
|
||||
|
||||
NNCASE_MODULES_VULKAN_API result<std::unique_ptr<runtime_module>> create_vulkan_runtime_module();
|
||||
|
||||
|
|
|
@ -39,6 +39,11 @@ module_type_t vulkan_module_builder::module_type() const noexcept
|
|||
return vulkan_module_type;
|
||||
}
|
||||
|
||||
uint32_t vulkan_module_builder::module_version() const noexcept
|
||||
{
|
||||
return vulkan_module_version;
|
||||
}
|
||||
|
||||
section_writer &vulkan_module_builder::text_writer()
|
||||
{
|
||||
return writer(".text");
|
||||
|
@ -104,17 +109,27 @@ void vulkan_module_builder::ldpipeline(ir::node &node, size_t shader_index, ldpi
|
|||
tw.write(op);
|
||||
}
|
||||
|
||||
void vulkan_module_builder::begin_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
|
||||
{
|
||||
set_current_entry_point(text_writer().position());
|
||||
}
|
||||
|
||||
void vulkan_module_builder::end_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
|
||||
{
|
||||
set_current_function_text_end(text_writer().position());
|
||||
}
|
||||
|
||||
void vulkan_module_builder::emit(ir::node &node)
|
||||
{
|
||||
#define DEFINE_OP(op) \
|
||||
if (node.runtime_opcode() == ir::op::opcode()) \
|
||||
return emit(static_cast<ir::op &>(node));
|
||||
return emit(static_cast<op &>(node));
|
||||
#include "ops.def"
|
||||
#undef DEFINE_OP
|
||||
module_builder::emit(node);
|
||||
}
|
||||
|
||||
void vulkan_module_builder::end_emit()
|
||||
void vulkan_module_builder::end_emit_module()
|
||||
{
|
||||
auto &sw = writer(DESCRIPTORS_SECTION_NAME);
|
||||
sw.write(descriptor_sets_);
|
||||
|
|
|
@ -29,13 +29,16 @@ public:
|
|||
vulkan_module_builder(std::string_view module_name, const module_builder_params ¶ms);
|
||||
|
||||
module_type_t module_type() const noexcept override;
|
||||
uint32_t module_version() const noexcept override;
|
||||
|
||||
protected:
|
||||
section_writer &text_writer();
|
||||
section_writer &shader_writer();
|
||||
|
||||
void begin_emit_function(const schedule::function_schedule_result &function) override;
|
||||
void end_emit_function(const schedule::function_schedule_result &function) override;
|
||||
void emit(ir::node &node) override;
|
||||
void end_emit() override;
|
||||
void end_emit_module() override;
|
||||
|
||||
private:
|
||||
std::vector<uint32_t> compile_shader(ir::node &node, const std::string &template_name, const nlohmann::json &context);
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
cmake_minimum_required (VERSION 3.13)
|
||||
|
||||
set(SRCS runtime_module.cpp
|
||||
runtime_function.cpp
|
||||
vulkan_error.cpp
|
||||
op_reader.cpp
|
||||
ops/ldbuf.cpp
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include "../vulkan_error.h"
|
||||
#include <nncase/runtime/dbg.h>
|
||||
#include <nncase/runtime/error.h>
|
||||
|
@ -21,7 +21,7 @@ using namespace nncase;
|
|||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::vulkan;
|
||||
|
||||
result<void> vulkan_runtime_module::visit(const barrier_op_t &op) noexcept
|
||||
result<void> vulkan_runtime_function::visit(const barrier_op_t &op) noexcept
|
||||
{
|
||||
CHECK_WITH_ERR(op.memory_barriers == 0, std::errc::not_supported);
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include "../vulkan_error.h"
|
||||
#include <nncase/runtime/error.h>
|
||||
|
||||
|
@ -20,7 +20,7 @@ using namespace nncase;
|
|||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::vulkan;
|
||||
|
||||
result<void> vulkan_runtime_module::visit(const copybuf_op_t &op) noexcept
|
||||
result<void> vulkan_runtime_function::visit(const copybuf_op_t &op) noexcept
|
||||
{
|
||||
try_var(output, pop_buffer_ref());
|
||||
try_var(input, pop_buffer_ref());
|
||||
|
|
|
@ -12,13 +12,13 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::vulkan;
|
||||
|
||||
result<void> vulkan_runtime_module::visit(const dispatch_op_t &op) noexcept
|
||||
result<void> vulkan_runtime_function::visit(const dispatch_op_t &op) noexcept
|
||||
{
|
||||
cmd_buffer_.dispatch(op.x, op.y, op.z);
|
||||
return ok();
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include "../vulkan_error.h"
|
||||
#include <nncase/runtime/error.h>
|
||||
|
||||
|
@ -20,24 +20,24 @@ using namespace nncase;
|
|||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::vulkan;
|
||||
|
||||
result<void> vulkan_runtime_module::visit(const ldbuf_op_t &op) noexcept
|
||||
result<void> vulkan_runtime_function::visit(const ldbuf_op_t &op) noexcept
|
||||
{
|
||||
vk::Buffer *dev_buf;
|
||||
vk::Buffer dev_buf;
|
||||
switch (op.memory.memory_location)
|
||||
{
|
||||
case mem_input:
|
||||
dev_buf = &input_buffer_;
|
||||
dev_buf = input_buffer_;
|
||||
break;
|
||||
case mem_output:
|
||||
dev_buf = &output_buffer_;
|
||||
dev_buf = output_buffer_;
|
||||
break;
|
||||
case mem_data:
|
||||
dev_buf = &data_buffer_;
|
||||
dev_buf = module().data();
|
||||
break;
|
||||
default:
|
||||
return err(nncase_errc::invalid_memory_location);
|
||||
}
|
||||
|
||||
buffer_refs_.emplace_back(buffer_ref { *dev_buf, op.memory.start, op.memory.size });
|
||||
buffer_refs_.emplace_back(buffer_ref { dev_buf, op.memory.start, op.memory.size });
|
||||
return ok();
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include "../vulkan_error.h"
|
||||
#include <nncase/runtime/error.h>
|
||||
|
||||
|
@ -20,26 +20,26 @@ using namespace nncase;
|
|||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::vulkan;
|
||||
|
||||
result<void> vulkan_runtime_module::visit(const ldbufbarrier_op_t &op) noexcept
|
||||
result<void> vulkan_runtime_function::visit(const ldbufbarrier_op_t &op) noexcept
|
||||
{
|
||||
vk::Buffer *dev_buf;
|
||||
vk::Buffer dev_buf;
|
||||
switch (op.memory.memory_location)
|
||||
{
|
||||
case mem_input:
|
||||
dev_buf = &input_buffer_;
|
||||
dev_buf = input_buffer_;
|
||||
break;
|
||||
case mem_output:
|
||||
dev_buf = &output_buffer_;
|
||||
dev_buf = output_buffer_;
|
||||
break;
|
||||
case mem_data:
|
||||
dev_buf = &data_buffer_;
|
||||
dev_buf = module().data();
|
||||
break;
|
||||
default:
|
||||
return err(nncase_errc::invalid_memory_location);
|
||||
}
|
||||
|
||||
buffer_barriers_.emplace_back(vk::BufferMemoryBarrier((vk::AccessFlagBits)op.src_access_mask,
|
||||
(vk::AccessFlagBits)op.dest_access_mask, compute_queue_index_, compute_queue_index_, *dev_buf,
|
||||
(vk::DeviceSize)op.memory.start, (vk::DeviceSize)op.memory.size));
|
||||
(vk::AccessFlagBits)op.dest_access_mask, module().compute_queue_index(), module().compute_queue_index(),
|
||||
dev_buf, (vk::DeviceSize)op.memory.start, (vk::DeviceSize)op.memory.size));
|
||||
return ok();
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include "../vulkan_error.h"
|
||||
#include <nncase/runtime/error.h>
|
||||
|
||||
|
@ -20,7 +20,7 @@ using namespace nncase;
|
|||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::vulkan;
|
||||
|
||||
result<void> vulkan_runtime_module::visit(const ldbufcopy_op_t &op) noexcept
|
||||
result<void> vulkan_runtime_function::visit(const ldbufcopy_op_t &op) noexcept
|
||||
{
|
||||
buffer_copies_.emplace_back(vk::BufferCopy((vk::DeviceSize)op.src, (vk::DeviceSize)op.dest, (vk::DeviceSize)op.size));
|
||||
return ok();
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include "../vulkan_error.h"
|
||||
#include <vulkan/vulkan.h>
|
||||
|
||||
|
@ -20,11 +20,11 @@ using namespace nncase;
|
|||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::vulkan;
|
||||
|
||||
result<void> vulkan_runtime_module::visit(const ldpipeline_op_t &op) noexcept
|
||||
result<void> vulkan_runtime_function::visit(const ldpipeline_op_t &op) noexcept
|
||||
{
|
||||
auto code = shader_.subspan(op.shader_start, op.shader_size).as_span<const uint32_t>();
|
||||
auto code = module().shader().subspan(op.shader_start, op.shader_size).as_span<const uint32_t>();
|
||||
vk::ShaderModuleCreateInfo shader_cinfo({}, op.shader_size, code.data());
|
||||
try_var(shader, vk::to_result(device_.createShaderModule(shader_cinfo)));
|
||||
try_var(shader, vk::to_result(module().device().createShaderModule(shader_cinfo)));
|
||||
|
||||
std::vector<vk::DescriptorSetLayoutBinding> layout_bindings((size_t)op.buffers);
|
||||
for (int32_t i = (int32_t)op.buffers - 1; i >= 0; i--)
|
||||
|
@ -37,20 +37,20 @@ result<void> vulkan_runtime_module::visit(const ldpipeline_op_t &op) noexcept
|
|||
}
|
||||
|
||||
vk::DescriptorSetLayoutCreateInfo desc_layout_cinfo({}, layout_bindings);
|
||||
try_var(desc_layout, vk::to_result(device_.createDescriptorSetLayout(desc_layout_cinfo)));
|
||||
try_var(desc_layout, vk::to_result(module().device().createDescriptorSetLayout(desc_layout_cinfo)));
|
||||
|
||||
vk::PipelineLayoutCreateInfo ppl_layout_cinfo({}, desc_layout);
|
||||
try_var(ppl_layout, vk::to_result(device_.createPipelineLayout(ppl_layout_cinfo)));
|
||||
try_var(ppl_layout, vk::to_result(module().device().createPipelineLayout(ppl_layout_cinfo)));
|
||||
|
||||
if (op.shader_type != shader_type_t::compute)
|
||||
return err(std::errc::not_supported);
|
||||
|
||||
vk::ComputePipelineCreateInfo comp_ppl_cinfo({}, { {}, vk::ShaderStageFlagBits::eCompute, shader, "main" }, ppl_layout);
|
||||
try_var(pipeline, vk::to_result(device_.createComputePipeline({}, comp_ppl_cinfo)));
|
||||
pipelines_owner_.emplace_back(pipeline);
|
||||
try_var(pipeline, vk::to_result(module().device().createComputePipeline({}, comp_ppl_cinfo)));
|
||||
try_(module().add_pipeline(pipeline));
|
||||
|
||||
vk::DescriptorSetAllocateInfo desc_alloc_info(buffer_desc_pool_, desc_layout);
|
||||
try_var(desc_sets, vk::to_result(device_.allocateDescriptorSets(desc_alloc_info)));
|
||||
vk::DescriptorSetAllocateInfo desc_alloc_info(module().buffer_desc_pool(), desc_layout);
|
||||
try_var(desc_sets, vk::to_result(module().device().allocateDescriptorSets(desc_alloc_info)));
|
||||
|
||||
std::vector<vk::DescriptorBufferInfo> buffer_infos((size_t)op.buffers);
|
||||
std::vector<vk::WriteDescriptorSet> write_descs(buffer_infos.size());
|
||||
|
@ -72,7 +72,7 @@ result<void> vulkan_runtime_module::visit(const ldpipeline_op_t &op) noexcept
|
|||
write_desc.setDstSet(desc_sets[0]);
|
||||
}
|
||||
|
||||
device_.updateDescriptorSets(write_descs, {});
|
||||
module().device().updateDescriptorSets(write_descs, {});
|
||||
|
||||
cmd_buffer_.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline);
|
||||
cmd_buffer_.bindDescriptorSets(vk::PipelineBindPoint::eCompute, ppl_layout, 0, desc_sets, {});
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "runtime_function.h"
|
||||
#include "vulkan_error.h"
|
||||
#include <nncase/runtime/dbg.h>
|
||||
#include <nncase/runtime/host_runtime_tensor.h>
|
||||
#include <nncase/runtime/runtime_loader.h>
|
||||
#include <nncase/runtime/runtime_op_utility.h>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::vulkan;
|
||||
|
||||
vulkan_runtime_function::~vulkan_runtime_function()
|
||||
{
|
||||
free_vulkan_resources();
|
||||
}
|
||||
|
||||
vulkan_runtime_module &vulkan_runtime_function::module() const noexcept
|
||||
{
|
||||
return static_cast<vulkan_runtime_module &>(runtime_function::module());
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_function::initialize_core(runtime_function_init_context &context) noexcept
|
||||
{
|
||||
input_pool_size_ = context.header().input_pool_size;
|
||||
output_pool_size_ = context.header().output_pool_size;
|
||||
text_ = context.module_init_context().section(".text").subspan(context.header().entrypoint, context.header().text_size);
|
||||
|
||||
try_(initialize_vulkan_device());
|
||||
try_(initialize_vulkan_memory());
|
||||
try_(initialize_vulkan_commands());
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<runtime_tensor> vulkan_runtime_function::allocate_input_tensor(size_t index) noexcept
|
||||
{
|
||||
return host_runtime_tensor::create(input_desc(index).datatype, input_shape(index));
|
||||
}
|
||||
|
||||
result<runtime_tensor> vulkan_runtime_function::allocate_output_tensor(size_t index) noexcept
|
||||
{
|
||||
return host_runtime_tensor::create(output_desc(index).datatype, output_shape(index));
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_function::validate_input_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
|
||||
{
|
||||
if (tensor.is_host() && tensor.is_contiguous())
|
||||
return ok();
|
||||
return err(std::errc::invalid_argument);
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_function::validate_output_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
|
||||
{
|
||||
if (tensor.is_host() && tensor.is_contiguous())
|
||||
return ok();
|
||||
return err(std::errc::invalid_argument);
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_function::initialize_vulkan_device() noexcept
|
||||
{
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_function::initialize_vulkan_memory() noexcept
|
||||
{
|
||||
if (input_pool_size_)
|
||||
{
|
||||
try_set(input_buffer_, module().allocate_vulkan_buffer(input_pool_size_));
|
||||
try_set(input_mem_, module().allocate_vulkan_memory({ vk::MemoryPropertyFlagBits::eHostVisible, vk::MemoryPropertyFlagBits::eHostCached, {} }, input_buffer_));
|
||||
try_(module().bind_vulkan_buffer(input_buffer_, input_mem_));
|
||||
}
|
||||
|
||||
if (output_pool_size_)
|
||||
{
|
||||
try_set(output_buffer_, module().allocate_vulkan_buffer(output_pool_size_));
|
||||
try_set(output_mem_, module().allocate_vulkan_memory({ vk::MemoryPropertyFlagBits::eHostVisible, vk::MemoryPropertyFlagBits::eHostCached, {} }, output_buffer_));
|
||||
try_(module().bind_vulkan_buffer(output_buffer_, output_mem_));
|
||||
}
|
||||
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_function::initialize_vulkan_commands() noexcept
|
||||
{
|
||||
vk::CommandBufferAllocateInfo cmdb_cinfo(module().command_pool(), vk::CommandBufferLevel::ePrimary, 1);
|
||||
try_var(cmdbs, vk::to_result(module().device().allocateCommandBuffers(cmdb_cinfo)));
|
||||
cmd_buffer_ = cmdbs[0];
|
||||
|
||||
vk::CommandBufferBeginInfo cmdb_info;
|
||||
try_(vk::to_result(cmd_buffer_.begin(cmdb_info)));
|
||||
try_(visit(text_));
|
||||
try_(vk::to_result(cmd_buffer_.end()));
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_function::preprocess_inputs() noexcept
|
||||
{
|
||||
try_var(dest, vk::to_result(module().device().mapMemory(input_mem_, 0, VK_WHOLE_SIZE, {})));
|
||||
|
||||
for (size_t i = 0; i < inputs_size(); i++)
|
||||
{
|
||||
try_var(src_tensor, device_input_tensor(i));
|
||||
try_var(src_map, hrt::map(src_tensor, hrt::map_read));
|
||||
auto &desc = input_desc(i);
|
||||
memcpy((uint8_t *)dest + desc.start, src_map.buffer().data(), desc.size);
|
||||
}
|
||||
|
||||
vk::MappedMemoryRange range(input_mem_, 0, VK_WHOLE_SIZE);
|
||||
try_(vk::to_result(module().device().flushMappedMemoryRanges(range)));
|
||||
module().device().unmapMemory(input_mem_);
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_function::invoke_core() noexcept
|
||||
{
|
||||
try_(preprocess_inputs());
|
||||
|
||||
vk::SubmitInfo si({}, {}, cmd_buffer_, {});
|
||||
try_(vk::to_result(module().compute_queue().submit(si)));
|
||||
try_(vk::to_result(module().compute_queue().waitIdle()));
|
||||
|
||||
try_(postprocess_outputs());
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_function::postprocess_outputs() noexcept
|
||||
{
|
||||
try_var(src, vk::to_result(module().device().mapMemory(output_mem_, 0, VK_WHOLE_SIZE, {})));
|
||||
vk::MappedMemoryRange range(output_mem_, 0, VK_WHOLE_SIZE);
|
||||
try_(vk::to_result(module().device().invalidateMappedMemoryRanges(range)));
|
||||
|
||||
for (size_t i = 0; i < outputs_size(); i++)
|
||||
{
|
||||
try_var(dest_tensor, device_output_tensor(i));
|
||||
try_var(dest_map, hrt::map(dest_tensor, hrt::map_write));
|
||||
auto &desc = output_desc(i);
|
||||
memcpy(dest_map.buffer().data(), (const uint8_t *)src + desc.start, desc.size);
|
||||
}
|
||||
|
||||
module().device().unmapMemory(output_mem_);
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<vulkan_runtime_function::buffer_ref> vulkan_runtime_function::pop_buffer_ref() noexcept
|
||||
{
|
||||
if (buffer_refs_.empty())
|
||||
return err(std::errc::result_out_of_range);
|
||||
auto buffer_ref = std::move(buffer_refs_.back());
|
||||
buffer_refs_.pop_back();
|
||||
return ok(std::move(buffer_ref));
|
||||
}
|
||||
|
||||
void vulkan_runtime_function::free_vulkan_resources() noexcept
|
||||
{
|
||||
if (auto device = module().device())
|
||||
{
|
||||
if (module().command_pool())
|
||||
device.freeCommandBuffers(module().command_pool(), cmd_buffer_);
|
||||
device.destroyBuffer(input_buffer_);
|
||||
device.destroyBuffer(output_buffer_);
|
||||
device.freeMemory(input_mem_);
|
||||
device.freeMemory(output_mem_);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "runtime_module.h"
|
||||
#include <nncase/kernels/kernel_context.h>
|
||||
#include <nncase/runtime/runtime_function.h>
|
||||
#include <nncase/runtime/vulkan/op_reader.h>
|
||||
#include <vulkan/vulkan.hpp>
|
||||
|
||||
BEGIN_NS_NNCASE_RT_MODULE(vulkan)
|
||||
|
||||
class vulkan_runtime_function : public runtime_function, private op_visitor
|
||||
{
|
||||
struct buffer_ref
|
||||
{
|
||||
vk::Buffer buffer;
|
||||
size_t start;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
public:
|
||||
using runtime_function::runtime_function;
|
||||
virtual ~vulkan_runtime_function();
|
||||
|
||||
vulkan_runtime_module &module() const noexcept;
|
||||
|
||||
protected:
|
||||
result<void> initialize_core(runtime_function_init_context &context) noexcept override;
|
||||
result<runtime_tensor> allocate_input_tensor(size_t index) noexcept override;
|
||||
result<runtime_tensor> allocate_output_tensor(size_t index) noexcept override;
|
||||
result<void> validate_input_tensor(size_t index, runtime_tensor tensor) noexcept override;
|
||||
result<void> validate_output_tensor(size_t index, runtime_tensor tensor) noexcept override;
|
||||
result<void> invoke_core() noexcept override;
|
||||
|
||||
using op_visitor::visit;
|
||||
result<void> visit(const ldbuf_op_t &op) noexcept override;
|
||||
result<void> visit(const ldbufbarrier_op_t &op) noexcept override;
|
||||
result<void> visit(const ldbufcopy_op_t &op) noexcept override;
|
||||
result<void> visit(const copybuf_op_t &op) noexcept override;
|
||||
result<void> visit(const ldpipeline_op_t &op) noexcept override;
|
||||
result<void> visit(const dispatch_op_t &op) noexcept override;
|
||||
result<void> visit(const barrier_op_t &op) noexcept override;
|
||||
|
||||
private:
|
||||
result<void> initialize_vulkan_device() noexcept;
|
||||
result<void> initialize_vulkan_memory() noexcept;
|
||||
result<void> initialize_vulkan_commands() noexcept;
|
||||
|
||||
result<buffer_ref> pop_buffer_ref() noexcept;
|
||||
result<void> preprocess_inputs() noexcept;
|
||||
result<void> postprocess_outputs() noexcept;
|
||||
|
||||
void free_vulkan_resources() noexcept;
|
||||
|
||||
private:
|
||||
uint32_t input_pool_size_;
|
||||
uint32_t output_pool_size_;
|
||||
gsl::span<const gsl::byte> text_;
|
||||
vk::Buffer input_buffer_;
|
||||
vk::Buffer output_buffer_;
|
||||
vk::DeviceMemory input_mem_;
|
||||
vk::DeviceMemory output_mem_;
|
||||
std::vector<buffer_ref> buffer_refs_;
|
||||
vk::CommandBuffer cmd_buffer_;
|
||||
std::vector<vk::BufferMemoryBarrier> buffer_barriers_;
|
||||
std::vector<vk::BufferCopy> buffer_copies_;
|
||||
};
|
||||
|
||||
END_NS_NNCASE_RT_MODULE
|
|
@ -13,6 +13,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "runtime_module.h"
|
||||
#include "runtime_function.h"
|
||||
#include "vulkan_error.h"
|
||||
#include <nncase/runtime/dbg.h>
|
||||
#include <nncase/runtime/host_runtime_tensor.h>
|
||||
|
@ -28,59 +29,26 @@ vulkan_runtime_module::~vulkan_runtime_module()
|
|||
free_vulkan_resources();
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_module::initialize_core(runtime_module_init_context &context) noexcept
|
||||
result<void> vulkan_runtime_module::initialize_before_functions(runtime_module_init_context &context) noexcept
|
||||
{
|
||||
assert(context.is_section_pinned());
|
||||
auto descs = context.section(DESCRIPTORS_SECTION_NAME).as_span<const uint32_t>();
|
||||
descriptor_sets_ = descs[0];
|
||||
descriptors_ = descs[1];
|
||||
rdata_ = context.section(".rdata");
|
||||
text_ = context.section(".text");
|
||||
shader_ = context.section(".shader");
|
||||
|
||||
try_(initialize_vulkan());
|
||||
|
||||
// TODO: Load rdata
|
||||
//rdata_ = context.section(".rdata");
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<runtime_tensor> vulkan_runtime_module::allocate_input_tensor(size_t index) noexcept
|
||||
{
|
||||
return host_runtime_tensor::create(input_desc(index).datatype, input_shape(index));
|
||||
}
|
||||
|
||||
result<runtime_tensor> vulkan_runtime_module::allocate_output_tensor(size_t index) noexcept
|
||||
{
|
||||
return host_runtime_tensor::create(output_desc(index).datatype, output_shape(index));
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_module::validate_input_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
|
||||
{
|
||||
if (tensor.is_host() && tensor.is_contiguous())
|
||||
return ok();
|
||||
return err(std::errc::invalid_argument);
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_module::validate_output_tensor(NNCASE_UNUSED size_t index, runtime_tensor tensor) noexcept
|
||||
{
|
||||
if (tensor.is_host() && tensor.is_contiguous())
|
||||
return ok();
|
||||
return err(std::errc::invalid_argument);
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_module::initialize_vulkan() noexcept
|
||||
{
|
||||
try_(initialize_vulkan_instance());
|
||||
try_(initialize_vulkan_device());
|
||||
try_(initialize_vulkan_memory());
|
||||
try_(initialize_vulkan_commands());
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_module::initialize_vulkan_commands() noexcept
|
||||
{
|
||||
vk::CommandBufferBeginInfo cmdb_info;
|
||||
try_(vk::to_result(cmd_buffer_.begin(cmdb_info)));
|
||||
try_(visit(text_));
|
||||
try_(vk::to_result(cmd_buffer_.end()));
|
||||
return ok();
|
||||
}
|
||||
|
||||
|
@ -110,30 +78,11 @@ result<void> vulkan_runtime_module::initialize_vulkan_device() noexcept
|
|||
|
||||
vk::CommandPoolCreateInfo cmdp_cinfo({}, compute_queue_index_);
|
||||
try_set(cmd_pool_, vk::to_result(device_.createCommandPool(cmdp_cinfo)));
|
||||
vk::CommandBufferAllocateInfo cmdb_cinfo(cmd_pool_, vk::CommandBufferLevel::ePrimary, 1);
|
||||
try_var(cmdbs, vk::to_result(device_.allocateCommandBuffers(cmdb_cinfo)));
|
||||
cmd_buffer_ = cmdbs[0];
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_module::initialize_vulkan_memory() noexcept
|
||||
{
|
||||
auto input_mem = mempool(mem_input);
|
||||
if (input_mem.size)
|
||||
{
|
||||
try_set(input_buffer_, allocate_vulkan_buffer(input_mem.size));
|
||||
try_set(input_mem_, allocate_vulkan_memory({ vk::MemoryPropertyFlagBits::eHostVisible, vk::MemoryPropertyFlagBits::eHostCached, {} }, input_buffer_));
|
||||
try_(bind_vulkan_buffer(input_buffer_, input_mem_));
|
||||
}
|
||||
|
||||
auto output_mem = mempool(mem_output);
|
||||
if (output_mem.size)
|
||||
{
|
||||
try_set(output_buffer_, allocate_vulkan_buffer(output_mem.size));
|
||||
try_set(output_mem_, allocate_vulkan_memory({ vk::MemoryPropertyFlagBits::eHostVisible, vk::MemoryPropertyFlagBits::eHostCached, {} }, output_buffer_));
|
||||
try_(bind_vulkan_buffer(output_buffer_, output_mem_));
|
||||
}
|
||||
|
||||
auto data_mem = mempool(mem_data);
|
||||
if (data_mem.size)
|
||||
{
|
||||
|
@ -165,6 +114,19 @@ result<void> vulkan_runtime_module::bind_vulkan_buffer(vk::Buffer buffer, vk::De
|
|||
return vk::to_result(device_.bindBufferMemory(buffer, memory, 0));
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_module::add_pipeline(vk::Pipeline pipeline) noexcept
|
||||
{
|
||||
try
|
||||
{
|
||||
pipelines_owner_.emplace_back(pipeline);
|
||||
return ok();
|
||||
}
|
||||
catch (const std::bad_alloc &)
|
||||
{
|
||||
return err(std::errc::not_enough_memory);
|
||||
}
|
||||
}
|
||||
|
||||
result<vk::PhysicalDevice> vulkan_runtime_module::select_physical_device() noexcept
|
||||
{
|
||||
vk::PhysicalDevice *intergrated = nullptr;
|
||||
|
@ -261,80 +223,29 @@ result<size_t> vulkan_runtime_module::select_memory_type(const vk::PhysicalDevic
|
|||
return err(std::errc::not_enough_memory);
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_module::run_core() noexcept
|
||||
{
|
||||
try_(preprocess_inputs());
|
||||
|
||||
vk::SubmitInfo si({}, {}, cmd_buffer_, {});
|
||||
try_(vk::to_result(compute_queue_.submit(si)));
|
||||
try_(vk::to_result(compute_queue_.waitIdle()));
|
||||
|
||||
try_(postprocess_outputs());
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_module::preprocess_inputs() noexcept
|
||||
{
|
||||
try_var(dest, vk::to_result(device_.mapMemory(input_mem_, 0, VK_WHOLE_SIZE, {})));
|
||||
|
||||
for (size_t i = 0; i < inputs_size(); i++)
|
||||
{
|
||||
try_var(src_tensor, device_input_tensor(i));
|
||||
try_var(src_map, hrt::map(src_tensor, hrt::map_read));
|
||||
auto &desc = input_desc(i);
|
||||
memcpy((uint8_t *)dest + desc.start, src_map.buffer().data(), desc.size);
|
||||
}
|
||||
|
||||
vk::MappedMemoryRange range(input_mem_, 0, VK_WHOLE_SIZE);
|
||||
try_(vk::to_result(device_.flushMappedMemoryRanges(range)));
|
||||
device_.unmapMemory(input_mem_);
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<void> vulkan_runtime_module::postprocess_outputs() noexcept
|
||||
{
|
||||
try_var(src, vk::to_result(device_.mapMemory(output_mem_, 0, VK_WHOLE_SIZE, {})));
|
||||
vk::MappedMemoryRange range(output_mem_, 0, VK_WHOLE_SIZE);
|
||||
try_(vk::to_result(device_.invalidateMappedMemoryRanges(range)));
|
||||
|
||||
for (size_t i = 0; i < outputs_size(); i++)
|
||||
{
|
||||
try_var(dest_tensor, device_output_tensor(i));
|
||||
try_var(dest_map, hrt::map(dest_tensor, hrt::map_write));
|
||||
auto &desc = output_desc(i);
|
||||
memcpy(dest_map.buffer().data(), (const uint8_t *)src + desc.start, desc.size);
|
||||
}
|
||||
|
||||
device_.unmapMemory(output_mem_);
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<vulkan_runtime_module::buffer_ref> vulkan_runtime_module::pop_buffer_ref() noexcept
|
||||
{
|
||||
if (buffer_refs_.empty())
|
||||
return err(std::errc::result_out_of_range);
|
||||
auto buffer_ref = std::move(buffer_refs_.back());
|
||||
buffer_refs_.pop_back();
|
||||
return ok(std::move(buffer_ref));
|
||||
}
|
||||
|
||||
void vulkan_runtime_module::free_vulkan_resources() noexcept
|
||||
{
|
||||
device_.freeCommandBuffers(cmd_pool_, cmd_buffer_);
|
||||
for (auto &func : functions())
|
||||
func.reset();
|
||||
|
||||
device_.destroyCommandPool(cmd_pool_);
|
||||
device_.destroyDescriptorPool(buffer_desc_pool_);
|
||||
for (auto p : pipelines_owner_)
|
||||
device_.destroyPipeline(p);
|
||||
device_.destroyBuffer(input_buffer_);
|
||||
device_.destroyBuffer(output_buffer_);
|
||||
device_.destroyBuffer(data_buffer_);
|
||||
device_.freeMemory(input_mem_);
|
||||
device_.freeMemory(output_mem_);
|
||||
device_.freeMemory(data_mem_);
|
||||
device_.destroy({});
|
||||
instance_.destroy({});
|
||||
}
|
||||
|
||||
result<std::unique_ptr<runtime_function>> vulkan_runtime_module::create_function() noexcept
|
||||
{
|
||||
std::unique_ptr<runtime_function> mod(new (std::nothrow) vulkan_runtime_function(*this));
|
||||
if (mod)
|
||||
return ok(std::move(mod));
|
||||
return err(std::errc::not_enough_memory);
|
||||
}
|
||||
|
||||
result<std::unique_ptr<runtime_module>> vulkan::create_vulkan_runtime_module()
|
||||
{
|
||||
std::unique_ptr<runtime_module> mod(new (std::nothrow) vulkan_runtime_module());
|
||||
|
|
|
@ -20,67 +20,54 @@
|
|||
|
||||
BEGIN_NS_NNCASE_RT_MODULE(vulkan)
|
||||
|
||||
class vulkan_runtime_module : public runtime_module, private op_visitor
|
||||
template <class T>
|
||||
struct select_options
|
||||
{
|
||||
template <class T>
|
||||
struct select_options
|
||||
{
|
||||
T requried;
|
||||
T preferred;
|
||||
T not_preferred;
|
||||
};
|
||||
T requried;
|
||||
T preferred;
|
||||
T not_preferred;
|
||||
};
|
||||
|
||||
struct buffer_ref
|
||||
{
|
||||
vk::Buffer buffer;
|
||||
size_t start;
|
||||
size_t size;
|
||||
};
|
||||
class vulkan_runtime_module : public runtime_module
|
||||
{
|
||||
|
||||
public:
|
||||
virtual ~vulkan_runtime_module();
|
||||
|
||||
protected:
|
||||
result<void> initialize_core(runtime_module_init_context &context) noexcept override;
|
||||
result<runtime_tensor> allocate_input_tensor(size_t index) noexcept override;
|
||||
result<runtime_tensor> allocate_output_tensor(size_t index) noexcept override;
|
||||
result<void> validate_input_tensor(size_t index, runtime_tensor tensor) noexcept override;
|
||||
result<void> validate_output_tensor(size_t index, runtime_tensor tensor) noexcept override;
|
||||
result<void> run_core() noexcept override;
|
||||
vk::Buffer data() const noexcept { return data_buffer_; }
|
||||
vk::Buffer rdata() const noexcept { return {}; }
|
||||
gsl::span<const gsl::byte> shader() const noexcept { return shader_; }
|
||||
|
||||
using op_visitor::visit;
|
||||
result<void> visit(const ldbuf_op_t &op) noexcept override;
|
||||
result<void> visit(const ldbufbarrier_op_t &op) noexcept override;
|
||||
result<void> visit(const ldbufcopy_op_t &op) noexcept override;
|
||||
result<void> visit(const copybuf_op_t &op) noexcept override;
|
||||
result<void> visit(const ldpipeline_op_t &op) noexcept override;
|
||||
result<void> visit(const dispatch_op_t &op) noexcept override;
|
||||
result<void> visit(const barrier_op_t &op) noexcept override;
|
||||
vk::Device device() const noexcept { return device_; }
|
||||
vk::CommandPool command_pool() const noexcept { return cmd_pool_; }
|
||||
uint32_t compute_queue_index() const noexcept { return compute_queue_index_; }
|
||||
vk::Queue compute_queue() const noexcept { return compute_queue_; }
|
||||
vk::DescriptorPool buffer_desc_pool() const noexcept { return buffer_desc_pool_; }
|
||||
|
||||
result<vk::DeviceMemory> allocate_vulkan_memory(const select_options<vk::MemoryPropertyFlagBits> &options, vk::Buffer buffer) noexcept;
|
||||
result<vk::Buffer> allocate_vulkan_buffer(size_t required_size) noexcept;
|
||||
result<void> bind_vulkan_buffer(vk::Buffer buffer, vk::DeviceMemory memory) noexcept;
|
||||
result<void> add_pipeline(vk::Pipeline pipeline) noexcept;
|
||||
|
||||
protected:
|
||||
result<void> initialize_before_functions(runtime_module_init_context &context) noexcept override;
|
||||
result<std::unique_ptr<runtime_function>> create_function() noexcept override;
|
||||
|
||||
private:
|
||||
result<void> initialize_vulkan() noexcept;
|
||||
result<void> initialize_vulkan_instance() noexcept;
|
||||
result<void> initialize_vulkan_device() noexcept;
|
||||
result<void> initialize_vulkan_memory() noexcept;
|
||||
result<void> initialize_vulkan_commands() noexcept;
|
||||
|
||||
result<vk::PhysicalDevice> select_physical_device() noexcept;
|
||||
result<uint32_t> select_queue_family(const std::vector<vk::QueueFamilyProperties> &families, const select_options<vk::QueueFlagBits> options) noexcept;
|
||||
result<size_t> select_memory_type(const vk::PhysicalDeviceMemoryProperties &properties, const select_options<vk::MemoryPropertyFlagBits> &options, size_t required_size) noexcept;
|
||||
result<vk::DeviceMemory> allocate_vulkan_memory(const select_options<vk::MemoryPropertyFlagBits> &options, vk::Buffer buffer) noexcept;
|
||||
result<vk::Buffer> allocate_vulkan_buffer(size_t required_size) noexcept;
|
||||
result<void> bind_vulkan_buffer(vk::Buffer buffer, vk::DeviceMemory memory) noexcept;
|
||||
|
||||
result<buffer_ref> pop_buffer_ref() noexcept;
|
||||
result<void> preprocess_inputs() noexcept;
|
||||
result<void> postprocess_outputs() noexcept;
|
||||
|
||||
void free_vulkan_resources() noexcept;
|
||||
|
||||
private:
|
||||
uint32_t descriptors_;
|
||||
uint32_t descriptor_sets_;
|
||||
gsl::span<const gsl::byte> rdata_;
|
||||
gsl::span<const gsl::byte> text_;
|
||||
gsl::span<const gsl::byte> shader_;
|
||||
vk::Instance instance_;
|
||||
|
@ -88,19 +75,11 @@ private:
|
|||
vk::Device device_;
|
||||
uint32_t compute_queue_index_;
|
||||
vk::Queue compute_queue_;
|
||||
vk::Buffer input_buffer_;
|
||||
vk::Buffer output_buffer_;
|
||||
vk::Buffer data_buffer_;
|
||||
vk::DeviceMemory input_mem_;
|
||||
vk::DeviceMemory output_mem_;
|
||||
vk::DeviceMemory data_mem_;
|
||||
std::vector<buffer_ref> buffer_refs_;
|
||||
std::vector<vk::Pipeline> pipelines_owner_;
|
||||
vk::DescriptorPool buffer_desc_pool_;
|
||||
vk::CommandPool cmd_pool_;
|
||||
vk::CommandBuffer cmd_buffer_;
|
||||
std::vector<vk::BufferMemoryBarrier> buffer_barriers_;
|
||||
std::vector<vk::BufferCopy> buffer_copies_;
|
||||
};
|
||||
|
||||
END_NS_NNCASE_RT_MODULE
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <nncase/ir/graph.h>
|
||||
#include <nncase/runtime/interpreter.h>
|
||||
#include <nncase/runtime/runtime_op_utility.h>
|
||||
#include <nncase/schedule/scheduler.h>
|
||||
#include <nncase/version.h>
|
||||
#include <pybind11/iostream.h>
|
||||
#include <pybind11/numpy.h>
|
||||
|
@ -74,7 +75,7 @@ void LaunchDebugger()
|
|||
}
|
||||
#endif
|
||||
|
||||
schedule::schedule_result schedule(target &target, ir::graph &graph)
|
||||
schedule::model_schedule_result schedule(target &target, ir::graph &graph)
|
||||
{
|
||||
schedule::scheduler sched(target, graph, graph.outputs());
|
||||
return sched.schedule(true);
|
||||
|
@ -115,7 +116,7 @@ public:
|
|||
|
||||
private:
|
||||
ir::graph &graph_;
|
||||
schedule::schedule_result schedule_result_;
|
||||
schedule::model_schedule_result schedule_result_;
|
||||
ir::evaluator evaluator_;
|
||||
};
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include <nncase/codegen/model_builder.h>
|
||||
#include <nncase/ir/op_utils.h>
|
||||
#include <nncase/runtime/model.h>
|
||||
#include <nncase/targets/target.h>
|
||||
|
||||
|
@ -22,7 +23,7 @@ using namespace nncase::ir;
|
|||
using namespace nncase::schedule;
|
||||
using namespace nncase::runtime;
|
||||
|
||||
model_builder::model_builder(target &target, const schedule::schedule_result &sched)
|
||||
model_builder::model_builder(target &target, const schedule::model_schedule_result &sched)
|
||||
: target_(target), sched_(sched), dump_asm_(false)
|
||||
{
|
||||
}
|
||||
|
@ -41,6 +42,7 @@ build_model_result model_builder::build(std::ostream &output)
|
|||
model_header header {};
|
||||
header.identifier = MODEL_IDENTIFIER;
|
||||
header.version = MODEL_VERSION;
|
||||
header.header_size = sizeof(header);
|
||||
header.flags = 0;
|
||||
header.alignment = 8;
|
||||
header.modules = (uint32_t)sched_.modules.size();
|
||||
|
@ -49,19 +51,27 @@ build_model_result model_builder::build(std::ostream &output)
|
|||
auto header_pos = writer.position();
|
||||
writer.skip(sizeof(header));
|
||||
|
||||
uint32_t main_module_id = 0;
|
||||
for (auto &graph : sched_.graph_orders)
|
||||
for (auto &mod_sched : sched_.modules)
|
||||
{
|
||||
auto &mod = sched_.modules.at(graph);
|
||||
module_builder_params params { sched_, mod };
|
||||
auto builder = target_.create_module_builder(graph->module_type(), graph->name(), params);
|
||||
builder->config_dump(dump_dir_ / graph->escaped_name(), dump_asm_);
|
||||
module_builder_params params { sched_, mod_sched };
|
||||
auto builder = target_.create_module_builder(mod_sched.type, mod_sched.type.data(), params);
|
||||
builder->config_dump(dump_dir_ / mod_sched.type.data(), dump_asm_);
|
||||
builder->build(writer);
|
||||
header.alignment = std::max(header.alignment, builder->alignment());
|
||||
}
|
||||
|
||||
if (graph == sched_.main_module)
|
||||
header.main_module = main_module_id;
|
||||
main_module_id++;
|
||||
// Entry point
|
||||
for (size_t i = 0; i < sched_.modules.size(); i++)
|
||||
{
|
||||
auto &mod_sched = sched_.modules[i];
|
||||
for (size_t j = 0; j < mod_sched.functions.size(); j++)
|
||||
{
|
||||
if (sched_.entry_function == &mod_sched.functions[i])
|
||||
{
|
||||
header.entry_module = (uint32_t)i;
|
||||
header.entry_function = (uint32_t)j;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto end_pos = writer.position();
|
||||
|
@ -78,11 +88,39 @@ build_model_result model_builder::build(std::ostream &output)
|
|||
size_t model_builder::max_usage(memory_location_t location) const
|
||||
{
|
||||
size_t usage = 0;
|
||||
for (auto &mod : sched_.modules)
|
||||
|
||||
if (location == mem_input)
|
||||
{
|
||||
auto it = mod.second.max_usages.find(location);
|
||||
if (it != mod.second.max_usages.end())
|
||||
usage += it->second;
|
||||
// Only take into account of main function's inputs
|
||||
auto graph = sched_.entry_function->graph;
|
||||
auto &entry_in_allocs = sched_.entry_function->module->allocations;
|
||||
for (auto in : graph->inputs())
|
||||
usage += entry_in_allocs.at(&in->output()).size;
|
||||
}
|
||||
else if (location == mem_output)
|
||||
{
|
||||
// Only take into account of main function's outputs
|
||||
auto graph = sched_.entry_function->graph;
|
||||
auto &entry_out_allocs = sched_.entry_function->module->allocations;
|
||||
for (auto out : graph->outputs())
|
||||
usage += entry_out_allocs.at(out->input().connection()).size;
|
||||
}
|
||||
else if (location != mem_shared_data)
|
||||
{
|
||||
for (auto &mod : sched_.modules)
|
||||
{
|
||||
auto it = mod.max_usages.find(location);
|
||||
if (it != mod.max_usages.end())
|
||||
usage += it->second;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto &mod : sched_.modules)
|
||||
{
|
||||
for (auto &shared : mod.shared_max_usages)
|
||||
usage += shared.second;
|
||||
}
|
||||
}
|
||||
|
||||
return usage;
|
||||
|
|
|
@ -65,10 +65,10 @@ section_writer &module_builder::writer(std::string_view section_name)
|
|||
return it->second.writer;
|
||||
}
|
||||
|
||||
std::vector<nncase::ir::node *> module_builder::generate_runtime_ops()
|
||||
std::vector<nncase::ir::node *> module_builder::generate_current_runtime_ops()
|
||||
{
|
||||
std::vector<nncase::ir::node *> runtime_ops;
|
||||
for (auto &&node : params_.module_sched.compute_sequence)
|
||||
for (auto &&node : current_function_->compute_sequence)
|
||||
{
|
||||
if (!non_runtime_opcodes.contains(node->runtime_opcode()))
|
||||
runtime_ops.emplace_back(node);
|
||||
|
@ -76,7 +76,8 @@ std::vector<nncase::ir::node *> module_builder::generate_runtime_ops()
|
|||
|
||||
if (dump_asm_)
|
||||
{
|
||||
std::ofstream file(dump_dir_ / "runtime_ops.txt");
|
||||
std::ofstream file(dump_dir_ / current_function_->graph->escaped_name() / "runtime_ops.txt");
|
||||
std::filesystem::create_directories(dump_dir_);
|
||||
for (auto node : runtime_ops)
|
||||
file << "[" << node->runtime_opcode().name << "] "
|
||||
<< node->name() << std::endl;
|
||||
|
@ -104,15 +105,18 @@ void module_builder::write_constants()
|
|||
if (it != params_.module_sched.max_usages.end())
|
||||
{
|
||||
auto constants = std::make_unique<std::byte[]>(it->second);
|
||||
for (auto &&node : params_.module_sched.compute_sequence)
|
||||
for (auto &func_sched : params_.module_sched.functions)
|
||||
{
|
||||
if (auto con = node_cast<constant>(*node))
|
||||
for (auto &&node : func_sched.compute_sequence)
|
||||
{
|
||||
if (con->output().memory_location() == mem_rdata)
|
||||
if (auto con = node_cast<constant>(*node))
|
||||
{
|
||||
auto &alloc = allocation(con->output());
|
||||
auto data = con->data();
|
||||
std::memcpy(constants.get() + alloc.start, data.data(), data.size_bytes());
|
||||
if (con->output().memory_location() == mem_rdata)
|
||||
{
|
||||
auto &alloc = allocation(con->output());
|
||||
auto data = con->data();
|
||||
std::memcpy(constants.get() + alloc.start, data.data(), data.size_bytes());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -125,11 +129,23 @@ void module_builder::compile()
|
|||
{
|
||||
write_constants();
|
||||
|
||||
auto runtime_ops = generate_runtime_ops();
|
||||
begin_emit();
|
||||
for (auto node : runtime_ops)
|
||||
emit(*node);
|
||||
end_emit();
|
||||
begin_emit_module();
|
||||
|
||||
for (auto &func_sched : params_.module_sched.functions)
|
||||
{
|
||||
current_function_ = &func_sched;
|
||||
|
||||
auto runtime_ops = generate_current_runtime_ops();
|
||||
begin_emit_function(func_sched);
|
||||
for (auto node : runtime_ops)
|
||||
emit(*node);
|
||||
end_emit_function(func_sched);
|
||||
|
||||
if (!entry_points_.contains(current_function_))
|
||||
throw std::runtime_error("Entry point for " + func_sched.graph->name() + " is not set");
|
||||
}
|
||||
|
||||
end_emit_module();
|
||||
|
||||
if (dump_asm_)
|
||||
{
|
||||
|
@ -146,15 +162,34 @@ void module_builder::merge_to_rdata_section(std::string_view from)
|
|||
rdata_section_merges_.emplace(from, std::in_place);
|
||||
}
|
||||
|
||||
size_t module_builder::module_id(ir::graph *graph)
|
||||
function_call_id module_builder::function_id(ir::graph *graph)
|
||||
{
|
||||
auto &orders = params_.sched.graph_orders;
|
||||
auto it = std::find(orders.begin(), orders.end(), graph);
|
||||
if (it != orders.end())
|
||||
return (size_t)std::distance(orders.begin(), it);
|
||||
for (size_t i = 0; i < params_.model_sched.modules.size(); i++)
|
||||
{
|
||||
auto &mod_sched = params_.model_sched.modules[i];
|
||||
auto func_it = mod_sched.functions_map.find(graph);
|
||||
if (func_it != mod_sched.functions_map.end())
|
||||
{
|
||||
auto func_sched = func_it->second;
|
||||
auto &orders = mod_sched.functions;
|
||||
if (func_sched >= orders.data() && func_sched < orders.data() + orders.size())
|
||||
return { i, (size_t)(func_sched - orders.data()) };
|
||||
}
|
||||
}
|
||||
|
||||
throw std::invalid_argument("Can't find graph " + graph->name() + " in modules");
|
||||
}
|
||||
|
||||
void module_builder::set_current_entry_point(std::streampos pos)
|
||||
{
|
||||
entry_points_[current_function_] = pos;
|
||||
}
|
||||
|
||||
void module_builder::set_current_function_text_end(std::streampos pos)
|
||||
{
|
||||
function_text_end_[current_function_] = pos;
|
||||
}
|
||||
|
||||
std::unique_ptr<section_decompiler> module_builder::create_decompiler([[maybe_unused]] std::string_view section_name)
|
||||
{
|
||||
return nullptr;
|
||||
|
@ -272,12 +307,95 @@ void module_builder::link()
|
|||
|
||||
void module_builder::write_binary(binary_writer &writer)
|
||||
{
|
||||
// Skip module header
|
||||
auto header_pos = writer.position();
|
||||
writer.skip(sizeof(module_header));
|
||||
|
||||
// mempools
|
||||
for (auto &mem : params_.module_sched.max_usages)
|
||||
{
|
||||
mempool_desc desc {};
|
||||
desc.location = mem.first;
|
||||
desc.size = mem.second;
|
||||
writer.write(desc);
|
||||
}
|
||||
|
||||
// functions
|
||||
for (auto &func_sched : params_.module_sched.functions)
|
||||
write_function_binary(writer, func_sched);
|
||||
|
||||
// sections
|
||||
for (auto §ion : section_writer_)
|
||||
{
|
||||
section_header header {};
|
||||
strncpy(header.name, section.first.c_str(), std::size(header.name) - 1);
|
||||
|
||||
auto merge_it = rdata_section_merges_.find(section.first);
|
||||
if (merge_it == rdata_section_merges_.end())
|
||||
{
|
||||
header.flags = 0;
|
||||
header.body_start = 0;
|
||||
header.body_size = (uint32_t)section.second.body.size();
|
||||
}
|
||||
else
|
||||
{
|
||||
header.flags = SECTION_MERGED_INTO_RDATA;
|
||||
header.body_start = merge_it->second.start;
|
||||
header.body_size = merge_it->second.size;
|
||||
}
|
||||
|
||||
// Skip section header
|
||||
auto header_pos = writer.position();
|
||||
writer.skip(sizeof(header));
|
||||
|
||||
if (merge_it == rdata_section_merges_.end())
|
||||
{
|
||||
header.body_start = (uint32_t)writer.align_position(alignment_);
|
||||
// write content
|
||||
writer.write_array(std::span<uint8_t const>(section.second.body));
|
||||
}
|
||||
|
||||
// write section header
|
||||
auto end_pos = writer.position();
|
||||
writer.position(header_pos);
|
||||
writer.write(header);
|
||||
writer.position(end_pos);
|
||||
}
|
||||
|
||||
writer.align_position(8);
|
||||
auto end_pos = writer.position();
|
||||
|
||||
// header
|
||||
module_header header {};
|
||||
header.type = module_type();
|
||||
header.version = module_version();
|
||||
header.header_size = sizeof(header);
|
||||
header.size = (uint32_t)(end_pos - header_pos);
|
||||
header.mempools = (uint32_t)params_.module_sched.max_usages.size();
|
||||
header.shared_mempools = (uint32_t)params_.module_sched.shared_max_usages.size();
|
||||
header.functions = (uint32_t)params_.module_sched.functions.size();
|
||||
header.sections = (uint32_t)section_writer_.size();
|
||||
header.reserved0 = 0;
|
||||
writer.position(header_pos);
|
||||
writer.write(header);
|
||||
|
||||
writer.position(end_pos);
|
||||
}
|
||||
|
||||
void module_builder::write_function_binary(binary_writer &writer, const schedule::function_schedule_result &function_sched)
|
||||
{
|
||||
auto write_shape = [&](const shape_t &shape) {
|
||||
writer.write((uint32_t)shape.size());
|
||||
for (auto dim : shape)
|
||||
writer.write((uint32_t)dim);
|
||||
};
|
||||
|
||||
std::vector<memory_range> inputs;
|
||||
std::vector<shape_t> input_shapes;
|
||||
std::vector<memory_range> outputs;
|
||||
std::vector<shape_t> output_shapes;
|
||||
|
||||
for (auto &&node : params_.module_sched.compute_sequence)
|
||||
for (auto &&node : function_sched.compute_sequence)
|
||||
{
|
||||
if (auto in = node_cast<input_node>(*node))
|
||||
{
|
||||
|
@ -293,24 +411,9 @@ void module_builder::write_binary(binary_writer &writer)
|
|||
}
|
||||
}
|
||||
|
||||
// Skip module header
|
||||
// Skip function header
|
||||
auto header_pos = writer.position();
|
||||
writer.skip(sizeof(module_header));
|
||||
|
||||
auto write_shape = [&](const shape_t &shape) {
|
||||
writer.write((uint32_t)shape.size());
|
||||
for (auto dim : shape)
|
||||
writer.write((uint32_t)dim);
|
||||
};
|
||||
|
||||
// mempools
|
||||
for (auto &mem : params_.module_sched.max_usages)
|
||||
{
|
||||
mempool_desc desc {};
|
||||
desc.location = mem.first;
|
||||
desc.size = mem.second;
|
||||
writer.write(desc);
|
||||
}
|
||||
writer.skip(sizeof(function_header));
|
||||
|
||||
// inputs
|
||||
writer.write_array<memory_range>(inputs);
|
||||
|
@ -322,55 +425,20 @@ void module_builder::write_binary(binary_writer &writer)
|
|||
for (auto &shape : output_shapes)
|
||||
write_shape(shape);
|
||||
|
||||
// sections
|
||||
for (auto §ion : section_writer_)
|
||||
{
|
||||
section_header header {};
|
||||
strncpy(header.name, section.first.c_str(), std::size(header.name) - 1);
|
||||
|
||||
auto merge_it = rdata_section_merges_.find(section.first);
|
||||
if (merge_it == rdata_section_merges_.end())
|
||||
{
|
||||
header.flags = 0;
|
||||
header.start = 0;
|
||||
header.size = (uint32_t)section.second.body.size();
|
||||
}
|
||||
else
|
||||
{
|
||||
header.flags = SECTION_MERGED_INTO_RDATA;
|
||||
header.start = merge_it->second.start;
|
||||
header.size = merge_it->second.size;
|
||||
}
|
||||
|
||||
// Skip section header
|
||||
auto header_pos = writer.position();
|
||||
writer.skip(sizeof(header));
|
||||
|
||||
if (merge_it == rdata_section_merges_.end())
|
||||
{
|
||||
header.start = (uint32_t)writer.align_position(alignment_);
|
||||
// write content
|
||||
writer.write_array(std::span<uint8_t const>(section.second.body));
|
||||
}
|
||||
|
||||
// write section header
|
||||
auto end_pos = writer.position();
|
||||
writer.position(header_pos);
|
||||
writer.write(header);
|
||||
writer.position(end_pos);
|
||||
}
|
||||
|
||||
writer.align_position(8);
|
||||
auto end_pos = writer.position();
|
||||
|
||||
// header
|
||||
module_header header {};
|
||||
header.type = module_type();
|
||||
header.size = (uint32_t)(end_pos - header_pos - sizeof(module_header));
|
||||
header.mempools = (uint32_t)params_.module_sched.max_usages.size();
|
||||
function_header header {};
|
||||
header.header_size = sizeof(header);
|
||||
header.size = (uint32_t)(end_pos - header_pos);
|
||||
header.input_pool_size = (uint32_t)function_sched.input_pool_size;
|
||||
header.output_pool_size = (uint32_t)function_sched.output_pool_size;
|
||||
header.inputs = (uint32_t)inputs.size();
|
||||
header.outputs = (uint32_t)outputs.size();
|
||||
header.sections = (uint32_t)section_writer_.size();
|
||||
header.reserved0 = 0;
|
||||
auto entrypoint = entry_points_.at(&function_sched);
|
||||
header.entrypoint = (uint32_t)entrypoint;
|
||||
header.text_size = (uint32_t)(function_text_end_.at(&function_sched) - entrypoint);
|
||||
writer.position(header_pos);
|
||||
writer.write(header);
|
||||
|
||||
|
@ -383,3 +451,19 @@ void module_builder::build(binary_writer &writer)
|
|||
link();
|
||||
write_binary(writer);
|
||||
}
|
||||
|
||||
void module_builder::begin_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
|
||||
{
|
||||
}
|
||||
|
||||
void module_builder::end_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
|
||||
{
|
||||
}
|
||||
|
||||
void module_builder::begin_emit_module()
|
||||
{
|
||||
}
|
||||
|
||||
void module_builder::end_emit_module()
|
||||
{
|
||||
}
|
||||
|
|
|
@ -39,11 +39,26 @@ module_type_t stackvm_module_builder::module_type() const noexcept
|
|||
return stackvm_module_type;
|
||||
}
|
||||
|
||||
uint32_t stackvm_module_builder::module_version() const noexcept
|
||||
{
|
||||
return stackvm_module_version;
|
||||
}
|
||||
|
||||
section_writer &stackvm_module_builder::text_writer()
|
||||
{
|
||||
return writer(".text");
|
||||
}
|
||||
|
||||
void stackvm_module_builder::begin_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
|
||||
{
|
||||
set_current_entry_point(text_writer().position());
|
||||
}
|
||||
|
||||
void stackvm_module_builder::end_emit_function([[maybe_unused]] const schedule::function_schedule_result &function)
|
||||
{
|
||||
set_current_function_text_end(text_writer().position());
|
||||
}
|
||||
|
||||
void stackvm_module_builder::emit(ir::node &node)
|
||||
{
|
||||
stackvm_op_builder builder(node, text_writer());
|
||||
|
|
|
@ -59,10 +59,13 @@ public:
|
|||
stackvm_module_builder(std::string_view module_name, const module_builder_params ¶ms);
|
||||
|
||||
module_type_t module_type() const noexcept override;
|
||||
uint32_t module_version() const noexcept override;
|
||||
|
||||
protected:
|
||||
section_writer &text_writer();
|
||||
|
||||
void begin_emit_function(const schedule::function_schedule_result &function) override;
|
||||
void end_emit_function(const schedule::function_schedule_result &function) override;
|
||||
void emit(ir::node &node) override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/7/14 19:17:49 +08:00.
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/8/11 17:40:11 +08:00.
|
||||
*
|
||||
* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
|
@ -513,9 +513,9 @@ void op_builder::tensor_binary_(datatype_t datatype, uint8_t rshape_src1, uint8_
|
|||
op_writer<tensor_binary_op_t>()(tensor_binary_op_t(datatype, rshape_src1, rstride_src1, rshape_src2, rstride_src2, rstride_dest, binary_op, fused_clamp_low, fused_clamp_high), writer_);
|
||||
}
|
||||
|
||||
void op_builder::tensor_call_(uint32_t module_id, uint8_t num_src, uint8_t num_dst)
|
||||
void op_builder::tensor_call_(uint32_t function_id, uint16_t module_id, uint8_t num_src, uint8_t num_dst)
|
||||
{
|
||||
op_writer<tensor_call_op_t>()(tensor_call_op_t(module_id, num_src, num_dst), writer_);
|
||||
op_writer<tensor_call_op_t>()(tensor_call_op_t(function_id, module_id, num_src, num_dst), writer_);
|
||||
}
|
||||
|
||||
void op_builder::tensor_conv2d_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rshape_kernel, uint8_t rstride_kernel, uint8_t rstride_bias, uint8_t rstride_dest, uint16_t groups, uint16_t stride_h, uint16_t stride_w, uint16_t dilation_h, uint16_t dilation_w, float fused_clamp_low, float fused_clamp_high)
|
||||
|
|
|
@ -21,7 +21,7 @@ using namespace nncase::ir;
|
|||
|
||||
void stackvm_module_builder::emit(call &node, stackvm_op_builder &builder)
|
||||
{
|
||||
auto target_id = module_id(&node.target());
|
||||
auto target_id = function_id(&node.target());
|
||||
|
||||
uint8_t rshape = 0;
|
||||
for (auto in : node.inputs())
|
||||
|
@ -46,5 +46,5 @@ void stackvm_module_builder::emit(call &node, stackvm_op_builder &builder)
|
|||
builder.ldc_i4_(rshape++);
|
||||
}
|
||||
|
||||
builder.tensor_call_(target_id, (uint8_t)node.inputs().size(), (uint8_t)node.outputs().size());
|
||||
builder.tensor_call_(target_id.function_id, target_id.module_id, (uint8_t)node.inputs().size(), (uint8_t)node.outputs().size());
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
set(SRCS evaluator.cpp
|
||||
quantizer.cpp
|
||||
evaluate_context.cpp
|
||||
ops/neutral/neutral_ops.cpp)
|
||||
|
||||
add_library(evaluator OBJECT ${SRCS})
|
||||
|
|
|
@ -0,0 +1,238 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <chrono>
|
||||
#include <nncase/ir/evaluator.h>
|
||||
#include <nncase/ir/op_utils.h>
|
||||
#include <nncase/ir/ops/constant.h>
|
||||
#include <nncase/ir/quantizer.h>
|
||||
#include <nncase/ir/runtime_type_utils.h>
|
||||
#include <nncase/targets/target.h>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::ir;
|
||||
using namespace nncase::schedule;
|
||||
using namespace nncase::runtime;
|
||||
namespace chrono = std::chrono;
|
||||
|
||||
namespace
|
||||
{
|
||||
std::unordered_map<node_opcode, std::function<void(ir::node &, function_evaluate_context &)>> g_evaluators;
|
||||
|
||||
auto &get_evaluator(node_opcode opcode)
|
||||
{
|
||||
auto it = g_evaluators.find(opcode);
|
||||
if (it == std::end(g_evaluators))
|
||||
throw std::runtime_error("Evaluator for " + std::string(opcode.name) + " is not found");
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
void nncase::ir::register_evaluator(ir::node_opcode opcode, std::function<void(ir::node &, function_evaluate_context &)> evaluator)
|
||||
{
|
||||
g_evaluators.emplace(opcode, std::move(evaluator));
|
||||
}
|
||||
|
||||
evaluate_tensor::evaluate_tensor(datatype_t datatype, runtime_shape_t shape, runtime_shape_t strides, gsl::span<gsl::byte> buffer)
|
||||
: datatype_(datatype), shape_(std::move(shape)), strides_(std::move(strides)), buffer_(buffer)
|
||||
{
|
||||
}
|
||||
|
||||
function_evaluate_context::function_evaluate_context(const function_schedule_result &sched, module_evaluate_context &mod_eval)
|
||||
: sched_(sched), mod_eval_(mod_eval)
|
||||
{
|
||||
input_pool_ = std::make_unique<std::byte[]>(sched.input_pool_size);
|
||||
output_pool_ = std::make_unique<std::byte[]>(sched.output_pool_size);
|
||||
|
||||
for (auto &&node : sched.compute_sequence)
|
||||
{
|
||||
auto &opcode = node->runtime_opcode();
|
||||
if (opcode == op_input_node)
|
||||
{
|
||||
inputs_.emplace_back(&node->output_at(0));
|
||||
}
|
||||
else if (opcode == op_output_node)
|
||||
{
|
||||
outputs_.emplace_back(&node->input_at(0));
|
||||
}
|
||||
else if (opcode == op_constant)
|
||||
{
|
||||
auto &rnode = static_cast<constant &>(*node);
|
||||
auto src = rnode.data();
|
||||
auto dest = memory_at(rnode.output()).buffer().as_span<std::byte>();
|
||||
std::copy(std::begin(src), std::end(src), dest.begin());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
evaluate_tensor function_evaluate_context::memory_at(const output_connector &conn)
|
||||
{
|
||||
auto &alloc = module().sched().allocations.at(&conn);
|
||||
std::byte *base;
|
||||
switch (alloc.memory_location)
|
||||
{
|
||||
case mem_input:
|
||||
base = input_pool_.get();
|
||||
break;
|
||||
case mem_output:
|
||||
base = output_pool_.get();
|
||||
break;
|
||||
default:
|
||||
base = module().memory_pool(alloc.memory_location);
|
||||
break;
|
||||
}
|
||||
|
||||
gsl::span<gsl::byte> buffer(reinterpret_cast<gsl::byte *>(base + alloc.start), alloc.size);
|
||||
return evaluate_tensor(alloc.type, to(alloc.shape), to(alloc.strides), buffer);
|
||||
}
|
||||
|
||||
void function_evaluate_context::evaluate()
|
||||
{
|
||||
using clock = chrono::high_resolution_clock;
|
||||
chrono::nanoseconds total_duration = {};
|
||||
auto quantizer = module().quantizer();
|
||||
|
||||
for (auto &&node : sched_.compute_sequence)
|
||||
{
|
||||
auto &evaluator = get_evaluator(node->runtime_opcode());
|
||||
|
||||
auto start = clock::now();
|
||||
evaluator(*node, *this);
|
||||
auto duration = clock::now() - start;
|
||||
total_duration += duration;
|
||||
|
||||
if (quantizer)
|
||||
{
|
||||
for (auto out : node->outputs())
|
||||
{
|
||||
if (out->attributes() & cnctr_attr_need_quantize)
|
||||
{
|
||||
if (!quantizer->has_record(*out))
|
||||
{
|
||||
auto mem = memory_at(*out);
|
||||
auto dtype = mem.datatype();
|
||||
if (dtype == dt_bfloat16)
|
||||
{
|
||||
auto buffer = mem.buffer().as_span<bfloat16>();
|
||||
quantizer->record(*out, buffer);
|
||||
}
|
||||
else if (dtype == dt_float32)
|
||||
{
|
||||
auto buffer = mem.buffer().as_span<float>();
|
||||
quantizer->record(*out, buffer);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Quantizer doesn't support datatype of " + std::string(datatype_names(dtype)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module_evaluate_context::module_evaluate_context(const module_schedule_result &sched, model_evaluate_context &model_eval)
|
||||
: sched_(sched), model_eval_(model_eval), quantizer_(nullptr)
|
||||
{
|
||||
for (auto &&usage : sched.max_usages)
|
||||
memory_pools_.emplace(usage.first, std::make_unique<std::byte[]>(usage.second));
|
||||
|
||||
for (auto &func : sched.functions)
|
||||
{
|
||||
functions_.emplace(std::piecewise_construct,
|
||||
std::forward_as_tuple(func.graph),
|
||||
std::forward_as_tuple(func, *this));
|
||||
}
|
||||
}
|
||||
|
||||
std::byte *module_evaluate_context::memory_pool(memory_location_t location) const
|
||||
{
|
||||
return memory_pools_.at(location).get();
|
||||
}
|
||||
|
||||
function_evaluate_context &module_evaluate_context::function(ir::graph &function)
|
||||
{
|
||||
return functions_.at(&function);
|
||||
}
|
||||
|
||||
void module_evaluate_context::enable_ptq(target &target, ir::calibrate_method calib_method)
|
||||
{
|
||||
quantizer_ = target.create_quantizer(sched_.type, calib_method);
|
||||
}
|
||||
|
||||
void module_evaluate_context::begin_collect_distribution()
|
||||
{
|
||||
if (quantizer_)
|
||||
quantizer_->begin_collect_distribution();
|
||||
}
|
||||
|
||||
void module_evaluate_context::end_sample()
|
||||
{
|
||||
if (quantizer_)
|
||||
quantizer_->end_sample();
|
||||
}
|
||||
|
||||
void module_evaluate_context::end_collect_distribution(const std::function<void(size_t cnt, size_t total)> &progress)
|
||||
{
|
||||
if (quantizer_)
|
||||
quantizer_->end_collect_distribution(progress);
|
||||
}
|
||||
|
||||
model_evaluate_context::model_evaluate_context(const schedule::model_schedule_result &sched)
|
||||
: sched_(sched)
|
||||
{
|
||||
for (auto &module : sched.modules)
|
||||
module_ctxs_.emplace(std::piecewise_construct, std::forward_as_tuple(module.type), std::forward_as_tuple(module, *this));
|
||||
}
|
||||
|
||||
function_evaluate_context &model_evaluate_context::entrypoint()
|
||||
{
|
||||
auto func = sched_.entry_function;
|
||||
return module_ctxs_.at(func->module->type).function(*func->graph);
|
||||
}
|
||||
|
||||
module_evaluate_context &model_evaluate_context::module(const module_type_t &module_type)
|
||||
{
|
||||
return module_ctxs_.at(module_type);
|
||||
}
|
||||
|
||||
void model_evaluate_context::enable_ptq(target &target, ir::calibrate_method calib_method)
|
||||
{
|
||||
for (auto &mod : module_ctxs_)
|
||||
mod.second.enable_ptq(target, calib_method);
|
||||
}
|
||||
|
||||
void model_evaluate_context::begin_collect_distribution()
|
||||
{
|
||||
for (auto &mod : module_ctxs_)
|
||||
mod.second.begin_collect_distribution();
|
||||
}
|
||||
|
||||
void model_evaluate_context::end_sample()
|
||||
{
|
||||
for (auto &mod : module_ctxs_)
|
||||
mod.second.end_sample();
|
||||
}
|
||||
|
||||
void model_evaluate_context::end_collect_distribution(const std::function<void(size_t cnt, size_t total)> &progress)
|
||||
{
|
||||
for (auto &mod : module_ctxs_)
|
||||
mod.second.end_collect_distribution(progress);
|
||||
}
|
||||
|
||||
void model_evaluate_context::evaluate()
|
||||
{
|
||||
entrypoint().evaluate();
|
||||
}
|
|
@ -12,196 +12,66 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "nncase/ir/quantizer.h"
|
||||
#include <chrono>
|
||||
#include <nncase/ir/evaluator.h>
|
||||
#include <nncase/ir/op_utils.h>
|
||||
#include <nncase/ir/ops/constant.h>
|
||||
#include <nncase/ir/runtime_type_utils.h>
|
||||
#include <nncase/targets/target.h>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::ir;
|
||||
using namespace nncase::schedule;
|
||||
using namespace nncase::runtime;
|
||||
namespace chrono = std::chrono;
|
||||
|
||||
#define PROFILE 0
|
||||
|
||||
namespace
|
||||
evaluator::evaluator(const schedule::model_schedule_result &sched)
|
||||
: model_eval_(sched)
|
||||
{
|
||||
std::unordered_map<node_opcode, std::function<void(ir::node &, module_evaluate_context &)>> g_evaluators;
|
||||
|
||||
auto &get_evaluator(node_opcode opcode)
|
||||
{
|
||||
auto it = g_evaluators.find(opcode);
|
||||
if (it == std::end(g_evaluators))
|
||||
throw std::runtime_error("Evaluator for " + std::string(opcode.name) + " is not found");
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
void nncase::ir::register_evaluator(ir::node_opcode opcode, std::function<void(ir::node &, module_evaluate_context &)> evaluator)
|
||||
{
|
||||
g_evaluators.emplace(opcode, std::move(evaluator));
|
||||
}
|
||||
|
||||
evaluate_tensor::evaluate_tensor(datatype_t datatype, runtime_shape_t shape, runtime_shape_t strides, gsl::span<gsl::byte> buffer)
|
||||
: datatype_(datatype), shape_(std::move(shape)), strides_(std::move(strides)), buffer_(buffer)
|
||||
{
|
||||
}
|
||||
|
||||
module_evaluate_context::module_evaluate_context(const module_schedule_result &sched)
|
||||
: sched_(sched), quantizer_(nullptr)
|
||||
{
|
||||
for (auto &&usage : sched.max_usages)
|
||||
memory_pools_.emplace(usage.first, std::make_unique<std::byte[]>(usage.second));
|
||||
|
||||
for (auto &&node : sched.compute_sequence)
|
||||
{
|
||||
auto &opcode = node->runtime_opcode();
|
||||
if (opcode == op_input_node)
|
||||
{
|
||||
inputs_.emplace_back(&node->output_at(0));
|
||||
}
|
||||
else if (opcode == op_output_node)
|
||||
{
|
||||
outputs_.emplace_back(&node->input_at(0));
|
||||
}
|
||||
else if (opcode == op_constant)
|
||||
{
|
||||
auto &rnode = static_cast<constant &>(*node);
|
||||
auto src = rnode.data();
|
||||
auto dest = memory_at(rnode.output()).buffer().as_span<std::byte>();
|
||||
std::copy(std::begin(src), std::end(src), dest.begin());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
evaluate_tensor module_evaluate_context::memory_at(const output_connector &conn)
|
||||
{
|
||||
auto &alloc = sched_.allocations.at(&conn);
|
||||
auto &memory_pool = memory_pools_.at(alloc.memory_location);
|
||||
gsl::span<gsl::byte> buffer(reinterpret_cast<gsl::byte *>(memory_pool.get() + alloc.start), alloc.size);
|
||||
return evaluate_tensor(alloc.type, to(alloc.shape), to(alloc.strides), buffer);
|
||||
}
|
||||
|
||||
void module_evaluate_context::enable_ptq(target &target, ir::calibrate_method calib_method)
|
||||
{
|
||||
quantizer_ = target.create_quantizer(sched_.graph->module_type(), calib_method);
|
||||
}
|
||||
|
||||
void module_evaluate_context::evaluate()
|
||||
{
|
||||
using clock = chrono::high_resolution_clock;
|
||||
chrono::nanoseconds total_duration = {};
|
||||
|
||||
for (auto &&node : sched_.compute_sequence)
|
||||
{
|
||||
auto &evaluator = get_evaluator(node->runtime_opcode());
|
||||
|
||||
auto start = clock::now();
|
||||
evaluator(*node, *this);
|
||||
auto duration = clock::now() - start;
|
||||
total_duration += duration;
|
||||
#if PROFILE
|
||||
std::cout << node->name() << "/" << node->runtime_opcode().name << ": " << std::endl;
|
||||
#endif
|
||||
|
||||
if (quantizer_)
|
||||
{
|
||||
for (auto out : node->outputs())
|
||||
{
|
||||
if (out->attributes() & cnctr_attr_need_quantize)
|
||||
{
|
||||
if (!quantizer_->has_record(*out))
|
||||
{
|
||||
auto mem = memory_at(*out);
|
||||
if (mem.datatype() == dt_bfloat16)
|
||||
{
|
||||
auto buffer = mem.buffer().as_span<bfloat16>();
|
||||
quantizer_->record(*out, buffer);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto buffer = mem.buffer().as_span<float>();
|
||||
quantizer_->record(*out, buffer);
|
||||
}
|
||||
}
|
||||
#if PROFILE
|
||||
std::cout << "\t" << out->name() << " range: { " << quantizer_->get(*out).min << " ~ " << quantizer_->get(*out).max << " } , ";
|
||||
#endif
|
||||
}
|
||||
}
|
||||
// quantizer_->range()
|
||||
}
|
||||
|
||||
#if PROFILE
|
||||
std::cout << "\t duration: " << duration.count() / 1e6 << "ms" << std::endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
if (quantizer_)
|
||||
quantizer_->reset_record();
|
||||
|
||||
#if PROFILE
|
||||
std::cout << "Total: " << total_duration.count() / 1e6 << "ms" << std::endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
void module_evaluate_context::begin_collect_distribution()
|
||||
{
|
||||
if (quantizer_)
|
||||
quantizer_->begin_collect_distribution();
|
||||
}
|
||||
|
||||
void module_evaluate_context::end_collect_distribution(std::function<void(size_t cnt, size_t total)> progress)
|
||||
{
|
||||
if (quantizer_)
|
||||
quantizer_->end_collect_distribution(progress);
|
||||
}
|
||||
|
||||
evaluator::evaluator(const schedule::schedule_result &sched)
|
||||
: sched_(sched)
|
||||
{
|
||||
for (auto &module_p : sched.modules)
|
||||
module_ctxs_.emplace(module_p.first, module_p.second);
|
||||
}
|
||||
|
||||
module_evaluate_context &evaluator::module_context(ir::graph &graph)
|
||||
{
|
||||
return module_ctxs_.at(&graph);
|
||||
}
|
||||
|
||||
module_evaluate_context &evaluator::main_module_context()
|
||||
{
|
||||
return module_context(*sched_.main_module);
|
||||
}
|
||||
|
||||
evaluate_tensor evaluator::memory_at(const output_connector &conn)
|
||||
{
|
||||
return main_module_context().memory_at(conn);
|
||||
}
|
||||
|
||||
void evaluator::enable_ptq(target &target, ir::calibrate_method calib_method)
|
||||
{
|
||||
for (auto &module_p : module_ctxs_)
|
||||
module_p.second.enable_ptq(target, calib_method);
|
||||
return model_eval_.enable_ptq(target, calib_method);
|
||||
}
|
||||
|
||||
void evaluator::evaluate()
|
||||
{
|
||||
module_ctxs_.at(sched_.main_module).evaluate();
|
||||
return model_eval_.evaluate();
|
||||
}
|
||||
|
||||
quantizer *evaluator::quantizer(const module_type_t &module_type)
|
||||
{
|
||||
return model_eval_.module(module_type).quantizer();
|
||||
}
|
||||
|
||||
void evaluator::begin_collect_distribution()
|
||||
{
|
||||
for (auto &module_p : module_ctxs_)
|
||||
module_p.second.begin_collect_distribution();
|
||||
model_eval_.end_sample();
|
||||
}
|
||||
|
||||
void evaluator::end_collect_distribution(std::function<void(size_t cnt, size_t total)> progress)
|
||||
void evaluator::end_sample()
|
||||
{
|
||||
for (auto &module_p : module_ctxs_)
|
||||
module_p.second.end_collect_distribution(progress);
|
||||
model_eval_.end_sample();
|
||||
}
|
||||
|
||||
void evaluator::end_collect_distribution(const std::function<void(size_t cnt, size_t total)> &progress)
|
||||
{
|
||||
model_eval_.end_collect_distribution(progress);
|
||||
}
|
||||
|
||||
evaluate_tensor evaluator::memory_at(const output_connector &conn)
|
||||
{
|
||||
return model_eval_.memory_at(conn);
|
||||
}
|
||||
|
||||
evaluate_tensor evaluator::memory_at(const input_connector &conn)
|
||||
{
|
||||
return model_eval_.memory_at(conn);
|
||||
}
|
||||
|
||||
evaluate_tensor evaluator::input_at(size_t index)
|
||||
{
|
||||
return model_eval_.input_at(index);
|
||||
}
|
||||
|
||||
evaluate_tensor evaluator::output_at(size_t index)
|
||||
{
|
||||
return model_eval_.output_at(index);
|
||||
}
|
||||
|
|
|
@ -70,7 +70,7 @@ using namespace nncase::runtime;
|
|||
|
||||
namespace
|
||||
{
|
||||
void nop_evaluator(ir::node &, module_evaluate_context &)
|
||||
void nop_evaluator(ir::node &, function_evaluate_context &)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
@ -84,7 +84,7 @@ void register_neutral_evaluators()
|
|||
register_evaluator(op_ignore_node, nop_evaluator);
|
||||
register_evaluator(op_constant, nop_evaluator);
|
||||
|
||||
register_evaluator(op_batch_to_space, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_batch_to_space, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<batch_to_space &>(node);
|
||||
|
||||
auto input = context.memory_at(rnode.input());
|
||||
|
@ -97,7 +97,7 @@ void register_neutral_evaluators()
|
|||
.unwrap_or_throw();
|
||||
});
|
||||
|
||||
register_evaluator(op_binary, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_binary, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<binary &>(node);
|
||||
|
||||
assert(rnode.input_a().type() == dt_float32);
|
||||
|
@ -112,7 +112,7 @@ void register_neutral_evaluators()
|
|||
.unwrap_or_throw();
|
||||
});
|
||||
|
||||
register_evaluator(op_broadcast, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_broadcast, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<broadcast &>(node);
|
||||
|
||||
auto input = context.memory_at(rnode.input());
|
||||
|
@ -122,7 +122,7 @@ void register_neutral_evaluators()
|
|||
.unwrap_or_throw();
|
||||
});
|
||||
|
||||
register_evaluator(op_concat, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_concat, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<concat &>(node);
|
||||
|
||||
std::vector<const gsl::byte *> inputs_mem;
|
||||
|
@ -141,7 +141,7 @@ void register_neutral_evaluators()
|
|||
.unwrap_or_throw();
|
||||
});
|
||||
|
||||
register_evaluator(op_conv2d, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_conv2d, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<conv2d &>(node);
|
||||
|
||||
assert(rnode.input().type() == dt_float32);
|
||||
|
@ -161,7 +161,7 @@ void register_neutral_evaluators()
|
|||
.unwrap_or_throw();
|
||||
});
|
||||
|
||||
register_evaluator(op_conv2d_transpose, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_conv2d_transpose, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<conv2d_transpose &>(node);
|
||||
|
||||
assert(rnode.input().type() == dt_float32);
|
||||
|
@ -175,7 +175,7 @@ void register_neutral_evaluators()
|
|||
rnode.dilation_h(), rnode.dilation_w(), rnode.padding_h(), rnode.padding_w(), rnode.fused_activation());
|
||||
});
|
||||
|
||||
register_evaluator(op_dequantize, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_dequantize, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<dequantize &>(node);
|
||||
|
||||
auto output = context.memory_at(rnode.output()).buffer().as_span<float>();
|
||||
|
@ -199,7 +199,7 @@ void register_neutral_evaluators()
|
|||
}
|
||||
});
|
||||
|
||||
register_evaluator(op_fused_unary, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_fused_unary, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<fused_unary &>(node);
|
||||
|
||||
auto input = context.memory_at(rnode.input()).buffer().as_span<float>();
|
||||
|
@ -217,7 +217,7 @@ void register_neutral_evaluators()
|
|||
.unwrap_or_throw();
|
||||
});
|
||||
|
||||
register_evaluator(op_matmul, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_matmul, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<matmul &>(node);
|
||||
|
||||
assert(rnode.input_a().type() == dt_float32);
|
||||
|
@ -233,7 +233,7 @@ void register_neutral_evaluators()
|
|||
neutral::matmul(input_a.data(), input_b.data(), output.data(), bias.data(), (int32_t)a_shape[0], (int32_t)a_shape[1], (int32_t)b_shape[1], rnode.fused_activation());
|
||||
});
|
||||
|
||||
register_evaluator(op_pad, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_pad, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<pad &>(node);
|
||||
|
||||
auto input = context.memory_at(rnode.input());
|
||||
|
@ -246,7 +246,7 @@ void register_neutral_evaluators()
|
|||
.unwrap_or_throw();
|
||||
});
|
||||
|
||||
register_evaluator(op_quantize, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_quantize, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<quantize &>(node);
|
||||
|
||||
auto input = context.memory_at(rnode.input()).buffer().as_span<float>();
|
||||
|
@ -255,7 +255,7 @@ void register_neutral_evaluators()
|
|||
neutral::quantize(input.data(), output.data(), xt::compute_size(rnode.input().shape()), rnode.quant_param());
|
||||
});
|
||||
|
||||
register_evaluator(op_reduce, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_reduce, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<reduce &>(node);
|
||||
|
||||
assert(rnode.input().type() == dt_float32);
|
||||
|
@ -269,7 +269,7 @@ void register_neutral_evaluators()
|
|||
.unwrap_or_throw();
|
||||
});
|
||||
|
||||
register_evaluator(op_reduce_window2d, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_reduce_window2d, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<reduce_window2d &>(node);
|
||||
|
||||
assert(rnode.input().type() == dt_float32);
|
||||
|
@ -284,7 +284,7 @@ void register_neutral_evaluators()
|
|||
.unwrap_or_throw();
|
||||
});
|
||||
|
||||
register_evaluator(op_bitcast, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_bitcast, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<bitcast &>(node);
|
||||
|
||||
auto input = context.memory_at(rnode.input()).buffer();
|
||||
|
@ -293,7 +293,7 @@ void register_neutral_evaluators()
|
|||
std::copy(input.begin(), input.end(), output.begin());
|
||||
});
|
||||
|
||||
register_evaluator(op_resize_image, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_resize_image, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<resize_image &>(node);
|
||||
|
||||
auto input = context.memory_at(rnode.input());
|
||||
|
@ -317,7 +317,7 @@ void register_neutral_evaluators()
|
|||
}
|
||||
});
|
||||
|
||||
register_evaluator(op_slice, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_slice, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<slice &>(node);
|
||||
|
||||
auto input = context.memory_at(rnode.input());
|
||||
|
@ -330,7 +330,7 @@ void register_neutral_evaluators()
|
|||
.unwrap_or_throw();
|
||||
});
|
||||
|
||||
register_evaluator(op_transpose, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_transpose, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<transpose &>(node);
|
||||
|
||||
auto input = context.memory_at(rnode.input());
|
||||
|
@ -343,7 +343,7 @@ void register_neutral_evaluators()
|
|||
.unwrap_or_throw();
|
||||
});
|
||||
|
||||
register_evaluator(op_unary, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_unary, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<unary &>(node);
|
||||
|
||||
assert(rnode.input().type() == dt_float32);
|
||||
|
@ -416,7 +416,7 @@ void register_neutral_evaluators()
|
|||
}
|
||||
});
|
||||
|
||||
register_evaluator(op_table_lookup1d, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_table_lookup1d, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<table_lookup1d &>(node);
|
||||
|
||||
assert(rnode.input().type() == dt_uint8);
|
||||
|
@ -427,7 +427,7 @@ void register_neutral_evaluators()
|
|||
kernels::neutral::table_lookup1d(input.data(), output.data(), input.size(), table.data());
|
||||
});
|
||||
|
||||
register_evaluator(op_clamp, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_clamp, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<clamp &>(node);
|
||||
|
||||
assert(rnode.input().type() == dt_float32);
|
||||
|
@ -446,7 +446,7 @@ void register_neutral_evaluators()
|
|||
}
|
||||
});
|
||||
|
||||
register_evaluator(op_convert, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_convert, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<convert &>(node);
|
||||
|
||||
auto input = context.memory_at(rnode.input());
|
||||
|
@ -459,7 +459,7 @@ void register_neutral_evaluators()
|
|||
.unwrap_or_throw();
|
||||
});
|
||||
|
||||
register_evaluator(op_gather, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_gather, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<gather &>(node);
|
||||
|
||||
auto input = context.memory_at(rnode.input());
|
||||
|
@ -473,7 +473,7 @@ void register_neutral_evaluators()
|
|||
.unwrap_or_throw();
|
||||
});
|
||||
|
||||
register_evaluator(op_gather_nd, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_gather_nd, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<gather_nd &>(node);
|
||||
|
||||
auto input = context.memory_at(rnode.input());
|
||||
|
@ -487,7 +487,7 @@ void register_neutral_evaluators()
|
|||
.unwrap_or_throw();
|
||||
});
|
||||
|
||||
register_evaluator(op_onehot, [](ir::node &node, module_evaluate_context &context) {
|
||||
register_evaluator(op_onehot, [](ir::node &node, function_evaluate_context &context) {
|
||||
auto &rnode = static_cast<onehot &>(node);
|
||||
|
||||
auto indices = context.memory_at(rnode.indices());
|
||||
|
|
|
@ -356,7 +356,7 @@ private:
|
|||
{
|
||||
auto graph_runner = [&](ir::graph &graph) {
|
||||
ir::transforms::pass_manager pmgr(graph, *target_);
|
||||
auto quant = evaluator.module_context(graph).quantizer();
|
||||
auto quant = evaluator.quantizer(graph.module_type());
|
||||
|
||||
if (!compile_options_.use_dataset_as_input_stat)
|
||||
{
|
||||
|
@ -486,6 +486,7 @@ private:
|
|||
std::memcpy(input_buffer.data(), tensor.data(), input_buffer.size_bytes());
|
||||
|
||||
evaluator.evaluate();
|
||||
evaluator.end_sample();
|
||||
if (options.progress)
|
||||
options.progress(i++, dataset.total_size());
|
||||
}
|
||||
|
@ -519,6 +520,7 @@ private:
|
|||
std::memcpy(input_buffer.data(), options.tensor_data.data() + i * input_buffer.size_bytes(), input_buffer.size_bytes());
|
||||
|
||||
evaluator.evaluate();
|
||||
evaluator.end_sample();
|
||||
if (options.progress)
|
||||
options.progress(i++, options.samples_count);
|
||||
}
|
||||
|
|
|
@ -3,9 +3,11 @@
|
|||
set(SRCS interpreter.cpp
|
||||
error.cpp
|
||||
runtime_loader.cpp
|
||||
runtime_function.cpp
|
||||
runtime_module.cpp
|
||||
runtime_tensor.cpp
|
||||
runtime_tensor_impl.cpp
|
||||
section.cpp
|
||||
host_runtime_tensor.cpp
|
||||
allocator.cpp)
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
*/
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <nncase/runtime/dbg.h>
|
||||
#include <nncase/runtime/error.h>
|
||||
#include <nncase/runtime/interpreter.h>
|
||||
#include <nncase/runtime/runtime_loader.h>
|
||||
|
@ -23,7 +24,7 @@ using namespace nncase;
|
|||
using namespace nncase::runtime;
|
||||
|
||||
interpreter::interpreter() noexcept
|
||||
: main_module_(nullptr)
|
||||
: entry_function_(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -49,82 +50,79 @@ result<void> interpreter::load_model(gsl::span<const gsl::byte> buffer) noexcept
|
|||
|
||||
for (size_t i = 0; i < header->modules; i++)
|
||||
{
|
||||
auto mod_header = reader.get_ref<module_header>();
|
||||
reader.skip(mod_header->size);
|
||||
try_var(rt_module, runtime_module::create(mod_header->type));
|
||||
auto mod_type = reader.peek_with_offset<decltype(module_header::type)>(offsetof(module_header, type));
|
||||
auto mod_size = reader.peek_with_offset<decltype(module_header::size)>(offsetof(module_header, size));
|
||||
auto payload = reader.read_span(mod_size);
|
||||
try_var(rt_module, runtime_module::create(mod_type));
|
||||
|
||||
try_(rt_module->initialize(*mod_header, *this));
|
||||
if (i == header->main_module)
|
||||
main_module_ = rt_module.get();
|
||||
try_(rt_module->initialize(payload, *this));
|
||||
if (i == header->entry_module)
|
||||
try_set(entry_function_, rt_module->find_function_by_id(header->entry_function));
|
||||
modules_[i] = std::move(rt_module);
|
||||
}
|
||||
|
||||
for (auto &mod : modules_)
|
||||
try_(mod->initialize_inter_modules(*this));
|
||||
|
||||
return ok();
|
||||
}
|
||||
|
||||
size_t interpreter::inputs_size() const noexcept
|
||||
{
|
||||
return main_module_->inputs_size();
|
||||
return entry_function_->inputs_size();
|
||||
}
|
||||
|
||||
size_t interpreter::outputs_size() const noexcept
|
||||
{
|
||||
return main_module_->outputs_size();
|
||||
return entry_function_->outputs_size();
|
||||
}
|
||||
|
||||
const memory_range &interpreter::input_desc(size_t index) const noexcept
|
||||
{
|
||||
return main_module_->input_desc(index);
|
||||
return entry_function_->input_desc(index);
|
||||
}
|
||||
|
||||
const memory_range &interpreter::output_desc(size_t index) const noexcept
|
||||
{
|
||||
return main_module_->output_desc(index);
|
||||
return entry_function_->output_desc(index);
|
||||
}
|
||||
|
||||
const runtime_shape_t &interpreter::input_shape(size_t index) const noexcept
|
||||
{
|
||||
return main_module_->input_shape(index);
|
||||
return entry_function_->input_shape(index);
|
||||
}
|
||||
|
||||
const runtime_shape_t &interpreter::output_shape(size_t index) const noexcept
|
||||
{
|
||||
return main_module_->output_shape(index);
|
||||
return entry_function_->output_shape(index);
|
||||
}
|
||||
|
||||
result<runtime_tensor> interpreter::input_tensor(size_t index) noexcept
|
||||
{
|
||||
return main_module_->input_tensor(index);
|
||||
return entry_function_->input_tensor(index);
|
||||
}
|
||||
|
||||
result<void> interpreter::input_tensor(size_t index, runtime_tensor tensor) noexcept
|
||||
{
|
||||
return main_module_->input_tensor(index, tensor);
|
||||
return entry_function_->input_tensor(index, tensor);
|
||||
}
|
||||
|
||||
result<runtime_tensor> interpreter::output_tensor(size_t index) noexcept
|
||||
{
|
||||
return main_module_->output_tensor(index);
|
||||
return entry_function_->output_tensor(index);
|
||||
}
|
||||
|
||||
result<void> interpreter::output_tensor(size_t index, runtime_tensor tensor) noexcept
|
||||
{
|
||||
return main_module_->output_tensor(index, tensor);
|
||||
return entry_function_->output_tensor(index, tensor);
|
||||
}
|
||||
|
||||
result<void> interpreter::run() noexcept
|
||||
{
|
||||
return main_module_->run();
|
||||
return entry_function_->invoke();
|
||||
}
|
||||
|
||||
result<runtime_module *> interpreter::find_module_by_id(size_t index) noexcept
|
||||
{
|
||||
if (index < modules_.size())
|
||||
return ok(modules_[index].get());
|
||||
return err(std::errc::result_out_of_range);
|
||||
CHECK_WITH_ERR(index < modules_.size(), std::errc::result_out_of_range);
|
||||
return ok(modules_[index].get());
|
||||
}
|
||||
|
||||
options_dict &interpreter::options() noexcept
|
||||
|
|
|
@ -1,80 +0,0 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <runtime/cpu/cpu_ops_body.h>
|
||||
#include <runtime/k210/k210_ops_body.h>
|
||||
#include <runtime/kernel_registry.h>
|
||||
#include <runtime/neutral/neutral_ops_body.h>
|
||||
#include <runtime/span_reader.h>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
|
||||
namespace nncase
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
#define BEGINE_DEFINE_TARGET(target) \
|
||||
namespace target \
|
||||
{
|
||||
|
||||
#define DEFINE_NEUTRAL_RUNTIME_OP(id, name, value) \
|
||||
kernel_call_result id(id##_options &, interpreter_t &, interpreter_step_t);
|
||||
#define DEFINE_RUNTIME_OP(target, id, name, value) \
|
||||
kernel_call_result id(id##_options &, interpreter_t &, interpreter_step_t);
|
||||
|
||||
#define END_DEFINE_TARGET() }
|
||||
|
||||
#include <runtime/runtime_op.def>
|
||||
|
||||
#undef BEGINE_DEFINE_TARGET
|
||||
#undef DEFINE_NEUTRAL_RUNTIME_OP
|
||||
#undef DEFINE_RUNTIME_OP
|
||||
#undef END_DEFINE_TARGET
|
||||
}
|
||||
}
|
||||
|
||||
kernel_call_result runtime::call_kernel(runtime_opcode opcode, xtl::span<const uint8_t> body, interpreter_t &interpreter, interpreter_step_t step)
|
||||
{
|
||||
span_reader reader(body);
|
||||
|
||||
switch (opcode)
|
||||
{
|
||||
#define BEGINE_DEFINE_TARGET(...)
|
||||
#define DEFINE_NEUTRAL_RUNTIME_OP(id, name, value) \
|
||||
case rop_##id: \
|
||||
{ \
|
||||
nncase::runtime::neutral::id##_options options; \
|
||||
options.deserialize(reader); \
|
||||
return nncase::runtime::neutral::id(options, interpreter, step); \
|
||||
}
|
||||
#define DEFINE_RUNTIME_OP(target, id, name, value) \
|
||||
case rop_##target##_##id: \
|
||||
{ \
|
||||
nncase::runtime::target::id##_options options; \
|
||||
options.deserialize(reader); \
|
||||
return nncase::runtime::target::id(options, interpreter, step); \
|
||||
}
|
||||
#define END_DEFINE_TARGET()
|
||||
|
||||
#include <runtime/runtime_op.def>
|
||||
|
||||
#undef BEGINE_DEFINE_TARGET
|
||||
#undef DEFINE_NEUTRAL_RUNTIME_OP
|
||||
#undef DEFINE_RUNTIME_OP
|
||||
#undef END_DEFINE_TARGET
|
||||
default:
|
||||
return kcr_error;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,310 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "section.h"
|
||||
#include <nncase/runtime/dbg.h>
|
||||
#include <nncase/runtime/error.h>
|
||||
#include <nncase/runtime/runtime_function.h>
|
||||
#include <nncase/runtime/span_reader.h>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
|
||||
namespace
|
||||
{
|
||||
class runtime_function_init_context_impl : public runtime_function_init_context
|
||||
{
|
||||
public:
|
||||
runtime_function_init_context_impl(const function_header &header, runtime_module_init_context &module_init_context, gsl::span<const gsl::byte> body) noexcept
|
||||
: header_(header), module_init_context_(module_init_context), body_(body)
|
||||
{
|
||||
}
|
||||
|
||||
runtime_module_init_context &module_init_context() noexcept override
|
||||
{
|
||||
return module_init_context_;
|
||||
}
|
||||
|
||||
const function_header &header() noexcept override
|
||||
{
|
||||
return header_;
|
||||
}
|
||||
|
||||
gsl::span<const gsl::byte> body() noexcept override
|
||||
{
|
||||
return body_;
|
||||
}
|
||||
|
||||
private:
|
||||
const function_header &header_;
|
||||
runtime_module_init_context &module_init_context_;
|
||||
gsl::span<const gsl::byte> body_;
|
||||
};
|
||||
}
|
||||
|
||||
runtime_function::runtime_function(runtime_module &rt_module)
|
||||
: rt_module_(rt_module)
|
||||
{
|
||||
}
|
||||
|
||||
runtime_module &runtime_function::module() const noexcept
|
||||
{
|
||||
return rt_module_;
|
||||
}
|
||||
|
||||
uint32_t runtime_function::inputs_size() const noexcept
|
||||
{
|
||||
return header_.inputs;
|
||||
}
|
||||
|
||||
uint32_t runtime_function::outputs_size() const noexcept
|
||||
{
|
||||
return header_.outputs;
|
||||
}
|
||||
|
||||
const memory_range &runtime_function::input_desc(size_t index) const noexcept
|
||||
{
|
||||
assert(index < input_tensors_.size());
|
||||
return input_tensors_[index].range;
|
||||
}
|
||||
|
||||
const memory_range &runtime_function::output_desc(size_t index) const noexcept
|
||||
{
|
||||
assert(index < output_tensors_.size());
|
||||
return output_tensors_[index].range;
|
||||
}
|
||||
|
||||
const runtime_shape_t &runtime_function::input_shape(size_t index) const noexcept
|
||||
{
|
||||
assert(index < input_tensors_.size());
|
||||
return input_tensors_[index].shape;
|
||||
}
|
||||
|
||||
const runtime_shape_t &runtime_function::output_shape(size_t index) const noexcept
|
||||
{
|
||||
assert(index < output_tensors_.size());
|
||||
return output_tensors_[index].shape;
|
||||
}
|
||||
|
||||
result<void> runtime_function::initialize(gsl::span<const gsl::byte> payload, runtime_module_init_context &module_init_context) noexcept
|
||||
{
|
||||
span_reader reader(payload);
|
||||
reader.read(header_);
|
||||
|
||||
try
|
||||
{
|
||||
input_tensors_.resize(inputs_size());
|
||||
output_tensors_.resize(outputs_size());
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
return err(std::errc::not_enough_memory);
|
||||
}
|
||||
|
||||
auto read_shape = [&](runtime_shape_t &shape) {
|
||||
shape.resize(reader.read<uint32_t>());
|
||||
for (auto &dim : shape)
|
||||
dim = reader.read<uint32_t>();
|
||||
};
|
||||
|
||||
// inputs
|
||||
for (auto &in : input_tensors_)
|
||||
reader.read(in.range);
|
||||
for (auto &in : input_tensors_)
|
||||
read_shape(in.shape);
|
||||
|
||||
// outputs
|
||||
for (auto &out : output_tensors_)
|
||||
reader.read(out.range);
|
||||
for (auto &out : output_tensors_)
|
||||
read_shape(out.shape);
|
||||
|
||||
runtime_function_init_context_impl init_context(header_, module_init_context, reader.read_avail());
|
||||
return initialize_core(init_context);
|
||||
}
|
||||
|
||||
#define INOUT_TENSOR_GETTER_IMPL(name) \
|
||||
CHECK_WITH_ERR(index < name##_tensors_.size(), std::errc::result_out_of_range); \
|
||||
\
|
||||
auto &info = name##_tensors_[index]; \
|
||||
if (info.bind_tensor.empty()) \
|
||||
{ \
|
||||
try_set(info.bind_tensor, allocate_##name##_tensor(index)); \
|
||||
} \
|
||||
return ok(info.bind_tensor);
|
||||
|
||||
result<runtime_tensor> runtime_function::input_tensor(size_t index) noexcept
|
||||
{
|
||||
INOUT_TENSOR_GETTER_IMPL(input);
|
||||
}
|
||||
|
||||
result<runtime_tensor> runtime_function::output_tensor(size_t index) noexcept
|
||||
{
|
||||
INOUT_TENSOR_GETTER_IMPL(output);
|
||||
}
|
||||
|
||||
#define DEV_INOUT_TENSOR_GETTER_IMPL(name) \
|
||||
CHECK_WITH_ERR(index < name##_tensors_.size(), std::errc::result_out_of_range); \
|
||||
\
|
||||
auto &info = name##_tensors_[index]; \
|
||||
if (info.bind_tensor.empty()) \
|
||||
{ \
|
||||
try_set(info.bind_tensor, allocate_##name##_tensor(index)); \
|
||||
} \
|
||||
if (!info.device_tensor.empty()) \
|
||||
{ \
|
||||
return ok(info.device_tensor); \
|
||||
} \
|
||||
return ok(info.bind_tensor);
|
||||
|
||||
result<runtime_tensor> runtime_function::device_input_tensor(size_t index) noexcept
|
||||
{
|
||||
DEV_INOUT_TENSOR_GETTER_IMPL(input);
|
||||
}
|
||||
|
||||
result<runtime_tensor> runtime_function::device_output_tensor(size_t index) noexcept
|
||||
{
|
||||
DEV_INOUT_TENSOR_GETTER_IMPL(output);
|
||||
}
|
||||
|
||||
result<void> runtime_function::input_tensor(size_t index, runtime_tensor tensor) noexcept
|
||||
{
|
||||
CHECK_WITH_ERR(!tensor.empty(), std::errc::invalid_argument);
|
||||
CHECK_WITH_ERR(index < input_tensors_.size(), std::errc::result_out_of_range);
|
||||
|
||||
auto &info = input_tensors_[index];
|
||||
CHECK_WITH_ERR(info.range.datatype == tensor.datatype(), nncase_errc::datatype_mismatch);
|
||||
CHECK_WITH_ERR(info.shape == tensor.shape(), nncase_errc::shape_mismatch);
|
||||
|
||||
if (info.bind_tensor != tensor)
|
||||
{
|
||||
if (validate_input_tensor(index, tensor).is_err())
|
||||
{
|
||||
auto device_tensor = info.device_tensor;
|
||||
if (device_tensor.empty())
|
||||
try_var(device_tensor, allocate_input_tensor(index));
|
||||
if (!tensor.can_copy_to_without_staging(device_tensor))
|
||||
{
|
||||
try_set(info.staging_tensor, host_runtime_tensor::create(info.range.datatype, info.shape));
|
||||
}
|
||||
else
|
||||
{
|
||||
info.staging_tensor.reset();
|
||||
}
|
||||
|
||||
info.device_tensor = device_tensor;
|
||||
}
|
||||
else
|
||||
{
|
||||
info.device_tensor.reset();
|
||||
info.staging_tensor.reset();
|
||||
}
|
||||
|
||||
info.bind_tensor = tensor;
|
||||
}
|
||||
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<void> runtime_function::output_tensor(size_t index, runtime_tensor tensor) noexcept
|
||||
{
|
||||
CHECK_WITH_ERR(!tensor.empty(), std::errc::invalid_argument);
|
||||
CHECK_WITH_ERR(index < output_tensors_.size(), std::errc::result_out_of_range);
|
||||
|
||||
auto &info = output_tensors_[index];
|
||||
CHECK_WITH_ERR(info.range.datatype == tensor.datatype(), nncase_errc::datatype_mismatch);
|
||||
CHECK_WITH_ERR(info.shape == tensor.shape(), nncase_errc::shape_mismatch);
|
||||
|
||||
if (info.bind_tensor != tensor)
|
||||
{
|
||||
if (validate_output_tensor(index, tensor).is_err())
|
||||
{
|
||||
auto device_tensor = info.device_tensor;
|
||||
if (device_tensor.empty())
|
||||
try_var(device_tensor, allocate_output_tensor(index));
|
||||
if (!device_tensor.can_copy_to_without_staging(tensor))
|
||||
{
|
||||
try_set(info.staging_tensor, host_runtime_tensor::create(info.range.datatype, info.shape));
|
||||
}
|
||||
else
|
||||
{
|
||||
info.staging_tensor.reset();
|
||||
}
|
||||
|
||||
info.device_tensor = device_tensor;
|
||||
}
|
||||
else
|
||||
{
|
||||
info.device_tensor.reset();
|
||||
info.staging_tensor.reset();
|
||||
}
|
||||
|
||||
info.bind_tensor = tensor;
|
||||
}
|
||||
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<void> runtime_function::invoke() noexcept
|
||||
{
|
||||
// 1. Ensure bindings
|
||||
for (size_t i = 0; i < input_tensors_.size(); i++)
|
||||
{
|
||||
auto &info = input_tensors_[i];
|
||||
if (info.bind_tensor.empty())
|
||||
try_set(info.bind_tensor, allocate_input_tensor(i));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < output_tensors_.size(); i++)
|
||||
{
|
||||
auto &info = output_tensors_[i];
|
||||
if (info.bind_tensor.empty())
|
||||
try_set(info.bind_tensor, allocate_output_tensor(i));
|
||||
}
|
||||
|
||||
// 2. Copy inputs
|
||||
for (auto &in : input_tensors_)
|
||||
{
|
||||
if (in.staging_tensor.empty())
|
||||
{
|
||||
if (!in.device_tensor.empty())
|
||||
try_(in.bind_tensor.copy_to(in.device_tensor));
|
||||
}
|
||||
else
|
||||
{
|
||||
try_(in.bind_tensor.copy_to(in.staging_tensor));
|
||||
try_(in.staging_tensor.copy_to(in.device_tensor));
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Run
|
||||
try_(invoke_core());
|
||||
|
||||
// 4. Copy outputs
|
||||
for (auto &out : output_tensors_)
|
||||
{
|
||||
if (out.staging_tensor.empty())
|
||||
{
|
||||
if (!out.device_tensor.empty())
|
||||
try_(out.device_tensor.copy_to(out.bind_tensor));
|
||||
}
|
||||
else
|
||||
{
|
||||
try_(out.device_tensor.copy_to(out.staging_tensor));
|
||||
try_(out.staging_tensor.copy_to(out.bind_tensor));
|
||||
}
|
||||
}
|
||||
|
||||
return ok();
|
||||
}
|
|
@ -12,6 +12,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "section.h"
|
||||
#include <nncase/runtime/dbg.h>
|
||||
#include <nncase/runtime/error.h>
|
||||
#include <nncase/runtime/runtime_module.h>
|
||||
|
@ -47,33 +48,7 @@ public:
|
|||
|
||||
gsl::span<const gsl::byte> section(const char *name) noexcept override
|
||||
{
|
||||
span_reader reader(sections_);
|
||||
while (!reader.empty())
|
||||
{
|
||||
auto header = reader.get_ref<section_header>();
|
||||
if (!strncmp(header->name, name, MAX_SECTION_NAME_LENGTH))
|
||||
{
|
||||
gsl::span<const gsl::byte> result;
|
||||
if (header->flags & SECTION_MERGED_INTO_RDATA)
|
||||
{
|
||||
auto rdata_span = section(".rdata");
|
||||
result = rdata_span.subspan(header->start, header->size);
|
||||
}
|
||||
else
|
||||
{
|
||||
result = reader.read_avail().subspan(header->start, header->size);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!(header->flags & SECTION_MERGED_INTO_RDATA))
|
||||
reader.skip((size_t)header->start + header->size);
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
return find_section(name, sections_);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -81,6 +56,21 @@ private:
|
|||
interpreter &interp_;
|
||||
gsl::span<const gsl::byte> sections_;
|
||||
};
|
||||
|
||||
gsl::span<const gsl::byte> read_functions(span_reader &sr, size_t functions) noexcept
|
||||
{
|
||||
auto nest_sr = sr;
|
||||
size_t size = 0;
|
||||
|
||||
for (size_t i = 0; i < functions; i++)
|
||||
{
|
||||
auto func_size = nest_sr.peek_with_offset<decltype(function_header::size)>(offsetof(function_header, size));
|
||||
nest_sr.skip(func_size);
|
||||
size += func_size;
|
||||
}
|
||||
|
||||
return sr.read_span(size);
|
||||
}
|
||||
}
|
||||
|
||||
const module_type_t &runtime_module::type() const noexcept
|
||||
|
@ -112,50 +102,17 @@ mempool_desc runtime_module::mempool(memory_location_t location) const noexcept
|
|||
return desc;
|
||||
}
|
||||
|
||||
uint32_t runtime_module::inputs_size() const noexcept
|
||||
{
|
||||
return header_.inputs;
|
||||
}
|
||||
|
||||
uint32_t runtime_module::outputs_size() const noexcept
|
||||
{
|
||||
return header_.outputs;
|
||||
}
|
||||
|
||||
const memory_range &runtime_module::input_desc(size_t index) const noexcept
|
||||
{
|
||||
assert(index < input_tensors_.size());
|
||||
return input_tensors_[index].range;
|
||||
}
|
||||
|
||||
const memory_range &runtime_module::output_desc(size_t index) const noexcept
|
||||
{
|
||||
assert(index < output_tensors_.size());
|
||||
return output_tensors_[index].range;
|
||||
}
|
||||
|
||||
const runtime_shape_t &runtime_module::input_shape(size_t index) const noexcept
|
||||
{
|
||||
assert(index < input_tensors_.size());
|
||||
return input_tensors_[index].shape;
|
||||
}
|
||||
|
||||
const runtime_shape_t &runtime_module::output_shape(size_t index) const noexcept
|
||||
{
|
||||
assert(index < output_tensors_.size());
|
||||
return output_tensors_[index].shape;
|
||||
}
|
||||
|
||||
result<void> runtime_module::initialize(const module_header &header, interpreter &interp) noexcept
|
||||
result<void> runtime_module::initialize(gsl::span<const gsl::byte> payload, interpreter &interp) noexcept
|
||||
{
|
||||
interp_ = &interp;
|
||||
header_ = header;
|
||||
span_reader reader(gsl::make_span(reinterpret_cast<const gsl::byte *>(&header) + sizeof(module_header), header.size));
|
||||
span_reader reader(payload);
|
||||
reader.read(header_);
|
||||
|
||||
try
|
||||
{
|
||||
input_tensors_.resize(inputs_size());
|
||||
output_tensors_.resize(outputs_size());
|
||||
mempools_.resize(mempools_size());
|
||||
mempools_.resize(header_.mempools);
|
||||
shared_mempools_.resize(header_.shared_mempools);
|
||||
functions_.resize(header_.functions);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
|
@ -166,204 +123,38 @@ result<void> runtime_module::initialize(const module_header &header, interpreter
|
|||
for (auto &desc : mempools_)
|
||||
reader.read(desc);
|
||||
|
||||
auto read_shape = [&](runtime_shape_t &shape) {
|
||||
shape.resize(reader.read<uint32_t>());
|
||||
for (auto &dim : shape)
|
||||
dim = reader.read<uint32_t>();
|
||||
};
|
||||
// shared mempools
|
||||
for (auto &desc : shared_mempools_)
|
||||
reader.read(desc);
|
||||
|
||||
// inputs
|
||||
for (auto &in : input_tensors_)
|
||||
reader.read(in.range);
|
||||
for (auto &in : input_tensors_)
|
||||
read_shape(in.shape);
|
||||
span_reader func_reader(read_functions(reader, header_.functions));
|
||||
runtime_module_init_context_impl init_context(header_, interp, read_sections(reader, header_.sections));
|
||||
try_(initialize_before_functions(init_context));
|
||||
|
||||
// outputs
|
||||
for (auto &out : output_tensors_)
|
||||
reader.read(out.range);
|
||||
for (auto &out : output_tensors_)
|
||||
read_shape(out.shape);
|
||||
|
||||
runtime_module_init_context_impl init_context(header, interp, reader.read_avail());
|
||||
return initialize_core(init_context);
|
||||
}
|
||||
|
||||
#define INOUT_TENSOR_GETTER_IMPL(name) \
|
||||
CHECK_WITH_ERR(index < name##_tensors_.size(), std::errc::result_out_of_range); \
|
||||
\
|
||||
auto &info = name##_tensors_[index]; \
|
||||
if (info.bind_tensor.empty()) \
|
||||
{ \
|
||||
try_set(info.bind_tensor, allocate_##name##_tensor(index)); \
|
||||
} \
|
||||
return ok(info.bind_tensor);
|
||||
|
||||
result<runtime_tensor> runtime_module::input_tensor(size_t index) noexcept
|
||||
{
|
||||
INOUT_TENSOR_GETTER_IMPL(input);
|
||||
}
|
||||
|
||||
result<runtime_tensor> runtime_module::output_tensor(size_t index) noexcept
|
||||
{
|
||||
INOUT_TENSOR_GETTER_IMPL(output);
|
||||
}
|
||||
|
||||
#define DEV_INOUT_TENSOR_GETTER_IMPL(name) \
|
||||
CHECK_WITH_ERR(index < name##_tensors_.size(), std::errc::result_out_of_range); \
|
||||
\
|
||||
auto &info = name##_tensors_[index]; \
|
||||
if (info.bind_tensor.empty()) \
|
||||
{ \
|
||||
try_set(info.bind_tensor, allocate_##name##_tensor(index)); \
|
||||
} \
|
||||
if (!info.device_tensor.empty()) \
|
||||
{ \
|
||||
return ok(info.device_tensor); \
|
||||
} \
|
||||
return ok(info.bind_tensor);
|
||||
|
||||
result<runtime_tensor> runtime_module::device_input_tensor(size_t index) noexcept
|
||||
{
|
||||
DEV_INOUT_TENSOR_GETTER_IMPL(input);
|
||||
}
|
||||
|
||||
result<runtime_tensor> runtime_module::device_output_tensor(size_t index) noexcept
|
||||
{
|
||||
DEV_INOUT_TENSOR_GETTER_IMPL(output);
|
||||
}
|
||||
|
||||
result<void> runtime_module::input_tensor(size_t index, runtime_tensor tensor) noexcept
|
||||
{
|
||||
CHECK_WITH_ERR(!tensor.empty(), std::errc::invalid_argument);
|
||||
CHECK_WITH_ERR(index < input_tensors_.size(), std::errc::result_out_of_range);
|
||||
|
||||
auto &info = input_tensors_[index];
|
||||
CHECK_WITH_ERR(info.range.datatype == tensor.datatype(), nncase_errc::datatype_mismatch);
|
||||
CHECK_WITH_ERR(info.shape == tensor.shape(), nncase_errc::shape_mismatch);
|
||||
|
||||
if (info.bind_tensor != tensor)
|
||||
for (size_t i = 0; i < header_.functions; i++)
|
||||
{
|
||||
if (validate_input_tensor(index, tensor).is_err())
|
||||
{
|
||||
auto device_tensor = info.device_tensor;
|
||||
if (device_tensor.empty())
|
||||
try_var(device_tensor, allocate_input_tensor(index));
|
||||
if (!tensor.can_copy_to_without_staging(device_tensor))
|
||||
{
|
||||
try_set(info.staging_tensor, host_runtime_tensor::create(info.range.datatype, info.shape));
|
||||
}
|
||||
else
|
||||
{
|
||||
info.staging_tensor.reset();
|
||||
}
|
||||
|
||||
info.device_tensor = device_tensor;
|
||||
}
|
||||
else
|
||||
{
|
||||
info.device_tensor.reset();
|
||||
info.staging_tensor.reset();
|
||||
}
|
||||
|
||||
info.bind_tensor = tensor;
|
||||
auto func_size = func_reader.peek_with_offset<decltype(function_header::size)>(offsetof(function_header, size));
|
||||
auto payload = func_reader.read_span(func_size);
|
||||
try_var(func, create_function());
|
||||
try_(func->initialize(payload, init_context));
|
||||
functions_[i] = std::move(func);
|
||||
}
|
||||
|
||||
return initialize_after_functions(init_context);
|
||||
}
|
||||
|
||||
result<runtime_function *> runtime_module::find_function_by_id(size_t index) noexcept
|
||||
{
|
||||
CHECK_WITH_ERR(index < functions_.size(), std::errc::result_out_of_range);
|
||||
return ok(functions_[index].get());
|
||||
}
|
||||
|
||||
result<void> runtime_module::initialize_before_functions(NNCASE_UNUSED runtime_module_init_context &context) noexcept
|
||||
{
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<void> runtime_module::output_tensor(size_t index, runtime_tensor tensor) noexcept
|
||||
{
|
||||
CHECK_WITH_ERR(!tensor.empty(), std::errc::invalid_argument);
|
||||
CHECK_WITH_ERR(index < output_tensors_.size(), std::errc::result_out_of_range);
|
||||
|
||||
auto &info = output_tensors_[index];
|
||||
CHECK_WITH_ERR(info.range.datatype == tensor.datatype(), nncase_errc::datatype_mismatch);
|
||||
CHECK_WITH_ERR(info.shape == tensor.shape(), nncase_errc::shape_mismatch);
|
||||
|
||||
if (info.bind_tensor != tensor)
|
||||
{
|
||||
if (validate_output_tensor(index, tensor).is_err())
|
||||
{
|
||||
auto device_tensor = info.device_tensor;
|
||||
if (device_tensor.empty())
|
||||
try_var(device_tensor, allocate_output_tensor(index));
|
||||
if (!device_tensor.can_copy_to_without_staging(tensor))
|
||||
{
|
||||
try_set(info.staging_tensor, host_runtime_tensor::create(info.range.datatype, info.shape));
|
||||
}
|
||||
else
|
||||
{
|
||||
info.staging_tensor.reset();
|
||||
}
|
||||
|
||||
info.device_tensor = device_tensor;
|
||||
}
|
||||
else
|
||||
{
|
||||
info.device_tensor.reset();
|
||||
info.staging_tensor.reset();
|
||||
}
|
||||
|
||||
info.bind_tensor = tensor;
|
||||
}
|
||||
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<void> runtime_module::run() noexcept
|
||||
{
|
||||
// 1. Ensure bindings
|
||||
for (size_t i = 0; i < input_tensors_.size(); i++)
|
||||
{
|
||||
auto &info = input_tensors_[i];
|
||||
if (info.bind_tensor.empty())
|
||||
try_set(info.bind_tensor, allocate_input_tensor(i));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < output_tensors_.size(); i++)
|
||||
{
|
||||
auto &info = output_tensors_[i];
|
||||
if (info.bind_tensor.empty())
|
||||
try_set(info.bind_tensor, allocate_output_tensor(i));
|
||||
}
|
||||
|
||||
// 2. Copy inputs
|
||||
for (auto &in : input_tensors_)
|
||||
{
|
||||
if (in.staging_tensor.empty())
|
||||
{
|
||||
if (!in.device_tensor.empty())
|
||||
try_(in.bind_tensor.copy_to(in.device_tensor));
|
||||
}
|
||||
else
|
||||
{
|
||||
try_(in.bind_tensor.copy_to(in.staging_tensor));
|
||||
try_(in.staging_tensor.copy_to(in.device_tensor));
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Run
|
||||
try_(run_core());
|
||||
|
||||
// 4. Copy outputs
|
||||
for (auto &out : output_tensors_)
|
||||
{
|
||||
if (out.staging_tensor.empty())
|
||||
{
|
||||
if (!out.device_tensor.empty())
|
||||
try_(out.device_tensor.copy_to(out.bind_tensor));
|
||||
}
|
||||
else
|
||||
{
|
||||
try_(out.device_tensor.copy_to(out.staging_tensor));
|
||||
try_(out.staging_tensor.copy_to(out.bind_tensor));
|
||||
}
|
||||
}
|
||||
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<void> runtime_module::initialize_inter_modules(NNCASE_UNUSED interpreter &interp) noexcept
|
||||
result<void> runtime_module::initialize_after_functions(NNCASE_UNUSED runtime_module_init_context &context) noexcept
|
||||
{
|
||||
return ok();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
/* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "section.h"
|
||||
#include <nncase/runtime/span_reader.h>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
|
||||
gsl::span<const gsl::byte> runtime::find_section(const char *name, gsl::span<const gsl::byte> sections) noexcept
|
||||
{
|
||||
span_reader reader(sections);
|
||||
while (!reader.empty())
|
||||
{
|
||||
auto header = reader.get_ref<section_header>();
|
||||
if (!strncmp(header->name, name, MAX_SECTION_NAME_LENGTH))
|
||||
{
|
||||
gsl::span<const gsl::byte> result;
|
||||
if (header->flags & SECTION_MERGED_INTO_RDATA)
|
||||
{
|
||||
auto rdata_span = find_section(".rdata", sections);
|
||||
result = rdata_span.subspan(header->body_start, header->body_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
result = reader.read_avail().subspan(header->body_start, header->body_size);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!(header->flags & SECTION_MERGED_INTO_RDATA))
|
||||
reader.skip((size_t)header->body_start + header->body_size);
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
gsl::span<const gsl::byte> runtime::read_sections(span_reader &sr, size_t sections) noexcept
|
||||
{
|
||||
auto nest_sr = sr;
|
||||
size_t size = 0;
|
||||
|
||||
for (size_t i = 0; i < sections; i++)
|
||||
{
|
||||
auto header = nest_sr.get_ref<section_header>();
|
||||
size += sizeof(section_header);
|
||||
|
||||
if (!(header->flags & SECTION_MERGED_INTO_RDATA))
|
||||
{
|
||||
auto to_skip = (size_t)header->body_start + header->body_size;
|
||||
nest_sr.skip(to_skip);
|
||||
size += to_skip;
|
||||
}
|
||||
}
|
||||
|
||||
return sr.read_span(size);
|
||||
}
|
|
@ -13,12 +13,13 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../runtime_module.h"
|
||||
#include <nncase/kernels/kernel_context.h>
|
||||
NNCASE_MODULES_K210_API
|
||||
#include <nncase/runtime/model.h>
|
||||
#include <nncase/runtime/result.h>
|
||||
#include <nncase/runtime/span_reader.h>
|
||||
|
||||
struct NNCASE_API k210_kernel_context : public kernels::kernel_context
|
||||
{
|
||||
};
|
||||
BEGIN_NS_NNCASE_RUNTIME
|
||||
|
||||
END_NS_NNCASE_KERNELS_K210
|
||||
gsl::span<const gsl::byte> find_section(const char *name, gsl::span<const gsl::byte> sections) noexcept;
|
||||
gsl::span<const gsl::byte> read_sections(span_reader &sr, size_t sections) noexcept;
|
||||
|
||||
END_NS_NNCASE_RUNTIME
|
|
@ -1,6 +1,7 @@
|
|||
cmake_minimum_required (VERSION 3.13)
|
||||
|
||||
set(SRCS runtime_module.cpp
|
||||
runtime_function.cpp
|
||||
op_reader.cpp
|
||||
evaluate_stack.cpp
|
||||
ops/control.cpp
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/7/14 19:17:48 +08:00.
|
||||
/* This file is generated by tools/stackvm_gen/IsaGen at 2021/8/11 17:40:11 +08:00.
|
||||
*
|
||||
* Copyright 2019-2021 Canaan Inc.
|
||||
*
|
||||
|
|
|
@ -12,23 +12,23 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const nop_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const nop_op_t &op) noexcept
|
||||
{
|
||||
return ok();
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const br_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const br_op_t &op) noexcept
|
||||
{
|
||||
return pc_relative(op.target);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const br_true_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const br_true_op_t &op) noexcept
|
||||
{
|
||||
try_var(value, stack_.pop());
|
||||
if (value.as_i())
|
||||
|
@ -36,7 +36,7 @@ result<void> stackvm_runtime_module::visit(const br_true_op_t &op) noexcept
|
|||
return ok();
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const br_false_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const br_false_op_t &op) noexcept
|
||||
{
|
||||
try_var(value, stack_.pop());
|
||||
if (!value.as_i())
|
||||
|
@ -44,7 +44,7 @@ result<void> stackvm_runtime_module::visit(const br_false_op_t &op) noexcept
|
|||
return ok();
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ret_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ret_op_t &op) noexcept
|
||||
{
|
||||
if (call_depth_ == 0)
|
||||
{
|
||||
|
@ -57,22 +57,22 @@ result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ret_op_t &op) noe
|
|||
return pc(target.as_u());
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const call_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const call_op_t &op) noexcept
|
||||
{
|
||||
return err(std::errc::not_supported);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ecall_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ecall_op_t &op) noexcept
|
||||
{
|
||||
return err(std::errc::not_supported);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const throw_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const throw_op_t &op) noexcept
|
||||
{
|
||||
return err(std::errc::not_supported);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const break_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const break_op_t &op) noexcept
|
||||
{
|
||||
return err(std::errc::not_supported);
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
|
@ -25,52 +25,52 @@ using namespace nncase::runtime::stackvm;
|
|||
else \
|
||||
return stack_.push((type)value.as_r())
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_i1_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_i1_op_t &op) noexcept
|
||||
{
|
||||
CONV_IMPL(int8_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_i2_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_i2_op_t &op) noexcept
|
||||
{
|
||||
CONV_IMPL(int16_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_i4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_i4_op_t &op) noexcept
|
||||
{
|
||||
CONV_IMPL(int32_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_i_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_i_op_t &op) noexcept
|
||||
{
|
||||
CONV_IMPL(intptr_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_u1_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_u1_op_t &op) noexcept
|
||||
{
|
||||
CONV_IMPL(uint8_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_u2_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_u2_op_t &op) noexcept
|
||||
{
|
||||
CONV_IMPL(uint16_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_u4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_u4_op_t &op) noexcept
|
||||
{
|
||||
CONV_IMPL(uint32_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_u_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_u_op_t &op) noexcept
|
||||
{
|
||||
CONV_IMPL(uintptr_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_br2_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_br2_op_t &op) noexcept
|
||||
{
|
||||
CONV_IMPL(bfloat16);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const conv_r4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const conv_r4_op_t &op) noexcept
|
||||
{
|
||||
CONV_IMPL(float);
|
||||
}
|
||||
|
|
|
@ -12,33 +12,33 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const ldc_i4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const ldc_i4_op_t &op) noexcept
|
||||
{
|
||||
return stack_.push(op.imm);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldnull_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldnull_op_t &op) noexcept
|
||||
{
|
||||
return stack_.push((uintptr_t)0);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldc_i4_0_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldc_i4_0_op_t &op) noexcept
|
||||
{
|
||||
return stack_.push((int32_t)0);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldc_i4_1_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldc_i4_1_op_t &op) noexcept
|
||||
{
|
||||
return stack_.push((int32_t)1);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const ldc_r4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const ldc_r4_op_t &op) noexcept
|
||||
{
|
||||
return stack_.push(op.imm);
|
||||
}
|
||||
|
@ -49,52 +49,52 @@ result<void> stackvm_runtime_module::visit(const ldc_r4_op_t &op) noexcept
|
|||
return err(std::errc::bad_address); \
|
||||
return stack_.push(*reinterpret_cast<const type *>(addr.as_u()))
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_i1_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_i1_op_t &op) noexcept
|
||||
{
|
||||
LDINDIMPL(int8_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_i2_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_i2_op_t &op) noexcept
|
||||
{
|
||||
LDINDIMPL(int16_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_i4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_i4_op_t &op) noexcept
|
||||
{
|
||||
LDINDIMPL(int32_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_i_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_i_op_t &op) noexcept
|
||||
{
|
||||
LDINDIMPL(intptr_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_u1_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_u1_op_t &op) noexcept
|
||||
{
|
||||
LDINDIMPL(uint8_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_u2_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_u2_op_t &op) noexcept
|
||||
{
|
||||
LDINDIMPL(uint16_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_u4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_u4_op_t &op) noexcept
|
||||
{
|
||||
LDINDIMPL(uint32_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_u_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_u_op_t &op) noexcept
|
||||
{
|
||||
LDINDIMPL(uintptr_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_br2_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_br2_op_t &op) noexcept
|
||||
{
|
||||
LDINDIMPL(bfloat16);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_r4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldind_r4_op_t &op) noexcept
|
||||
{
|
||||
LDINDIMPL(float);
|
||||
}
|
||||
|
@ -107,44 +107,43 @@ result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldind_r4_op_t &op
|
|||
*reinterpret_cast<decltype(value.as_##type()) *>(addr.as_u()) = value.as_##type(); \
|
||||
return ok()
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stind_i1_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stind_i1_op_t &op) noexcept
|
||||
{
|
||||
STINDIMPL(i1);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stind_i2_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stind_i2_op_t &op) noexcept
|
||||
{
|
||||
STINDIMPL(i2);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stind_i4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stind_i4_op_t &op) noexcept
|
||||
{
|
||||
STINDIMPL(i4);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stind_i_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stind_i_op_t &op) noexcept
|
||||
{
|
||||
STINDIMPL(i);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stind_br2_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stind_br2_op_t &op) noexcept
|
||||
{
|
||||
STINDIMPL(br2);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stind_r4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stind_r4_op_t &op) noexcept
|
||||
{
|
||||
STINDIMPL(r4);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const lea_gp_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const lea_gp_op_t &op) noexcept
|
||||
{
|
||||
if (op.gpid >= regs_.size())
|
||||
return err(std::errc::result_out_of_range);
|
||||
return stack_.push((intptr_t)regs_[op.gpid] + op.offset);
|
||||
try_var(reg, module().reg(op.gpid));
|
||||
return stack_.push((intptr_t)reg + op.offset);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const lea_buffer_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const lea_buffer_op_t &op) noexcept
|
||||
{
|
||||
#define ID_NOT_FOUND ((size_t)-1)
|
||||
// TODO: use subres
|
||||
|
@ -206,13 +205,13 @@ result<void> stackvm_runtime_module::visit(const lea_buffer_op_t &op) noexcept
|
|||
}
|
||||
else if (op.location == mem_rdata)
|
||||
{
|
||||
auto buffer = rdata_.subspan(op.offset);
|
||||
auto buffer = module().rdata().subspan(op.offset);
|
||||
return stack_.push((uintptr_t)buffer.data());
|
||||
}
|
||||
else if (op.location == mem_data)
|
||||
{
|
||||
auto buffer = data_.get() + op.offset;
|
||||
return stack_.push((uintptr_t)buffer);
|
||||
auto buffer = module().data().subspan(op.offset);
|
||||
return stack_.push((uintptr_t)buffer.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -225,52 +224,52 @@ result<void> stackvm_runtime_module::visit(const lea_buffer_op_t &op) noexcept
|
|||
try_var(addr, stack_.pop()); \
|
||||
return stack_.push(reinterpret_cast<const type *>(addr.as_u())[offset.as_u()])
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_i1_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_i1_op_t &op) noexcept
|
||||
{
|
||||
LDELEM_IMPL(int8_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_i2_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_i2_op_t &op) noexcept
|
||||
{
|
||||
LDELEM_IMPL(int16_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_i4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_i4_op_t &op) noexcept
|
||||
{
|
||||
LDELEM_IMPL(int32_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_i_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_i_op_t &op) noexcept
|
||||
{
|
||||
LDELEM_IMPL(intptr_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_u1_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_u1_op_t &op) noexcept
|
||||
{
|
||||
LDELEM_IMPL(uint8_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_u2_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_u2_op_t &op) noexcept
|
||||
{
|
||||
LDELEM_IMPL(uint16_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_u4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_u4_op_t &op) noexcept
|
||||
{
|
||||
LDELEM_IMPL(uint32_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_u_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_u_op_t &op) noexcept
|
||||
{
|
||||
LDELEM_IMPL(uintptr_t);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_br2_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_br2_op_t &op) noexcept
|
||||
{
|
||||
LDELEM_IMPL(bfloat16);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_r4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldelem_r4_op_t &op) noexcept
|
||||
{
|
||||
LDELEM_IMPL(float);
|
||||
}
|
||||
|
@ -282,116 +281,110 @@ result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldelem_r4_op_t &o
|
|||
reinterpret_cast<decltype(value.as_##type()) *>(addr.as_u())[offset.as_u()] = value.as_##type(); \
|
||||
return ok()
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stelem_i1_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stelem_i1_op_t &op) noexcept
|
||||
{
|
||||
STELEM_IMPL(i1);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stelem_i2_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stelem_i2_op_t &op) noexcept
|
||||
{
|
||||
STELEM_IMPL(i2);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stelem_i4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stelem_i4_op_t &op) noexcept
|
||||
{
|
||||
STELEM_IMPL(i4);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stelem_i_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stelem_i_op_t &op) noexcept
|
||||
{
|
||||
STELEM_IMPL(i);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stelem_br2_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stelem_br2_op_t &op) noexcept
|
||||
{
|
||||
STELEM_IMPL(br2);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const stelem_r4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const stelem_r4_op_t &op) noexcept
|
||||
{
|
||||
STELEM_IMPL(r4);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldarg_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldarg_op_t &op) noexcept
|
||||
{
|
||||
return err(std::errc::not_supported);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldarg_0_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldarg_0_op_t &op) noexcept
|
||||
{
|
||||
return err(std::errc::not_supported);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldarg_1_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldarg_1_op_t &op) noexcept
|
||||
{
|
||||
return err(std::errc::not_supported);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldarg_2_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldarg_2_op_t &op) noexcept
|
||||
{
|
||||
return err(std::errc::not_supported);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldarg_3_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldarg_3_op_t &op) noexcept
|
||||
{
|
||||
return err(std::errc::not_supported);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldarg_4_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldarg_4_op_t &op) noexcept
|
||||
{
|
||||
return err(std::errc::not_supported);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ldarg_5_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ldarg_5_op_t &op) noexcept
|
||||
{
|
||||
return err(std::errc::not_supported);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const stshape_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const stshape_op_t &op) noexcept
|
||||
{
|
||||
if (op.rshape >= shape_regs_.size())
|
||||
shape_regs_.resize(op.rshape + 1);
|
||||
|
||||
auto ® = shape_regs_[op.rshape];
|
||||
runtime_shape_t shape;
|
||||
try
|
||||
{
|
||||
reg.resize(op.rank);
|
||||
shape.resize(op.rank);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
return err(std::errc::not_enough_memory);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < reg.size(); i++)
|
||||
for (size_t i = 0; i < shape.size(); i++)
|
||||
{
|
||||
try_var(dim, stack_.pop());
|
||||
reg[op.rank - i - 1] = (size_t)dim.as_u();
|
||||
shape[op.rank - i - 1] = (size_t)dim.as_u();
|
||||
}
|
||||
|
||||
return ok();
|
||||
return module().shape_reg(op.rshape, std::move(shape));
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const stpaddings_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const stpaddings_op_t &op) noexcept
|
||||
{
|
||||
if (op.rpaddings >= paddings_regs_.size())
|
||||
paddings_regs_.resize(op.rpaddings + 1);
|
||||
|
||||
auto ® = paddings_regs_[op.rpaddings];
|
||||
runtime_paddings_t paddings;
|
||||
try
|
||||
{
|
||||
reg.resize(op.rank);
|
||||
paddings.resize(op.rank);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
return err(std::errc::not_enough_memory);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < reg.size(); i++)
|
||||
for (size_t i = 0; i < paddings.size(); i++)
|
||||
{
|
||||
try_var(after, stack_.pop());
|
||||
try_var(before, stack_.pop());
|
||||
reg[op.rank - i - 1] = { before.as_i4(), after.as_i4() };
|
||||
paddings[op.rank - i - 1] = { before.as_i4(), after.as_i4() };
|
||||
}
|
||||
|
||||
return ok();
|
||||
return module().paddings_reg(op.rpaddings, std::move(paddings));
|
||||
}
|
||||
|
|
|
@ -12,13 +12,13 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const neg_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const neg_op_t &op) noexcept
|
||||
{
|
||||
try_var(value, stack_.pop());
|
||||
if (!value.is_real())
|
||||
|
@ -53,73 +53,73 @@ result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const neg_op_t &op) noe
|
|||
try_var(a, stack_.pop()); \
|
||||
return stack_.push(a.as_u() op b.as_u());
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const add_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const add_op_t &op) noexcept
|
||||
{
|
||||
BINARY_IMPL(+);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const sub_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const sub_op_t &op) noexcept
|
||||
{
|
||||
BINARY_IMPL(-);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const mul_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const mul_op_t &op) noexcept
|
||||
{
|
||||
BINARY_IMPL(*);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const div_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const div_op_t &op) noexcept
|
||||
{
|
||||
BINARY_IMPL(/);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const div_u_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const div_u_op_t &op) noexcept
|
||||
{
|
||||
BINARY_U_IMPL(/);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const rem_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const rem_op_t &op) noexcept
|
||||
{
|
||||
BINARY_IMPL(/);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const rem_u_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const rem_u_op_t &op) noexcept
|
||||
{
|
||||
BINARY_U_IMPL(/);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const and_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const and_op_t &op) noexcept
|
||||
{
|
||||
BINARY_BIT_U_IMPL(&);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const or_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const or_op_t &op) noexcept
|
||||
{
|
||||
BINARY_BIT_U_IMPL(|);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const xor_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const xor_op_t &op) noexcept
|
||||
{
|
||||
BINARY_BIT_U_IMPL(^);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const not_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const not_op_t &op) noexcept
|
||||
{
|
||||
try_var(value, stack_.pop());
|
||||
return stack_.push(~value.as_u());
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const shl_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const shl_op_t &op) noexcept
|
||||
{
|
||||
BINARY_BIT_U_IMPL(<<);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const shr_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const shr_op_t &op) noexcept
|
||||
{
|
||||
BINARY_BIT_IMPL(>>);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const shr_u_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const shr_u_op_t &op) noexcept
|
||||
{
|
||||
BINARY_BIT_U_IMPL(>>);
|
||||
}
|
||||
|
@ -140,52 +140,52 @@ result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const shr_u_op_t &op) n
|
|||
else \
|
||||
return stack_.push(a.as_r() op b.as_r() ? 1 : 0)
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const clt_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const clt_op_t &op) noexcept
|
||||
{
|
||||
COMPARE_IMPL(<);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const clt_u_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const clt_u_op_t &op) noexcept
|
||||
{
|
||||
COMPARE_U_IMPL(<);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const cle_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const cle_op_t &op) noexcept
|
||||
{
|
||||
COMPARE_IMPL(<=);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const cle_u_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const cle_u_op_t &op) noexcept
|
||||
{
|
||||
COMPARE_U_IMPL(<=);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const ceq_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const ceq_op_t &op) noexcept
|
||||
{
|
||||
COMPARE_U_IMPL(==);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const cge_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const cge_op_t &op) noexcept
|
||||
{
|
||||
COMPARE_IMPL(>=);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const cge_u_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const cge_u_op_t &op) noexcept
|
||||
{
|
||||
COMPARE_U_IMPL(>=);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const cgt_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const cgt_op_t &op) noexcept
|
||||
{
|
||||
COMPARE_IMPL(>);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const cgt_u_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const cgt_u_op_t &op) noexcept
|
||||
{
|
||||
COMPARE_U_IMPL(>);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const cne_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const cne_op_t &op) noexcept
|
||||
{
|
||||
COMPARE_U_IMPL(!=);
|
||||
}
|
||||
|
|
|
@ -12,19 +12,19 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const dup_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const dup_op_t &op) noexcept
|
||||
{
|
||||
try_var(entry, stack_.peek());
|
||||
return stack_.push(entry);
|
||||
}
|
||||
|
||||
result<void> stackvm_runtime_module::visit(NNCASE_UNUSED const pop_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(NNCASE_UNUSED const pop_op_t &op) noexcept
|
||||
{
|
||||
try_(stack_.pop());
|
||||
return ok();
|
||||
|
|
|
@ -12,23 +12,23 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/tensor_compute.h>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const tensor_batch_to_space_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const tensor_batch_to_space_op_t &op) noexcept
|
||||
{
|
||||
try_var(output, pop_addr());
|
||||
try_var(input, pop_addr());
|
||||
auto &in_shape = shape_regs_[op.rshape_src];
|
||||
auto &block_shape = shape_regs_[op.rshape_block];
|
||||
auto &crops = paddings_regs_[op.rpad_crops];
|
||||
auto &in_strides = shape_regs_[op.rstride_src];
|
||||
auto &out_strides = shape_regs_[op.rstride_dest];
|
||||
try_var(in_shape, module().shape_reg(op.rshape_src));
|
||||
try_var(block_shape, module().shape_reg(op.rshape_block));
|
||||
try_var(crops, module().paddings_reg(op.rpad_crops));
|
||||
try_var(in_strides, module().shape_reg(op.rstride_src));
|
||||
try_var(out_strides, module().shape_reg(op.rstride_dest));
|
||||
|
||||
return kernels::batch_to_space(op.datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output),
|
||||
in_shape, block_shape, crops, in_strides, out_strides, kernel_context());
|
||||
in_shape, block_shape, crops, in_strides, out_strides, module().kernel_context());
|
||||
}
|
||||
|
|
|
@ -12,24 +12,24 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/tensor_compute.h>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const tensor_binary_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const tensor_binary_op_t &op) noexcept
|
||||
{
|
||||
try_var(output, pop_addr());
|
||||
try_var(input_b, pop_addr());
|
||||
try_var(input_a, pop_addr());
|
||||
auto &in_a_shape = shape_regs_[op.rshape_src1];
|
||||
auto &in_a_strides = shape_regs_[op.rstride_src1];
|
||||
auto &in_b_shape = shape_regs_[op.rshape_src2];
|
||||
auto &in_b_strides = shape_regs_[op.rstride_src2];
|
||||
auto &out_strides = shape_regs_[op.rstride_dest];
|
||||
try_var(in_a_shape, module().shape_reg(op.rshape_src1));
|
||||
try_var(in_a_strides, module().shape_reg(op.rstride_src1));
|
||||
try_var(in_b_shape, module().shape_reg(op.rshape_src2));
|
||||
try_var(in_b_strides, module().shape_reg(op.rstride_src2));
|
||||
try_var(out_strides, module().shape_reg(op.rstride_dest));
|
||||
|
||||
return kernels::binary(op.binary_op, reinterpret_cast<const float *>(input_a), reinterpret_cast<const float *>(input_b),
|
||||
reinterpret_cast<float *>(output), in_a_shape, in_a_strides, in_b_shape, in_b_strides, out_strides, { op.fused_clamp_low, op.fused_clamp_high }, kernel_context());
|
||||
reinterpret_cast<float *>(output), in_a_shape, in_a_strides, in_b_shape, in_b_strides, out_strides, { op.fused_clamp_low, op.fused_clamp_high }, module().kernel_context());
|
||||
}
|
||||
|
|
|
@ -12,22 +12,22 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/tensor_compute.h>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const tensor_broadcast_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const tensor_broadcast_op_t &op) noexcept
|
||||
{
|
||||
try_var(output, pop_addr());
|
||||
try_var(input, pop_addr());
|
||||
auto &in_shape = shape_regs_[op.rshape_src];
|
||||
auto &in_strides = shape_regs_[op.rstride_src];
|
||||
auto &out_shape = shape_regs_[op.rshape_dest];
|
||||
auto &out_strides = shape_regs_[op.rstride_dest];
|
||||
try_var(in_shape, module().shape_reg(op.rshape_src));
|
||||
try_var(in_strides, module().shape_reg(op.rstride_src));
|
||||
try_var(out_shape, module().shape_reg(op.rshape_dest));
|
||||
try_var(out_strides, module().shape_reg(op.rstride_dest));
|
||||
|
||||
return kernels::broadcast(op.datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output),
|
||||
in_shape, in_strides, out_shape, out_strides, kernel_context());
|
||||
in_shape, in_strides, out_shape, out_strides, module().kernel_context());
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/tensor_compute.h>
|
||||
#include <nncase/runtime/interpreter.h>
|
||||
#include <nncase/runtime/runtime_op_utility.h>
|
||||
|
@ -21,15 +21,16 @@ using namespace nncase;
|
|||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const tensor_call_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const tensor_call_op_t &op) noexcept
|
||||
{
|
||||
try_var(mod, interp().find_module_by_id(op.module_id));
|
||||
try_var(mod, module().interp().find_module_by_id(op.module_id));
|
||||
try_var(func, mod->find_function_by_id(op.function_id));
|
||||
|
||||
auto create_tensor = [&]() -> result<runtime_tensor> {
|
||||
try_var(rstrides, stack_.pop());
|
||||
auto &strides = shape_regs_[rstrides.as_u4()];
|
||||
try_var(strides, module().shape_reg(rstrides.as_u4()));
|
||||
try_var(rshape, stack_.pop());
|
||||
auto &shape = shape_regs_[rshape.as_u4()];
|
||||
try_var(shape, module().shape_reg(rshape.as_u4()));
|
||||
try_var(e_datatype, stack_.pop());
|
||||
try_var(addr, pop_addr());
|
||||
|
||||
|
@ -40,14 +41,14 @@ result<void> stackvm_runtime_module::visit(const tensor_call_op_t &op) noexcept
|
|||
for (uint8_t i = 0; i < op.num_dst; i++)
|
||||
{
|
||||
try_var(tensor, create_tensor());
|
||||
try_(mod->output_tensor(op.num_dst - i - 1, tensor));
|
||||
try_(func->output_tensor((size_t)op.num_dst - i - 1, tensor));
|
||||
}
|
||||
|
||||
for (uint8_t i = 0; i < op.num_src; i++)
|
||||
{
|
||||
try_var(tensor, create_tensor());
|
||||
try_(mod->input_tensor(op.num_src - i - 1, tensor));
|
||||
try_(func->input_tensor((size_t)op.num_src - i - 1, tensor));
|
||||
}
|
||||
|
||||
return mod->run();
|
||||
return func->invoke();
|
||||
}
|
||||
|
|
|
@ -12,14 +12,14 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/convolution.h>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const tensor_conv2d_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const tensor_conv2d_op_t &op) noexcept
|
||||
{
|
||||
try_var(padding_w, pop_padding());
|
||||
try_var(padding_h, pop_padding());
|
||||
|
@ -27,16 +27,16 @@ result<void> stackvm_runtime_module::visit(const tensor_conv2d_op_t &op) noexcep
|
|||
try_var(bias, pop_addr());
|
||||
try_var(weights, pop_addr());
|
||||
try_var(input, pop_addr());
|
||||
auto &in_shape = shape_regs_[op.rshape_src];
|
||||
auto &in_strides = shape_regs_[op.rstride_src];
|
||||
auto &w_shape = shape_regs_[op.rshape_kernel];
|
||||
auto &w_strides = shape_regs_[op.rstride_kernel];
|
||||
auto &bias_strides = shape_regs_[op.rstride_bias];
|
||||
auto &out_strides = shape_regs_[op.rstride_dest];
|
||||
try_var(in_shape, module().shape_reg(op.rshape_src));
|
||||
try_var(in_strides, module().shape_reg(op.rstride_src));
|
||||
try_var(w_shape, module().shape_reg(op.rshape_kernel));
|
||||
try_var(w_strides, module().shape_reg(op.rstride_kernel));
|
||||
try_var(bias_strides, module().shape_reg(op.rstride_bias));
|
||||
try_var(out_strides, module().shape_reg(op.rstride_dest));
|
||||
|
||||
if (op.datatype != dt_float32)
|
||||
return err(nncase_errc::datatype_mismatch);
|
||||
return kernels::conv2d(reinterpret_cast<const float *>(input), reinterpret_cast<const float *>(weights),
|
||||
reinterpret_cast<const float *>(bias), reinterpret_cast<float *>(output), in_shape, in_strides, w_shape, w_strides, bias_strides, out_strides,
|
||||
padding_h, padding_w, op.groups, op.stride_h, op.stride_w, op.dilation_h, op.dilation_w, { op.fused_clamp_low, op.fused_clamp_high }, kernel_context());
|
||||
padding_h, padding_w, op.groups, op.stride_h, op.stride_w, op.dilation_h, op.dilation_w, { op.fused_clamp_low, op.fused_clamp_high }, module().kernel_context());
|
||||
}
|
||||
|
|
|
@ -12,20 +12,20 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/tensor_compute.h>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const tensor_convert_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const tensor_convert_op_t &op) noexcept
|
||||
{
|
||||
try_var(output, pop_addr());
|
||||
try_var(input, pop_addr());
|
||||
auto &shape = shape_regs_[op.rshape_src];
|
||||
auto &in_strides = shape_regs_[op.rstride_src];
|
||||
auto &out_strides = shape_regs_[op.rstride_dest];
|
||||
try_var(shape, module().shape_reg(op.rshape_src));
|
||||
try_var(in_strides, module().shape_reg(op.rstride_src));
|
||||
try_var(out_strides, module().shape_reg(op.rstride_dest));
|
||||
|
||||
return kernels::convert(op.in_datatype, op.dst_datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output), shape, in_strides, out_strides, kernel_context());
|
||||
return kernels::convert(op.in_datatype, op.dst_datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output), shape, in_strides, out_strides, module().kernel_context());
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/tensor_compute.h>
|
||||
#include <nncase/runtime/interpreter.h>
|
||||
#include <nncase/runtime/runtime_op_utility.h>
|
||||
|
@ -21,13 +21,13 @@ using namespace nncase;
|
|||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const tensor_copy_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const tensor_copy_op_t &op) noexcept
|
||||
{
|
||||
try_var(output, pop_addr());
|
||||
try_var(input, pop_addr());
|
||||
auto &shape = shape_regs_[op.rshape];
|
||||
auto &in_strides = shape_regs_[op.rstride_src];
|
||||
auto &out_strides = shape_regs_[op.rstride_dest];
|
||||
try_var(shape, module().shape_reg(op.rshape));
|
||||
try_var(in_strides, module().shape_reg(op.rstride_src));
|
||||
try_var(out_strides, module().shape_reg(op.rstride_dest));
|
||||
|
||||
return kernels::copy(op.datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output), shape, in_strides, out_strides, kernel_context());
|
||||
return kernels::copy(op.datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output), shape, in_strides, out_strides, module().kernel_context());
|
||||
}
|
||||
|
|
|
@ -12,24 +12,24 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/tensor_compute.h>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const tensor_dequantize_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const tensor_dequantize_op_t &op) noexcept
|
||||
{
|
||||
try_var(bias, stack_.pop());
|
||||
try_var(scale, stack_.pop());
|
||||
try_var(output, pop_addr());
|
||||
try_var(input, pop_addr());
|
||||
|
||||
auto &shape = shape_regs_[op.rshape_src];
|
||||
auto &in_strides = shape_regs_[op.rstride_src];
|
||||
auto &out_strides = shape_regs_[op.rstride_dest];
|
||||
try_var(shape, module().shape_reg(op.rshape_src));
|
||||
try_var(in_strides, module().shape_reg(op.rstride_src));
|
||||
try_var(out_strides, module().shape_reg(op.rstride_dest));
|
||||
|
||||
return kernels::dequantize(op.in_datatype, op.dst_datatype, reinterpret_cast<const gsl::byte *>(input),
|
||||
reinterpret_cast<gsl::byte *>(output), shape, in_strides, out_strides, scale.as_r4(), bias.as_r4(), kernel_context());
|
||||
reinterpret_cast<gsl::byte *>(output), shape, in_strides, out_strides, scale.as_r4(), bias.as_r4(), module().kernel_context());
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/tensor_compute.h>
|
||||
#include <nncase/runtime/interpreter.h>
|
||||
#include <nncase/runtime/runtime_op_utility.h>
|
||||
|
@ -21,17 +21,17 @@ using namespace nncase;
|
|||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const tensor_gather_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const tensor_gather_op_t &op) noexcept
|
||||
{
|
||||
try_var(indices, pop_addr());
|
||||
try_var(output, pop_addr());
|
||||
try_var(input, pop_addr());
|
||||
|
||||
auto &in_shape = shape_regs_[op.rshape_src];
|
||||
auto &in_strides = shape_regs_[op.rstride_src];
|
||||
auto &out_shape = shape_regs_[op.rshape_dest];
|
||||
auto &out_strides = shape_regs_[op.rstride_dest];
|
||||
auto &indices_shape = shape_regs_[op.rshape_indices];
|
||||
try_var(in_shape, module().shape_reg(op.rshape_src));
|
||||
try_var(in_strides, module().shape_reg(op.rstride_src));
|
||||
try_var(out_shape, module().shape_reg(op.rshape_dest));
|
||||
try_var(out_strides, module().shape_reg(op.rstride_dest));
|
||||
try_var(indices_shape, module().shape_reg(op.rshape_indices));
|
||||
|
||||
return kernels::gather(op.datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output), in_shape, out_shape,
|
||||
in_strides, out_strides, reinterpret_cast<const int32_t *>(indices), indices_shape, op.axis);
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/tensor_compute.h>
|
||||
#include <nncase/runtime/interpreter.h>
|
||||
#include <nncase/runtime/runtime_op_utility.h>
|
||||
|
@ -21,17 +21,17 @@ using namespace nncase;
|
|||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const tensor_gather_nd_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const tensor_gather_nd_op_t &op) noexcept
|
||||
{
|
||||
try_var(indices, pop_addr());
|
||||
try_var(output, pop_addr());
|
||||
try_var(input, pop_addr());
|
||||
|
||||
auto &in_shape = shape_regs_[op.rshape_src];
|
||||
auto &in_strides = shape_regs_[op.rstride_src];
|
||||
auto &out_shape = shape_regs_[op.rshape_dest];
|
||||
auto &out_strides = shape_regs_[op.rstride_dest];
|
||||
auto &indices_shape = shape_regs_[op.rshape_indices];
|
||||
try_var(in_shape, module().shape_reg(op.rshape_src));
|
||||
try_var(in_strides, module().shape_reg(op.rstride_src));
|
||||
try_var(out_shape, module().shape_reg(op.rshape_dest));
|
||||
try_var(out_strides, module().shape_reg(op.rstride_dest));
|
||||
try_var(indices_shape, module().shape_reg(op.rshape_indices));
|
||||
|
||||
return kernels::gather_nd(op.datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<gsl::byte *>(output), in_shape, out_shape,
|
||||
in_strides, out_strides, reinterpret_cast<const int32_t *>(indices), indices_shape, op.batch_dims);
|
||||
|
|
|
@ -12,23 +12,23 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/tensor_compute.h>
|
||||
|
||||
using namespace nncase;
|
||||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const tensor_lut1d_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const tensor_lut1d_op_t &op) noexcept
|
||||
{
|
||||
try_var(max_value, pop_scalar(op.datatype));
|
||||
try_var(min_value, pop_scalar(op.datatype));
|
||||
try_var(output, pop_addr());
|
||||
try_var(table, pop_addr());
|
||||
try_var(input, pop_addr());
|
||||
auto &shape = shape_regs_[op.rshape_src];
|
||||
auto &in_strides = shape_regs_[op.rstride_src];
|
||||
auto &out_strides = shape_regs_[op.rstride_dest];
|
||||
try_var(shape, module().shape_reg(op.rshape_src));
|
||||
try_var(in_strides, module().shape_reg(op.rstride_src));
|
||||
try_var(out_strides, module().shape_reg(op.rstride_dest));
|
||||
|
||||
return kernels::lut1d(op.datatype, reinterpret_cast<const gsl::byte *>(input), reinterpret_cast<const gsl::byte *>(table),
|
||||
reinterpret_cast<gsl::byte *>(output), shape, in_strides, out_strides, min_value, max_value);
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "../runtime_module.h"
|
||||
#include "../runtime_function.h"
|
||||
#include <nncase/kernels/tensor_compute.h>
|
||||
#include <nncase/runtime/interpreter.h>
|
||||
#include <nncase/runtime/runtime_op_utility.h>
|
||||
|
@ -22,7 +22,7 @@ using namespace nncase;
|
|||
using namespace nncase::runtime;
|
||||
using namespace nncase::runtime::stackvm;
|
||||
|
||||
result<void> stackvm_runtime_module::visit(const tensor_onehot_op_t &op) noexcept
|
||||
result<void> stackvm_runtime_function::visit(const tensor_onehot_op_t &op) noexcept
|
||||
{
|
||||
try_var(output, pop_addr());
|
||||
try_var(off_value, pop_addr());
|
||||
|
@ -30,11 +30,11 @@ result<void> stackvm_runtime_module::visit(const tensor_onehot_op_t &op) noexcep
|
|||
try_var(depth, pop_addr());
|
||||
try_var(indices, pop_addr());
|
||||
|
||||
auto &indices_shape = shape_regs_[op.rshape_indices];
|
||||
auto &out_shape = shape_regs_[op.rshape_dest];
|
||||
auto &out_strides = shape_regs_[op.rstride_dest];
|
||||
try_var(indices_shape, module().shape_reg(op.rshape_indices));
|
||||
try_var(out_shape, module().shape_reg(op.rshape_dest));
|
||||
try_var(out_strides, module().shape_reg(op.rstride_dest));
|
||||
|
||||
return kernels::onehot(op.datatype, reinterpret_cast<const int32_t *>(indices), reinterpret_cast<gsl::byte *>(output),
|
||||
indices_shape, out_shape, out_strides, reinterpret_cast<gsl::byte *>(depth), reinterpret_cast<gsl::byte *>(off_value),
|
||||
reinterpret_cast<gsl::byte *>(on_value), op.axis, op.onehot_mode);
|
||||
reinterpret_cast<gsl::byte *>(on_value), op.axis, op.onehot_mode, module().kernel_context());
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue