nncase/include/nncase/ir/op_utils.h

380 lines
11 KiB
C++

/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "node.h"
#include <span>
#include <xtensor/xtensor.hpp>
namespace nncase::ir
{
inline shape_t get_transposed_shape(const shape_t &input_shape, const axis_t &perm)
{
shape_t new_shape(input_shape.size());
for (size_t i = 0; i < new_shape.size(); i++)
new_shape[i] = input_shape[perm[i]];
return new_shape;
}
inline size_t get_windowed_output_size(int32_t size, int32_t filter, int32_t stride, int32_t dilation, bool same, bool ceil_mode = false)
{
auto effective_filter_size = (filter - 1) * dilation + 1;
if (same)
return (size_t(size) + stride - 1) / stride;
else
{
if (!ceil_mode)
return (size_t(size) - effective_filter_size + stride) / stride;
else
{
return static_cast<int>(ceil(static_cast<float>(size_t(size) - effective_filter_size + stride) / stride));
}
}
}
inline padding get_windowed_padding(int32_t input_size, int32_t output_size, int32_t filter, int32_t stride, int32_t dilation)
{
auto effective_filter_size = (filter - 1) * dilation + 1;
int padding = std::max(0, (output_size - 1) * stride + effective_filter_size - input_size);
return { padding / 2, padding - padding / 2 };
}
inline padding get_windowed_padding(int32_t input_size, int32_t filter, int32_t stride, int32_t dilation, bool same)
{
auto output_size = get_windowed_output_size(input_size, filter, stride, dilation, same);
return get_windowed_padding(input_size, (int32_t)output_size, filter, stride, dilation);
}
inline constexpr size_t get_bytes(datatype_t type)
{
switch (type)
{
#define DEFINE_DATATYPE(id, t, name, value) \
case (dt_##id): \
return sizeof(t);
#include <nncase/runtime/datatypes.def>
#undef DEFINE_DATATYPE
default:
throw std::invalid_argument("Invalid datatype");
}
}
inline size_t get_bytes(datatype_t type, const shape_t &shape)
{
return xt::compute_size(shape) * get_bytes(type);
}
inline nncase::ir::shape_t to_strides(const nncase::ir::shape_t &shape)
{
nncase::ir::shape_t strides(shape.size());
xt::compute_strides(shape, xt::layout_type::row_major, strides);
return strides;
}
inline int32_t normalize_axis(const shape_t &input_shape, int32_t axis)
{
return axis < 0 ? (int32_t)input_shape.size() + axis : axis;
}
inline axis_t normalize_axis(const shape_t &input_shape, const axis_t &axis)
{
axis_t new_axis = axis;
for (auto &a : new_axis)
{
if (a < 0)
a = (int32_t)input_shape.size() + a;
}
return new_axis;
}
inline axis_t normalize_reduce_axis(const shape_t &input_shape, const axis_t &axis)
{
axis_t new_axis = normalize_axis(input_shape, axis);
std::sort(new_axis.begin(), new_axis.end());
return new_axis;
}
inline shape_t get_reduced_shape(const shape_t &input_shape, const axis_t &axis, bool keep_dims)
{
if (!std::is_sorted(axis.begin(), axis.end()))
throw std::invalid_argument("axis must be sorted");
shape_t shape;
for (size_t i = 0; i < input_shape.size(); i++)
{
if (std::find(axis.begin(), axis.end(), i) == axis.end())
shape.push_back(input_shape[i]);
else if (keep_dims)
shape.push_back(1);
}
if (shape.empty())
shape.push_back(1);
return shape;
}
inline shape_t normalize_reshape(const shape_t &in_shape, const axis_t &new_shape)
{
shape_t result(new_shape.size());
size_t shape_size = 1;
std::optional<size_t> non_det_id;
for (size_t i = 0; i < new_shape.size(); i++)
{
auto v = new_shape[i];
if (v == -1)
{
if (non_det_id)
throw std::runtime_error("Reshape can only have 1 non-determined dimension at most");
non_det_id = i;
}
else
{
shape_size *= v;
result[i] = (size_t)new_shape[i];
}
}
if (non_det_id)
result[*non_det_id] = xt::compute_size(in_shape) / shape_size;
return result;
}
inline shape_t get_binary_output_shape(const shape_t &input_a_shape, const shape_t &input_b_shape)
{
shape_t out_shape;
const auto dest_dims = (int32_t)std::max(input_a_shape.size(), input_b_shape.size());
const auto in_a_ext = dest_dims - (int32_t)input_a_shape.size();
const auto in_b_ext = dest_dims - (int32_t)input_b_shape.size();
for (int32_t i = 0; i < dest_dims; i++)
{
const auto in_a_dim = i - (int32_t)in_a_ext;
const auto in_b_dim = i - (int32_t)in_b_ext;
const auto in_a = in_a_dim < 0 ? 1 : input_a_shape[in_a_dim];
const auto in_b = in_b_dim < 0 ? 1 : input_b_shape[in_b_dim];
if (in_a == in_b)
out_shape.push_back(in_a);
else if (in_a == 1)
out_shape.push_back(in_b);
else if (in_b == 1)
out_shape.push_back(in_a);
else
throw std::invalid_argument("inputs are not compatible to broadcast");
}
return out_shape;
}
inline std::vector<shape_t> get_input_shapes(std::span<input_connector *const> inputs)
{
std::vector<shape_t> shapes;
shapes.reserve(inputs.size());
for (auto in : inputs)
shapes.emplace_back(in->shape());
return shapes;
}
inline shape_t get_concated_shape(std::span<shape_t> input_shapes, size_t axis)
{
if (input_shapes.empty())
throw std::invalid_argument("there must be at least one input");
auto concated_shape = input_shapes[0];
for (size_t i = 1; i < input_shapes.size(); i++)
{
auto &cur_shape = input_shapes[i];
if (concated_shape.size() != cur_shape.size())
throw std::invalid_argument("inputs must have same ranks");
for (size_t j = 0; j < concated_shape.size(); j++)
{
if (j == axis)
{
concated_shape[j] += cur_shape[j];
}
else if (cur_shape[j] != concated_shape[j])
{
throw std::invalid_argument("inputs are not compatible to concat");
}
}
}
return concated_shape;
}
inline void get_concat_params(const shape_t &out_shape, size_t elem_size, size_t axis, uint64_t &inner_size, uint64_t &outer_size)
{
inner_size = elem_size;
outer_size = 1;
for (size_t i = 0; i < out_shape.size(); i++)
{
if (i > axis)
inner_size *= out_shape[i];
else if (i < axis)
outer_size *= out_shape[i];
}
}
inline shape_t get_padded_shape(const shape_t &in_shape, const xt::svector<padding> &paddings)
{
auto new_shape = in_shape;
for (size_t i = 0; i < in_shape.size(); i++)
new_shape[i] = size_t(int32_t(new_shape[i]) + paddings[i].sum() + (new_shape[i] - 1) * paddings[i].interior);
return new_shape;
}
inline shape_t get_resize_image_shape(const shape_t &in_shape, const std::array<int32_t, 2> &new_size)
{
auto new_shape = in_shape;
auto r = new_shape.rbegin();
*r++ = new_size[1];
*r++ = new_size[0];
return new_shape;
}
inline axis_t normalize_strided_slice_begin(const shape_t &in_shape, const axis_t &begin, const axis_t &strides, int32_t begin_mask)
{
axis_t new_shape(strides.size());
for (size_t i = 0; i < new_shape.size(); i++)
{
auto stride = strides[i];
assert(stride);
new_shape[i] = (begin_mask & (1 << i)) != 0
? stride > 0 ? 0 : (int32_t)in_shape[i] - 1
: (begin[i] >= 0 ? begin[i] : (int32_t)in_shape[i] + begin[i]);
}
return new_shape;
}
inline axis_t normalize_strided_slice_end(const shape_t &in_shape, [[maybe_unused]] const axis_t &begin, const axis_t &end, const axis_t &strides, int32_t end_mask)
{
axis_t new_shape(strides.size());
for (size_t i = 0; i < new_shape.size(); i++)
{
auto stride = strides[i];
auto end_val = (end_mask & (1 << i)) != 0
? stride > 0 ? (int32_t)in_shape[i] : -1
: (end[i] >= 0 ? end[i] : in_shape[i] + end[i]);
new_shape[i] = (int32_t)end_val;
}
return new_shape;
}
inline shape_t get_strided_slice_output_shape(const axis_t &begin, const axis_t &end, const axis_t &strides, int32_t ellipsis_mask, int32_t new_axis_mask)
{
if (ellipsis_mask)
throw std::invalid_argument("Non-zero ellipsis_mask is not supported");
if (new_axis_mask)
throw std::invalid_argument("Non-zero new_axis_mask is not supported");
shape_t new_shape;
for (size_t i = 0; i < strides.size(); i++)
{
auto stride = strides[i];
auto begin_val = begin[i];
auto end_val = end[i];
auto dim = (int)std::ceil((end_val - begin_val) / (float)stride);
new_shape.push_back(dim);
}
return new_shape.size() ? new_shape : shape_t { 1 };
}
inline bool is_copy_slice(const axis_t &strides)
{
return std::all_of(strides.begin(), strides.end(), [](int32_t stride) { return stride == 1; });
}
inline bool is_simple_slice(const axis_t &begin, const axis_t &end, const axis_t &strides, const shape_t &input_shape)
{
if (!is_copy_slice(strides))
return false;
bool is_simple_slice = true;
bool allow_not_equal = true;
for (size_t i = 0; i < begin.size(); i++)
{
if (begin[i] != 0
|| end[i] != input_shape[i])
{
if (allow_not_equal)
{
allow_not_equal = false;
}
else
{
is_simple_slice = false;
break;
}
}
else if (input_shape[i] != 1)
{
allow_not_equal = false;
}
}
return is_simple_slice;
}
inline bool is_axis0_squeeze_or_expand_dim_bitcast(const shape_t &in_shape, const shape_t &out_shape)
{
auto in_begin = std::find_if_not(in_shape.begin(), in_shape.end(), [](size_t dim) { return dim == 1; });
auto out_begin = std::find_if_not(out_shape.begin(), out_shape.end(), [](size_t dim) { return dim == 1; });
return std::distance(in_begin, in_shape.end()) == std::distance(out_begin, out_shape.end())
&& std::equal(in_begin, in_shape.end(), out_begin);
}
template <class U, class T>
std::span<U> as_span(const std::span<T> &src) noexcept
{
assert(src.size_bytes() % sizeof(U) == 0);
return std::span<U>(reinterpret_cast<U *>(src.data()), src.size_bytes() / sizeof(U));
}
}
namespace xt
{
inline nncase::ir::shape_t operator+(const nncase::ir::shape_t &lhs, const nncase::ir::shape_t &rhs)
{
using namespace nncase::ir;
if (lhs.size() != rhs.size())
throw std::invalid_argument("Shape's rank mismatch");
shape_t ret = lhs;
for (size_t i = 0; i < lhs.size(); i++)
ret[i] += rhs[i];
return ret;
}
inline nncase::ir::shape_t &operator+=(nncase::ir::shape_t &lhs, const nncase::ir::shape_t &rhs)
{
using namespace nncase::ir;
if (lhs.size() != rhs.size())
throw std::invalid_argument("Shape's rank mismatch");
for (size_t i = 0; i < lhs.size(); i++)
lhs[i] += rhs[i];
return lhs;
}
}