[nncase] Update runtime

pull/107/head
sunnycase 2019-11-27 10:38:34 +08:00
parent a3ac928968
commit 1f4842a763
6 changed files with 158 additions and 23 deletions

View File

@ -13,15 +13,15 @@
* limitations under the License.
*/
#pragma once
#include <filesystem>
#include <boost/filesystem.hpp>
#include <fstream>
#include <vector>
namespace nncase
{
inline std::vector<uint8_t> read_file(const std::filesystem::path &filename)
inline std::vector<uint8_t> read_file(const boost::filesystem::path &filename)
{
std::ifstream infile(filename, std::ios::binary | std::ios::in);
std::ifstream infile(filename.c_str(), std::ios::binary | std::ios::in);
if (!infile.good())
throw std::runtime_error("Cannot open file: " + filename.string());

View File

@ -219,6 +219,70 @@ namespace kernels
}
}
inline void conv2d_transpose(const float *input, float *output, const float *weights, const float *bias, const runtime_shape_t &in_shape,
int32_t groups, const runtime_shape_t &out_shape, int32_t filter_h, int32_t filter_w, int32_t stride_h, int32_t stride_w, int32_t dilation_h, int32_t dilation_w,
const padding &padding_h, const padding &padding_w, const value_range<float> &fused_activation)
{
std::fill(output, output + kernels::details::compute_size(out_shape), 0.f);
const auto g_ic = in_shape[1] / groups;
const auto g_oc = out_shape[1] / groups;
for (int32_t batch = 0; batch < in_shape[0]; batch++)
{
float *out_batch_p = output + (size_t)batch * out_shape[1] * out_shape[2] * out_shape[3];
for (int32_t g = 0; g < groups; g++)
{
float *out_group_p = out_batch_p + (size_t)g * g_oc * out_shape[2] * out_shape[3];
const float *w_group_p = weights + (size_t)g * g_oc * g_ic * filter_h * filter_w;
for (int32_t ic = 0; ic < g_ic; ic++)
{
for (int32_t iy = 0; iy < in_shape[2]; iy++)
{
for (int32_t ix = 0; ix < in_shape[3]; ix++)
{
const int32_t out_y_origin = (iy * stride_h) - padding_h.before;
const int32_t out_x_origin = (ix * stride_w) - padding_w.before;
const int32_t filter_y_start = std::max(0, (-out_y_origin + dilation_h - 1) / dilation_h);
const int32_t filter_y_end = std::min(filter_h, (out_shape[2] - out_y_origin + dilation_h - 1) / dilation_h);
const int32_t filter_x_start = std::max(0, (-out_x_origin + dilation_w - 1) / dilation_w);
const int32_t filter_x_end = std::min(filter_w, (out_shape[3] - out_x_origin + dilation_w - 1) / dilation_w);
const float in_v = *input++;
for (int32_t oc = 0; oc < g_oc; oc++)
{
float value = bias[g * g_oc + oc];
float *out_c_p = out_group_p + (size_t)oc * out_shape[2] * out_shape[3];
const float *w_oc_p = w_group_p + (size_t)oc * g_ic * filter_h * filter_w;
const float *w_ic_p = w_oc_p + (size_t)ic * filter_h * filter_w;
for (int32_t ky = filter_y_start; ky < filter_y_end; ky++)
{
for (int32_t kx = filter_x_start; kx < filter_x_end; kx++)
{
const int32_t out_y = out_y_origin + dilation_h * ky;
const int32_t out_x = out_x_origin + dilation_w * kx;
const float w = w_ic_p[ky * filter_w + kx];
out_c_p[out_y * out_shape[3] + out_x] += in_v * w;
}
}
}
}
}
}
}
}
if (fused_activation != value_range<float>::full())
{
for (size_t i = 0; i < kernels::details::compute_size(out_shape); i++)
output[i] = details::apply_activation(output[i], fused_activation);
}
}
template <class TQ>
void dequantize(const TQ *input, float *output, size_t count, const quant_param_t &param)
{

View File

@ -75,11 +75,11 @@ namespace runtime
auto layer_pos = writer.position();
writer.position(layer_pos + std::streamoff(sizeof(layer)));
layer.kernel_pool_type_cfg.data.bwsx_base_addr = (uint32_t)writer.align_position(8);
layer.kernel_pool_type_cfg.data.bwsx_base_addr = (uint32_t)writer.align_position(256);
writer.write_array(batch_norm);
layer.kernel_calc_type_cfg.data.active_addr = (uint32_t)writer.align_position(256);
writer.write(*activation);
layer.kernel_load_cfg.data.para_start_addr = (uint32_t)writer.align_position(128);
layer.kernel_load_cfg.data.para_start_addr = (uint32_t)writer.align_position(256);
writer.write_array(weights);
auto end_pos = writer.position();

View File

@ -215,6 +215,66 @@ namespace runtime
}
};
struct conv2d_transpose_options
{
memory_range input;
memory_range output;
runtime_shape_t in_shape;
runtime_shape_t out_shape;
int32_t groups;
padding padding_h;
padding padding_w;
int32_t filter_h;
int32_t filter_w;
int32_t stride_h;
int32_t stride_w;
int32_t dilation_h;
int32_t dilation_w;
value_range<float> fused_activation;
xtl::span<const float> weights;
xtl::span<const float> bias;
void deserialize(span_reader &reader)
{
reader.read(input);
reader.read(output);
reader.read(in_shape);
reader.read(out_shape);
reader.read(groups);
reader.read(padding_h);
reader.read(padding_w);
reader.read(filter_h);
reader.read(filter_w);
reader.read(stride_h);
reader.read(stride_w);
reader.read(dilation_h);
reader.read(dilation_w);
reader.read(fused_activation);
reader.read_span(weights, (size_t)out_shape[1] * in_shape[1] / groups * filter_h * filter_w);
reader.read_span(bias, out_shape[1]);
}
void serialize(binary_writer &writer) const
{
writer.write(input);
writer.write(output);
writer.write(in_shape);
writer.write(out_shape);
writer.write(groups);
writer.write(padding_h);
writer.write(padding_w);
writer.write(filter_h);
writer.write(filter_w);
writer.write(stride_h);
writer.write(stride_w);
writer.write(dilation_h);
writer.write(dilation_w);
writer.write(fused_activation);
writer.write_array(weights);
writer.write_array(bias);
}
};
struct dequantize_options : public simple_node_body<dequantize_options>
{
memory_range input;

View File

@ -1,22 +1,24 @@
BEGINE_DEFINE_TARGET(neutral)
DEFINE_NEUTRAL_RUNTIME_OP(binary, Binary, 0x0)
DEFINE_NEUTRAL_RUNTIME_OP(concat, Concat, 0x1)
DEFINE_NEUTRAL_RUNTIME_OP(conv2d, Conv2D, 0x2)
DEFINE_NEUTRAL_RUNTIME_OP(dequantize, Dequantize, 0x3)
DEFINE_NEUTRAL_RUNTIME_OP(matmul, MatMul, 0x4)
DEFINE_NEUTRAL_RUNTIME_OP(pad, Pad, 0x5)
DEFINE_NEUTRAL_RUNTIME_OP(quantize, Quantize, 0x6)
DEFINE_NEUTRAL_RUNTIME_OP(reduce, Reduce, 0x7)
DEFINE_NEUTRAL_RUNTIME_OP(reduce_window2d, ReduceWindow2D, 0x8)
DEFINE_NEUTRAL_RUNTIME_OP(memory_copy, MemoryCopy, 0x9)
DEFINE_NEUTRAL_RUNTIME_OP(resize_image, ResizeImage, 0x0A)
DEFINE_NEUTRAL_RUNTIME_OP(softmax, Softmax, 0x0B)
DEFINE_NEUTRAL_RUNTIME_OP(transpose, Transpose, 0x0C)
DEFINE_NEUTRAL_RUNTIME_OP(strided_slice, StridedSlice, 0x0D)
DEFINE_NEUTRAL_RUNTIME_OP(unary, Unary, 0x0E)
DEFINE_NEUTRAL_RUNTIME_OP(quantized_conv2d, QuantizedConv2D, 0x0F)
DEFINE_NEUTRAL_RUNTIME_OP(quantized_matmul, QuantizedMatMul, 0x10)
DEFINE_NEUTRAL_RUNTIME_OP(quantized_binary, QuantizedBinary, 0x11)
DEFINE_NEUTRAL_RUNTIME_OP(binary, Binary, 0x0)
DEFINE_NEUTRAL_RUNTIME_OP(concat, Concat, 0x1)
DEFINE_NEUTRAL_RUNTIME_OP(conv2d, Conv2D, 0x2)
DEFINE_NEUTRAL_RUNTIME_OP(dequantize, Dequantize, 0x3)
DEFINE_NEUTRAL_RUNTIME_OP(matmul, MatMul, 0x4)
DEFINE_NEUTRAL_RUNTIME_OP(pad, Pad, 0x5)
DEFINE_NEUTRAL_RUNTIME_OP(quantize, Quantize, 0x6)
DEFINE_NEUTRAL_RUNTIME_OP(reduce, Reduce, 0x7)
DEFINE_NEUTRAL_RUNTIME_OP(reduce_window2d, ReduceWindow2D, 0x8)
DEFINE_NEUTRAL_RUNTIME_OP(memory_copy, MemoryCopy, 0x9)
DEFINE_NEUTRAL_RUNTIME_OP(resize_image, ResizeImage, 0x0A)
DEFINE_NEUTRAL_RUNTIME_OP(softmax, Softmax, 0x0B)
DEFINE_NEUTRAL_RUNTIME_OP(transpose, Transpose, 0x0C)
DEFINE_NEUTRAL_RUNTIME_OP(strided_slice, StridedSlice, 0x0D)
DEFINE_NEUTRAL_RUNTIME_OP(unary, Unary, 0x0E)
DEFINE_NEUTRAL_RUNTIME_OP(quantized_conv2d, QuantizedConv2D, 0x0F)
DEFINE_NEUTRAL_RUNTIME_OP(quantized_matmul, QuantizedMatMul, 0x10)
DEFINE_NEUTRAL_RUNTIME_OP(quantized_binary, QuantizedBinary, 0x11)
// DEFINE_NEUTRAL_RUNTIME_OP(table_lookup1d, TableLookup1D, 0x12)
DEFINE_NEUTRAL_RUNTIME_OP(conv2d_transpose, QuantizedBinary, 0x13)
END_DEFINE_TARGET()
// CPU

View File

@ -153,6 +153,15 @@ namespace runtime
return kcr_done;
}
kernel_call_result conv2d_transpose(conv2d_transpose_options &options, interpreter_t &interpreter, interpreter_step_t step)
{
auto input = interpreter.memory_at<float>(options.input);
auto output = interpreter.memory_at<float>(options.output);
kernels::neutral::conv2d_transpose(input.data(), output.data(), options.weights.data(), options.bias.data(), options.in_shape, options.groups, options.out_shape, options.filter_h,
options.filter_w, options.stride_h, options.stride_w, options.dilation_h, options.dilation_w, options.padding_h, options.padding_w, options.fused_activation);
return kcr_done;
}
kernel_call_result dequantize(dequantize_options &options, interpreter_t &interpreter, interpreter_step_t step)
{
auto input = interpreter.memory_at<uint8_t>(options.input);