kendryte-standalone-sdk/lib/nncase/v1/include/nncase/runtime/runtime_module.h

95 lines
3.6 KiB
C++

/* Copyright 2019-2020 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "model.h"
#include "result.h"
#include "runtime_tensor.h"
BEGIN_NS_NNCASE_RUNTIME
class interpreter;
struct NNCASE_API runtime_module_init_context
{
virtual bool is_section_pinned() const noexcept = 0;
virtual interpreter &interp() noexcept = 0;
virtual const module_header &header() noexcept = 0;
virtual gsl::span<const gsl::byte> section(const char *name) noexcept = 0;
};
class NNCASE_API runtime_module
{
private:
struct inout_tensor_info
{
runtime_shape_t shape;
runtime_shape_t strides;
memory_range range;
runtime_tensor bind_tensor;
runtime_tensor staging_tensor;
runtime_tensor device_tensor;
};
public:
static result<std::unique_ptr<runtime_module>> create(const module_type_t &type);
runtime_module() = default;
runtime_module(runtime_module &) = delete;
virtual ~runtime_module() = default;
result<void> initialize(const module_header &header, interpreter &interp) noexcept;
virtual result<void> initialize_inter_modules(interpreter &interp) noexcept;
const module_type_t &type() const noexcept;
interpreter &interp() const noexcept { return *interp_; }
uint32_t mempools_size() const noexcept;
const mempool_desc &mempool(size_t index) const noexcept;
mempool_desc mempool(memory_location_t location) const noexcept;
uint32_t inputs_size() const noexcept;
const runtime_shape_t &input_shape(size_t index) const noexcept;
const memory_range &input_desc(size_t index) const noexcept;
result<runtime_tensor> input_tensor(size_t index) noexcept;
result<void> input_tensor(size_t index, runtime_tensor tensor) noexcept;
uint32_t outputs_size() const noexcept;
const runtime_shape_t &output_shape(size_t index) const noexcept;
const memory_range &output_desc(size_t index) const noexcept;
result<runtime_tensor> output_tensor(size_t index) noexcept;
result<void> output_tensor(size_t index, runtime_tensor tensor) noexcept;
result<void> run() noexcept;
protected:
virtual result<void> initialize_core(runtime_module_init_context &context) noexcept = 0;
virtual result<runtime_tensor> allocate_input_tensor(size_t index) noexcept = 0;
virtual result<runtime_tensor> allocate_output_tensor(size_t index) noexcept = 0;
virtual result<void> validate_input_tensor(size_t index, runtime_tensor tensor) noexcept = 0;
virtual result<void> validate_output_tensor(size_t index, runtime_tensor tensor) noexcept = 0;
result<runtime_tensor> device_input_tensor(size_t index) noexcept;
result<runtime_tensor> device_output_tensor(size_t index) noexcept;
virtual result<void> run_core() noexcept = 0;
private:
module_header header_;
std::vector<mempool_desc> mempools_;
std::vector<inout_tensor_info> input_tensors_;
std::vector<inout_tensor_info> output_tensors_;
interpreter *interp_ = nullptr;
};
END_NS_NNCASE_RUNTIME