Feature/add preprocess (#940)

* fix type

* add preprocess

* fix resizeimage shape check

* Apply code-format changes

* fix warning

* fix warning

* update json && fix test runner bug

* remove unused file

* revert packages.lock.json

---------

Co-authored-by: yanghaoqi <yanghaoqi_intern@canaan-creative.com>
Co-authored-by: curioyang <curioyang@users.noreply.github.com>
Co-authored-by: 郑启航 <597323109@qq.com>
pull/948/head
Curio Yang 2023-05-30 17:20:28 +08:00 committed by GitHub
parent 55e9643662
commit f0fac5f706
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 733 additions and 158 deletions

View File

@ -33,7 +33,7 @@ using namespace nncase;
using namespace nncase::runtime;
template <typename T> datatype_t from_dtype() {
if (std::is_same_v<T, float>)
if (std::is_same_v<T, uint8_t>)
return dt_uint8;
else if (std::is_same_v<T, uint16_t>)
return dt_uint16;

View File

@ -223,6 +223,21 @@ class Compiler:
dump_flags = _nncase.DumpFlags(dump_flags | _nncase.DumpFlags.CodeGen)
self._compile_options.dump_flags = dump_flags
self._compile_options.dump_dir = compile_options.dump_dir
self._compile_options.preprocess = compile_options.preprocess
self._compile_options.input_layout = compile_options.input_layout
self._compile_options.output_layout = compile_options.output_layout
if compile_options.input_type == "uint8":
self._compile_options.input_type = _nncase.InputType.Uint8
elif compile_options.input_type == "int8":
self._compile_options.input_type = _nncase.InputType.Int8
if compile_options.input_type == "float32":
self._compile_options.input_type = _nncase.InputType.Float32
self._compile_options.input_shape = str(compile_options.input_shape)[1:-1]
self._compile_options.input_range = str(compile_options.input_range)[1:-1]
self._compile_options.swapRB = compile_options.swapRB
self._compile_options.letterbox_value = compile_options.letterbox_value
self._compile_options.mean = str(compile_options.mean)[1:-1]
self._compile_options.std = str(compile_options.std)[1:-1]
def _import_module(self, model_content: bytes | io.RawIOBase) -> None:
stream = io.BytesIO(model_content) if isinstance(model_content, bytes) else model_content
@ -278,6 +293,16 @@ class ClCompileOptions():
OutputFile: str
ModelQuantMode: int
QuantizeOptions: ClQuantizeOptions
SwapRB: bool
InputRange: List[float]
InputShape: List[int]
InputType: str
Mean: List[float]
Std: List[float]
PreProcess: bool
InputLayout: str
OutputLayout: str
LetterBoxValue: float
class CompileOptions:

View File

@ -79,6 +79,11 @@ PYBIND11_MODULE(_nncase, m) {
.value("Int8", nncase_qt_int8)
.value("Int16", nncase_qt_int16);
py::enum_<nncase_input_type_t>(m, "InputType")
.value("Uint8", nncase_it_uint8)
.value("Int8", nncase_it_int8)
.value("Float32", nncase_it_float32);
py::enum_<nncase_finetune_weights_method_t>(m, "FineTuneWeightsMethod")
.value("NoFineTuneWeights", nncase_no_finetune_weights)
.value("UseSquant", nncase_finetune_weights_squant)
@ -99,7 +104,39 @@ PYBIND11_MODULE(_nncase, m) {
.def_property("quantize_options",
py::overload_cast<>(&compile_options::quantize_options),
py::overload_cast<const quantize_options &>(
&compile_options::quantize_options));
&compile_options::quantize_options))
.def_property("preprocess",
py::overload_cast<>(&compile_options::preprocess),
py::overload_cast<bool>(&compile_options::preprocess))
.def_property(
"input_layout", py::overload_cast<>(&compile_options::input_layout),
py::overload_cast<std::string_view>(&compile_options::input_layout))
.def_property("output_layout",
py::overload_cast<>(&compile_options::output_layout),
py::overload_cast<std::string_view>(
&compile_options::output_layout))
.def_property("input_type",
py::overload_cast<>(&compile_options::input_type),
py::overload_cast<nncase_input_type_t>(
&compile_options::input_type))
.def_property(
"input_shape", py::overload_cast<>(&compile_options::input_shape),
py::overload_cast<std::string_view>(&compile_options::input_shape))
.def_property(
"input_range", py::overload_cast<>(&compile_options::input_range),
py::overload_cast<std::string_view>(&compile_options::input_range))
.def_property("swapRB", py::overload_cast<>(&compile_options::swapRB),
py::overload_cast<bool>(&compile_options::swapRB))
.def_property(
"letterbox_value",
py::overload_cast<>(&compile_options::letterbox_value),
py::overload_cast<float>(&compile_options::letterbox_value))
.def_property(
"mean", py::overload_cast<>(&compile_options::mean),
py::overload_cast<std::string_view>(&compile_options::mean))
.def_property(
"std", py::overload_cast<>(&compile_options::std),
py::overload_cast<std::string_view>(&compile_options::std));
py::class_<target>(m, "Target")
.def(py::init<std::string_view>())

View File

@ -69,6 +69,12 @@ typedef enum {
nncase_dump_flags_codegen = 1 << 10
} nncase_dump_flags_t;
typedef enum {
nncase_it_uint8 = 0,
nncase_it_int8 = 1,
nncase_it_float32 = 2
} nncase_input_type_t;
typedef struct {
void (*add_ref)(nncase_stream_handle_t handle);
void (*release)(nncase_stream_handle_t handle);
@ -114,6 +120,30 @@ typedef struct {
void (*compile_options_set_quantize_options)(
clr_object_handle_t compile_options,
clr_object_handle_t quantize_options);
void (*compile_options_set_preprocess)(clr_object_handle_t compile_options,
bool preprocess);
void (*compile_options_set_input_layout)(
clr_object_handle_t compile_options, const char *input_layout,
size_t input_layout_length);
void (*compile_options_set_output_layout)(
clr_object_handle_t compile_options, const char *output_layout,
size_t output_layout_length);
void (*compile_options_set_input_type)(clr_object_handle_t compile_options,
nncase_input_type_t input_type);
void (*compile_options_set_input_shape)(clr_object_handle_t compile_options,
const char *input_shape,
size_t input_shape_length);
void (*compile_options_set_input_range)(clr_object_handle_t compile_options,
const char *input_range,
size_t input_range_length);
void (*compile_options_set_swapRB)(clr_object_handle_t compile_options,
bool swapRB);
void (*compile_options_set_letterbox_value)(
clr_object_handle_t compile_options, float letterbox_value);
void (*compile_options_set_mean)(clr_object_handle_t compile_options,
const char *mean, size_t mean_length);
void (*compile_options_set_std)(clr_object_handle_t compile_options,
const char *std, size_t std_length);
clr_object_handle_t (*compile_session_create)(
clr_object_handle_t target, clr_object_handle_t compile_options);
clr_object_handle_t (*compile_session_get_compiler)(
@ -391,6 +421,63 @@ class compile_options : public clr_object_base {
nncase_clr_api()->compile_options_set_quantize_options(obj_.get(),
value.get());
}
bool preprocess() { return true; }
void preprocess(bool value) {
nncase_clr_api()->compile_options_set_preprocess(obj_.get(), value);
}
std::string input_layout() { return ""; }
void input_layout(std::string_view value) {
nncase_clr_api()->compile_options_set_input_layout(
obj_.get(), value.data(), value.length());
}
std::string output_layout() { return ""; }
void output_layout(std::string_view value) {
nncase_clr_api()->compile_options_set_output_layout(
obj_.get(), value.data(), value.length());
}
nncase_input_type_t input_type() { return nncase_it_float32; }
void input_type(nncase_input_type_t value) {
nncase_clr_api()->compile_options_set_input_type(obj_.get(), value);
}
std::string input_shape() { return ""; }
void input_shape(std::string_view value) {
nncase_clr_api()->compile_options_set_input_shape(
obj_.get(), value.data(), value.length());
}
std::string input_range() { return ""; }
void input_range(std::string_view value) {
nncase_clr_api()->compile_options_set_input_range(
obj_.get(), value.data(), value.length());
}
bool swapRB() { return false; }
void swapRB(bool value) {
nncase_clr_api()->compile_options_set_swapRB(obj_.get(), value);
}
float letterbox_value() { return 0.f; }
void letterbox_value(float value) {
nncase_clr_api()->compile_options_set_letterbox_value(obj_.get(),
value);
}
std::string mean() { return ""; }
void mean(std::string_view value) {
nncase_clr_api()->compile_options_set_mean(obj_.get(), value.data(),
value.length());
}
std::string std() { return ""; }
void std(std::string_view value) {
nncase_clr_api()->compile_options_set_std(obj_.get(), value.data(),
value.length());
}
};
class target : public clr_object_base {

View File

@ -12,6 +12,7 @@ using Nncase.Hosting;
using Nncase.IR;
using Nncase.Passes;
using Nncase.Passes.Rules.Lower;
using Nncase.Passes.Rules.Neutral;
using Nncase.Passes.Transforms;
using Nncase.Quantization;
using Nncase.Utilities;
@ -47,8 +48,11 @@ internal class Compiler : ICompiler
_dumpper.DumpModule(module, "IRImport");
}
var preprocess_option = _compileSession.CompileOptions;
await RunPassAsync(pmg => BroadcastOutputNamesAfterImportPass(pmg), "BroadcastOutputNamesAfterImport");
await RunPassAsync(pmg => pmg.Add<ShapeInferPass>(), "ShapeInferAfterImport");
await RunPassAsync(pmg => pmg.Add<AddPreProcess>(), "AddPreProcessAfterImport");
var inferSucc = CompilerServices.InferenceType(module.Entry!);
if (!inferSucc)

View File

@ -48,6 +48,16 @@ public unsafe struct CApiMT
public delegate* unmanaged<IntPtr, DumpFlags> CompileOptionsGetDumpFlagsPtr;
public delegate* unmanaged<IntPtr, DumpFlags, void> CompileOptionsSetDumpFlagsPtr;
public delegate* unmanaged<IntPtr, IntPtr, void> CompileOptionsSetQuantizeOptionsPtr;
public delegate* unmanaged<IntPtr, byte, void> CompileOptionsSetPreProcessPtr;
public delegate* unmanaged<IntPtr, byte*, nuint, void> CompileOptionsSetInputLayoutPtr;
public delegate* unmanaged<IntPtr, byte*, nuint, void> CompileOptionsSetOutputLayoutPtr;
public delegate* unmanaged<IntPtr, byte, void> CompileOptionsSetInputTypePtr;
public delegate* unmanaged<IntPtr, byte*, nuint, void> CompileOptionsSetInputShapePtr;
public delegate* unmanaged<IntPtr, byte*, nuint, void> CompileOptionsSetInputRangePtr;
public delegate* unmanaged<IntPtr, byte, void> CompileOptionsSetSwapRBPtr;
public delegate* unmanaged<IntPtr, float, void> CompileOptionsSetLetterBoxValuePtr;
public delegate* unmanaged<IntPtr, byte*, nuint, void> CompileOptionsSetMeanPtr;
public delegate* unmanaged<IntPtr, byte*, nuint, void> CompileOptionsSetStdPtr;
public delegate* unmanaged<IntPtr, IntPtr, IntPtr> CompileSessionCreatePtr;
public delegate* unmanaged<IntPtr, IntPtr> CompileSessionGetCompilerPtr;
public delegate* unmanaged<void> CompilerInitializePtr;
@ -98,6 +108,16 @@ public static unsafe class CApi
mt->CompileOptionsSetDumpFlagsPtr = &CompileOptionsSetDumpFlags;
mt->CompileOptionsGetDumpFlagsPtr = &CompileOptionsGetDumpFlags;
mt->CompileOptionsSetQuantizeOptionsPtr = &CompileOptionsSetQuantizeOptions;
mt->CompileOptionsSetPreProcessPtr = &CompileOptionsSetPreProcess;
mt->CompileOptionsSetInputLayoutPtr = &CompileOptionsSetInputLayout;
mt->CompileOptionsSetOutputLayoutPtr = &CompileOptionsSetOutputLayout;
mt->CompileOptionsSetInputTypePtr = &CompileOptionsSetInputType;
mt->CompileOptionsSetInputShapePtr = &CompileOptionsSetInputShape;
mt->CompileOptionsSetInputRangePtr = &CompileOptionsSetInputRange;
mt->CompileOptionsSetSwapRBPtr = &CompileOptionsSetSwapRB;
mt->CompileOptionsSetLetterBoxValuePtr = &CompileOptionsSetLetterBoxValue;
mt->CompileOptionsSetMeanPtr = &CompileOptionsSetMean;
mt->CompileOptionsSetStdPtr = &CompileOptionsSetStd;
mt->CompileSessionCreatePtr = &CompileSessionCreate;
mt->CompileSessionGetCompilerPtr = &CompileSessionGetCompiler;
mt->CompilerInitializePtr = &CompilerInitialize;
@ -233,6 +253,100 @@ public static unsafe class CApi
Get<CompileOptions>(compileOptionsHandle).QuantizeOptions = Get<QuantizeOptions>(quantizeOptionsHandle);
}
[UnmanagedCallersOnly]
private static void CompileOptionsSetPreProcess(IntPtr compileOptionsHandle, byte preProcess)
{
switch (preProcess)
{
case 0:
Get<CompileOptions>(compileOptionsHandle).PreProcess = false;
break;
case 1:
Get<CompileOptions>(compileOptionsHandle).PreProcess = true;
break;
default:
throw new ArgumentException("Invalid PreProcess Flag");
}
}
[UnmanagedCallersOnly]
private static void CompileOptionsSetInputLayout(IntPtr compileOptionsHandle, byte* inputLayout, nuint inputLayoutLength)
{
Get<CompileOptions>(compileOptionsHandle).InputLayout = ToString(inputLayout, inputLayoutLength);
}
[UnmanagedCallersOnly]
private static void CompileOptionsSetOutputLayout(IntPtr compileOptionsHandle, byte* outputLayout, nuint outputLayoutLength)
{
Get<CompileOptions>(compileOptionsHandle).OutputLayout = ToString(outputLayout, outputLayoutLength);
}
[UnmanagedCallersOnly]
private static void CompileOptionsSetInputType(IntPtr compileOptionsHandle, byte inputType)
{
// Get<CompileOptions>(compileOptionsHandle).InputType = inputType;
switch (inputType)
{
case 0:
Get<CompileOptions>(compileOptionsHandle).InputType = InputType.Uint8;
break;
case 1:
Get<CompileOptions>(compileOptionsHandle).InputType = InputType.Int8;
break;
case 2:
Get<CompileOptions>(compileOptionsHandle).InputType = InputType.Float32;
break;
default:
throw new ArgumentException("Invalid InputType Flag");
}
}
[UnmanagedCallersOnly]
private static void CompileOptionsSetInputShape(IntPtr compileOptionsHandle, byte* inputShapeValue, nuint inputShapeValueLength)
{
Get<CompileOptions>(compileOptionsHandle).InputShape = StringToArrayInt32(ToString(inputShapeValue, inputShapeValueLength));
}
[UnmanagedCallersOnly]
private static void CompileOptionsSetInputRange(IntPtr compileOptionsHandle, byte* inputRangeValue, nuint inputRangeValueLength)
{
Get<CompileOptions>(compileOptionsHandle).InputRange = StringToArrayFloat(ToString(inputRangeValue, inputRangeValueLength));
}
[UnmanagedCallersOnly]
private static void CompileOptionsSetSwapRB(IntPtr compileOptionsHandle, byte swapRBValue)
{
switch (swapRBValue)
{
case 0:
Get<CompileOptions>(compileOptionsHandle).SwapRB = false;
break;
case 1:
Get<CompileOptions>(compileOptionsHandle).SwapRB = true;
break;
default:
throw new ArgumentException("Invalid SwapRB Flag");
}
}
[UnmanagedCallersOnly]
private static void CompileOptionsSetLetterBoxValue(IntPtr compileOptionsHandle, float letterBoxValue)
{
Get<CompileOptions>(compileOptionsHandle).LetterBoxValue = letterBoxValue;
}
[UnmanagedCallersOnly]
private static void CompileOptionsSetMean(IntPtr compileOptionsHandle, byte* meanValue, nuint meanValueLength)
{
Get<CompileOptions>(compileOptionsHandle).Mean = StringToArrayFloat(ToString(meanValue, meanValueLength));
}
[UnmanagedCallersOnly]
private static void CompileOptionsSetStd(IntPtr compileOptionsHandle, byte* stdValue, nuint stdValueLength)
{
Get<CompileOptions>(compileOptionsHandle).Std = StringToArrayFloat(ToString(stdValue, stdValueLength));
}
[UnmanagedCallersOnly]
private static IntPtr CompileSessionCreate(IntPtr targetHandle, IntPtr compileOptionsHandle)
{
@ -519,6 +633,18 @@ public static unsafe class CApi
private static string ToString(byte* bytes, nuint length) =>
Encoding.UTF8.GetString(bytes, (int)length);
private static int[] StringToArrayInt32(string value)
{
var data = value.Replace(" ", string.Empty, StringComparison.OrdinalIgnoreCase).Split(",");
return Array.ConvertAll(data, int.Parse);
}
private static float[] StringToArrayFloat(string value)
{
var data = value.Replace(" ", string.Empty, StringComparison.OrdinalIgnoreCase).Split(',');
return Array.ConvertAll(data, float.Parse);
}
private class CCalibrationDatasetProvider : ICalibrationDatasetProvider
{
public CCalibrationDatasetProvider(IAsyncEnumerable<IReadOnlyDictionary<Var, IValue>> samples, int samplesCount)

View File

@ -7,6 +7,24 @@ using Nncase.Quantization;
namespace Nncase;
public enum InputType : int
{
/// <summary>
/// uint8.
/// </summary>
Uint8,
/// <summary>
/// int8.
/// </summary>
Int8,
/// <summary>
/// float32.
/// </summary>
Float32,
}
/// <summary>
/// Compile options.
/// </summary>
@ -36,4 +54,59 @@ public sealed record CompileOptions
/// Gets or sets quant options.
/// </summary>
public QuantizeOptions QuantizeOptions { get; set; } = QuantizeOptions.CreateNoQuant();
/// <summary>
/// Gets or sets a value indicating whether gets or sets the preprocess.
/// </summary>
public bool PreProcess { get; set; }
/// <summary>
/// Gets or sets the input layout.
/// </summary>
public string InputLayout { get; set; } = string.Empty;
/// <summary>
/// Gets or sets the output type.
/// </summary>
public string OutputLayout { get; set; } = string.Empty;
/// <summary>
/// Gets or sets the input type.
/// </summary>
public InputType InputType { get; set; } = InputType.Float32;
/// <summary>
/// Gets or sets the input shape.
/// </summary>
public int[] InputShape { get; set; } = Array.Empty<int>();
/// <summary>
/// Gets or sets the input range.
/// </summary>
public float[] InputRange { get; set; } = Array.Empty<float>();
/// <summary>
/// Gets or sets a value indicating whether gets or sets the swapRB.
/// </summary>
public bool SwapRB { get; set; }
/// <summary>
/// Gets or sets the letterbox_value.
/// </summary>
public float LetterBoxValue { get; set; }
/// <summary>
/// Gets or sets the mean.
/// </summary>
public float[] Mean { get; set; } = Array.Empty<float>();
/// <summary>
/// Gets or sets the std.
/// </summary>
public float[] Std { get; set; } = Array.Empty<float>();
/// <summary>
/// Gets or sets the std.
/// </summary>
public string ModelLayout { get; set; } = string.Empty;
}

View File

@ -26,6 +26,7 @@ public static class Importers
/// <returns>Imported IR module.</returns>
public static IRModule ImportTFLite(Stream tflite, CompileSession compileSession)
{
compileSession.CompileOptions.ModelLayout = "NHWC";
var model = new byte[tflite.Length];
tflite.Read(model);
var importer = new TFLiteImporter(model, compileSession);
@ -40,6 +41,7 @@ public static class Importers
/// <returns>Imported IR module.</returns>
public static IRModule ImportOnnx(Stream onnx, CompileSession compileSession)
{
compileSession.CompileOptions.ModelLayout = "NCHW";
var importer = new OnnxImporter(onnx, compileSession);
return importer.Import();
}

View File

@ -0,0 +1,148 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading.Tasks;
using Nncase.IR;
using Nncase.IR.Imaging;
using Nncase.IR.Math;
using Nncase.Passes;
using Nncase.PatternMatch;
using OrtKISharp;
using static Nncase.IR.F.Math;
using static Nncase.IR.F.NN;
using static Nncase.IR.F.Tensors;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.Utility;
using Pad = Nncase.IR.NN.Pad;
namespace Nncase.Passes.Rules.Neutral;
/// <summary>
/// Add preprocess in model.
/// </summary>
[RuleGenerator]
public sealed class AddPreProcess : ModulePass
{
protected override Task<IRModule> RunCoreAsync(IRModule module, RunPassContext options)
{
var preProcess = CompileSession.CompileOptions.PreProcess;
var inputLayout = CompileSession.CompileOptions.InputLayout;
var inputType = CompileSession.CompileOptions.InputType;
var inputShape = CompileSession.CompileOptions.InputShape;
var inputRange = CompileSession.CompileOptions.InputRange;
var swapRB = CompileSession.CompileOptions.SwapRB;
var letterBoxValue = CompileSession.CompileOptions.LetterBoxValue;
var mean = CompileSession.CompileOptions.Mean;
var std = CompileSession.CompileOptions.Std;
var modelLayout = CompileSession.CompileOptions.ModelLayout;
var entry = (IR.Function)module.Entry!;
var newType = new[] { DataTypes.UInt8, DataTypes.Int8, DataTypes.Float32 };
if (!preProcess)
{
return Task.FromResult(module);
}
var a = new Var(new TensorType(newType[(int)inputType], inputShape));
foreach (var input in entry.Parameters)
{
Expr newInput = a;
var oldShape = input.CheckedShape;
int n, c, h, w;
if (inputLayout == "NHWC")
{
(n, h, w, c) = (inputShape[0], inputShape[1], inputShape[2], inputShape[3]);
}
else
{
(n, c, h, w) = (inputShape[0], inputShape[1], inputShape[2], inputShape[3]);
}
// Convert new input to NCHW
if (inputLayout == "NHWC")
{
newInput = Transpose(newInput, new int[4] { 0, 3, 1, 2 });
}
// SwapRB
if (swapRB)
{
var axes = new int[4] { 0, 1, 2, 3 };
var strides = new int[4] { 1, 1, 1, 1 };
newInput = Concat(
new IR.Tuple(new[] { Slice(newInput, new int[4] { 0, 2, 0, 0 }, new int[4] { n, 3, h, w }, axes, strides),
Slice(newInput, new int[4] { 0, 1, 0, 0 }, new int[4] { n, 2, h, w }, axes, strides),
Slice(newInput, new int[4] { 0, 0, 0, 0 }, new int[4] { n, 1, h, w }, axes, strides), }),
1);
// TODO: fix slice neg strides shape inference
// newInput = Slice(newInput, new int[] {n, c, h, w },new[] { 0, 0, 0, 0 }, axes, strides);
}
// Dequantize to float
if (inputType != InputType.Float32)
{
var qP = QuantParamOf(QuantMode.UnsignedMode, new[] { inputRange[0], inputRange[1] }, 8);
newInput = Dequantize(newInput, qP, DataTypes.Float32);
}
// Letterbox
{
int modelH, modelW;
if (modelLayout != "NCHW")
{
(modelH, modelW) = (oldShape[1].FixedValue, oldShape[2].FixedValue);
}
else
{
(modelH, modelW) = (oldShape[2].FixedValue, oldShape[3].FixedValue);
}
var ratio = Math.Min(modelH / (float)h, modelW / (float)w);
var pads = Tensor.From<int>(new[] { 0, 0, 0, 0, 0, 0, 0, 0 }, new Shape(new[] { 4, 2 }));
var resizeH = Math.Round(h * ratio);
var resizeW = Math.Round(w * ratio);
var padH = modelH - resizeH;
var padW = modelW - resizeW;
var resizeShape = new int[] { n, c, (int)resizeH, (int)resizeW };
pads[2, 0] = (int)Math.Round((padH / 2) - 0.1);
pads[2, 1] = (int)padH - (int)Math.Round((padH / 2) - 0.1);
pads[3, 0] = (int)Math.Round((padW / 2) - 0.1);
pads[3, 1] = (int)padW - (int)Math.Round((padW / 2) - 0.1);
newInput = IR.F.NN.Pad(IR.F.Imaging.ResizeImage(ImageResizeMode.Bilinear, newInput, float.NaN, resizeShape, ImageResizeTransformationMode.HalfPixel), pads, PadMode.Constant, letterBoxValue);
}
// Normalization
if (mean.Length != 0)
{
newInput = (newInput - Tensor.From(mean, new[] { 1, 3, 1, 1 })) / Tensor.From(std, new[] { 1, 3, 1, 1 });
// newInput = Binary(BinaryOp.Div, Binary(BinaryOp.Sub, newInput, Tensor.From(mean, new []{1,3,1,1})), Const.FromTensor(std) );
}
// Convert to model layout
if (modelLayout == "NHWC")
{
newInput = Transpose(newInput, new[] { 0, 2, 3, 1 });
}
var y = new Passes.Mutators.Substitutor(expr => object.ReferenceEquals(expr, input) ? newInput : null).Rewrite(entry.Body);
var x = (Function)new Passes.Mutators.Substitutor(expr => object.ReferenceEquals(expr, input) ? a : null).Rewrite(entry);
}
return Task.FromResult(module);
}
}

View File

@ -188,7 +188,7 @@ public class UnitTestTypeInfer : UnitTypeInferBase
ImageResizeMode.NearestNeighbor,
IR.F.Random.Uniform(DataTypes.Float32, 0, 2, 1, new[] { 1, 3, 34, 67 }),
float.NaN,
Const.FromShape(new[] { 32, 48 }));
Const.FromShape(new[] { 1, 3, 32, 48 }));
Assert.True(CompilerServices.InferenceType(resize));
Assert.True(HasShape(new[] { 1, 3, 32, 48 }).MatchLeaf(resize.CheckedType!));
@ -196,7 +196,7 @@ public class UnitTestTypeInfer : UnitTypeInferBase
ImageResizeMode.NearestNeighbor,
IR.F.Random.Uniform(DataTypes.Float32, 0, 2, 1, new[] { 3, 34, 67 }),
float.NaN,
Const.FromShape(new[] { 32, 48 }));
Const.FromShape(new[] { 32, 48, 67 }));
Assert.True(CompilerServices.InferenceType(resize2));
Assert.True(HasShape(new[] { 32, 48, 67 }).MatchLeaf(resize2.CheckedType!));

View File

@ -42,7 +42,7 @@ class CaffeTestRunner(TestRunner):
for input in self.inputs:
caffe_model.blobs[input['name']].data[...] = self.transform_input(
self.data_pre_process(input['data']), "float32", "CPU")
self.data_pre_process(input['data']), "float32", "CPU")[0]
outputs = caffe_model.forward()

View File

@ -13,9 +13,6 @@ case: # case的配置应该是一个多层次的
- name: swapRB
values:
- false
- name: input_shape
values:
- [1,224,224,3]
- name: mean
values:
- [0,0,0]
@ -28,12 +25,18 @@ case: # case的配置应该是一个多层次的
- name: input_type
values:
- uint8
- name: input_shape
values:
- [1,224,224,3]
- name: input_layout
values:
- NHWC
- name: output_layout
values:
- NHWC
- name: model_layout
values:
- NHWC
- name: letterbox_value
values:
- 0.

View File

@ -3,6 +3,7 @@ import os
import nncase
import numpy as np
import test_utils
import preprocess_utils
class Evaluator:
@ -25,20 +26,21 @@ class Evaluator:
compile_options.dump_dir = eval_dir
compile_options.dump_asm = cfg.compile_opt.dump_asm
compile_options.dump_ir = cfg.compile_opt.dump_ir
compile_options = preprocess_utils.update_compile_options(compile_options, preprocess)
self.compiler = nncase.Compiler(compile_options)
self.import_model(self.compiler, model_content, import_options)
self.set_quant_opt(cfg, kwargs, preprocess, self.compiler)
evaluator = self.compiler.create_evaluator(3)
self.set_inputs(evaluator)
self.set_inputs(evaluator, preprocess)
evaluator.run()
eval_output_paths = self.dump_outputs(eval_dir, evaluator)
return eval_output_paths
def set_inputs(self, evaluator):
for i in range(len(self.inputs)):
def set_inputs(self, evaluator, preprocess):
for idx, i in enumerate(self.inputs):
input_tensor: nncase.RuntimeTensor = nncase.RuntimeTensor.from_numpy(
self.transform_input(self.data_pre_process(self.inputs[i]['data'])[0], "float32", "CPU"))
evaluator.set_input_tensor(i, input_tensor)
self.transform_input((i['data']), preprocess['input_type'], "infer")[0])
evaluator.set_input_tensor(idx, input_tensor)
def dump_outputs(self, eval_dir, evaluator):
eval_output_paths = []

View File

@ -23,42 +23,45 @@ def _make_module(v_shape):
class BinaryModule(torch.nn.Module):
def __init__(self):
super(BinaryModule, self).__init__()
self.v = torch.from_numpy(np.random.rand(*v_shape).astype(np.float32))
# self.v = torch.from_numpy(np.random.rand(*v_shape).astype(np.float32))
self.v = torch.from_numpy(np.ones(v_shape).astype(np.float32))
def forward(self, x):
outs = []
outs.append(torch.add(x, self.v))
outs.append(torch.mul(x, self.v))
outs.append(torch.sub(x, self.v))
outs.append(torch.max(x, self.v))
outs.append(torch.div(x, self.v))
outs.append(torch.min(x, self.v))
outs.append(torch.fmod(x, self.v))
# outs.append(torch.mul(x, self.v))
# outs.append(torch.sub(x, self.v))
# outs.append(torch.max(x, self.v))
# outs.append(torch.div(x, self.v))
# outs.append(torch.min(x, self.v))
# outs.append(torch.fmod(x, self.v))
return outs
return BinaryModule()
lhs_shapes = [
[3],
[64, 3],
[3, 64, 3],
[8, 3, 64, 3]
# [3],
# [64, 3],
# [3, 64, 3],
# [8, 3, 64, 3]
[1, 3, 24, 24]
]
rhs_shapes = [
[1],
[3],
[1, 3],
[64, 1],
[64, 3],
[3, 64, 1],
[3, 64, 3],
[8, 3, 64, 1],
[8, 3, 64, 3],
[8, 3, 1, 3],
[8, 1, 64, 3],
[1, 3, 64, 1]
# [1],
# [3],
# [1, 3],
# [64, 1],
# [64, 3],
# [3, 64, 1],
# [3, 64, 3],
# [8, 3, 64, 1],
# [8, 3, 64, 3],
# [8, 3, 1, 3],
# [8, 1, 64, 3],
# [1, 3, 64, 1]
[1, 3, 24, 24]
]

View File

@ -46,7 +46,7 @@ rhs_shapes = [
def test_exchannel(lhs_shape, rhs_shape, request):
module = _make_module(rhs_shape)
overwrite_cfg = """
case:
case:
preprocess_opt:
- name: preprocess
values:
@ -54,9 +54,6 @@ case:
- name: swapRB
values:
- true
- name: input_shape
values:
- [1,3,224,224]
- name: mean
values:
- [0,0,0]
@ -69,13 +66,19 @@ case:
- name: input_type
values:
- uint8
- name: input_shape
values:
- [1,3,224,224]
- name: input_layout
values:
- NCHW
- name: output_layout
values:
- NCHW
- name: letter_value
- name: model_layout
values:
- NCHW
- name: letterbox_value
values:
- 0.
"""

View File

@ -46,7 +46,7 @@ rhs_shapes = [
def test_letterbox(lhs_shape, rhs_shape, request):
module = _make_module(rhs_shape)
overwrite_cfg = """
case:
case:
preprocess_opt:
- name: preprocess
values:
@ -54,9 +54,6 @@ case:
- name: swapRB
values:
- false
- name: input_shape
values:
- [1,224,224,3]
- name: mean
values:
- [0,0,0]
@ -69,14 +66,19 @@ case:
- name: input_type
values:
- uint8
- name: input_shape
values:
- [1,3,224,224]
- name: input_layout
values:
- NHWC
- NCHW
- name: output_layout
values:
- NHWC
- NCHW
- name: letter_value
- name: model_layout
values:
- NCHW
- name: letterbox_value
values:
- 0.
"""

View File

@ -48,7 +48,7 @@ rhs_shapes = [
def test_letterbox(lhs_shape, rhs_shape, request):
module = _make_module(rhs_shape)
overwrite_cfg = """
case:
case:
preprocess_opt:
- name: preprocess
values:
@ -56,9 +56,6 @@ case:
- name: swapRB
values:
- false
- name: input_shape
values:
- [1,3,224,224]
- name: mean
values:
- [0,0,0]
@ -71,13 +68,19 @@ case:
- name: input_type
values:
- uint8
- name: input_shape
values:
- [1,3,224,224]
- name: input_layout
values:
- NCHW
- name: output_layout
values:
- NCHW
- name: letter_value
- name: model_layout
values:
- NCHW
- name: letterbox_value
values:
- 114.
"""

View File

@ -46,7 +46,7 @@ rhs_shapes = [
def test_letterbox(lhs_shape, rhs_shape, request):
module = _make_module(rhs_shape)
overwrite_cfg = """
case:
case:
preprocess_opt:
- name: preprocess
values:
@ -54,9 +54,6 @@ case:
- name: swapRB
values:
- false
- name: input_shape
values:
- [1,3,224,224]
- name: mean
values:
- [123,114,109]
@ -69,13 +66,19 @@ case:
- name: input_type
values:
- uint8
- name: input_shape
values:
- [1,3,224,224]
- name: input_layout
values:
- NCHW
- name: output_layout
values:
- NCHW
- name: letter_value
- name: model_layout
values:
- NCHW
- name: letterbox_value
values:
- 0.
"""

View File

@ -30,35 +30,37 @@ def _make_module(in_shape, v_shape):
def __call__(self, x):
outs = []
outs.append(x + self.v)
outs.append(x - self.v)
outs.append(x * self.v)
outs.append(self.v / (2.0 + x))
outs.append(tf.minimum(x, self.v))
outs.append(tf.maximum(x, self.v))
# outs.append(x - self.v)
# outs.append(x * self.v)
# outs.append(self.v / (2.0 + x))
# outs.append(tf.minimum(x, self.v))
# outs.append(tf.maximum(x, self.v))
return outs
return BinaryModule()
lhs_shapes = [
[3],
[64, 3],
[3, 64, 3],
[8, 3, 64, 3]
# [3],
# [64, 3],
# [3, 64, 3],
# [8, 3, 64, 3]
[1, 24, 24, 3]
]
rhs_shapes = [
[1],
[3],
[1, 3],
[64, 1],
[64, 3],
[3, 64, 1],
[3, 64, 3],
[8, 3, 64, 1],
[8, 3, 64, 3],
[8, 3, 1, 3],
[8, 1, 64, 3],
[1, 3, 64, 1]
# [1],
# [3],
# [1, 3],
# [64, 1],
# [64, 3],
# [3, 64, 1],
# [3, 64, 3],
# [8, 3, 64, 1],
# [8, 3, 64, 3],
# [8, 3, 1, 3],
# [8, 1, 64, 3],
# [1, 3, 64, 1]
[1, 24, 24, 3]
]

View File

@ -49,7 +49,7 @@ rhs_shapes = [
def test_exchannel(lhs_shape, rhs_shape, request):
module = _make_module(lhs_shape, rhs_shape)
overwrite_cfg = """
case:
case:
preprocess_opt:
- name: preprocess
values:
@ -57,9 +57,6 @@ case:
- name: swapRB
values:
- true
- name: input_shape
values:
- [1,224,224,3]
- name: mean
values:
- [0,0,0]
@ -72,13 +69,19 @@ case:
- name: input_type
values:
- uint8
- name: input_shape
values:
- [1,224,224,3]
- name: input_layout
values:
- NHWC
- name: output_layout
values:
- NHWC
- name: letter_value
- name: model_layout
values:
- NHWC
- name: letterbox_value
values:
- 0.
"""

View File

@ -49,7 +49,7 @@ rhs_shapes = [
def test_layout(lhs_shape, rhs_shape, request):
module = _make_module(lhs_shape, rhs_shape)
overwrite_cfg = """
case:
case:
preprocess_opt:
- name: preprocess
values:
@ -57,9 +57,6 @@ case:
- name: swapRB
values:
- false
- name: input_shape
values:
- [1,3,224,224]
- name: mean
values:
- [0,0,0]
@ -72,14 +69,19 @@ case:
- name: input_type
values:
- uint8
- name: input_shape
values:
- [1,3,224,224]
- name: input_layout
values:
- NCHW
- name: output_layout
values:
- NCHW
- NHWC
- name: letter_value
- name: model_layout
values:
- NHWC
- name: letterbox_value
values:
- 0.
"""

View File

@ -51,7 +51,7 @@ rhs_shapes = [
def test_letterbox(lhs_shape, rhs_shape, request):
module = _make_module(lhs_shape, rhs_shape)
overwrite_cfg = """
case:
case:
preprocess_opt:
- name: preprocess
values:
@ -59,9 +59,6 @@ case:
- name: swapRB
values:
- false
- name: input_shape
values:
- [1,224,224,3]
- name: mean
values:
- [0,0,0]
@ -74,13 +71,19 @@ case:
- name: input_type
values:
- uint8
- name: input_shape
values:
- [1,224,224,3]
- name: input_layout
values:
- NHWC
- name: output_layout
values:
- NHWC
- name: letter_value
- name: model_layout
values:
- NHWC
- name: letterbox_value
values:
- 114.
"""

View File

@ -49,7 +49,7 @@ rhs_shapes = [
def test_norm(lhs_shape, rhs_shape, request):
module = _make_module(lhs_shape, rhs_shape)
overwrite_cfg = """
case:
case:
preprocess_opt:
- name: preprocess
values:
@ -57,9 +57,6 @@ case:
- name: swapRB
values:
- false
- name: input_shape
values:
- [1,224,224,3]
- name: mean
values:
- [123,114,109]
@ -72,13 +69,19 @@ case:
- name: input_type
values:
- uint8
- name: input_shape
values:
- [1,224,224,3]
- name: input_layout
values:
- NHWC
- name: output_layout
values:
- NHWC
- name: letter_value
- name: model_layout
values:
- NHWC
- name: letterbox_value
values:
- 0.
"""

View File

@ -65,16 +65,16 @@ class Inference:
return compile_options
def set_infer_input(self, preprocess, case_dir, sim):
for i in range(len(self.inputs)):
for idx, value in enumerate(self.inputs):
data = self.transform_input(
self.inputs[i]['data'], preprocess['input_type'], "infer")[0]
value['data'], preprocess['input_type'], "infer")[0]
dtype = preprocess['input_type']
if preprocess['preprocess'] and dtype != 'float32':
if not test_utils.in_ci():
data.tofile(os.path.join(case_dir, f'input_{i}_{dtype}.bin'))
self.totxtfile(os.path.join(case_dir, f'input_{i}_{dtype}.txt'), data)
data.tofile(os.path.join(case_dir, f'input_{idx}_{dtype}.bin'))
self.totxtfile(os.path.join(case_dir, f'input_{idx}_{dtype}.txt'), data)
sim.set_input_tensor(i, nncase.RuntimeTensor.from_numpy(data))
sim.set_input_tensor(idx, nncase.RuntimeTensor.from_numpy(data))
def dump_infer_output(self, infer_dir, preprocess, sim):
infer_output_paths = []

View File

@ -192,9 +192,16 @@ class OnnxTestRunner(TestRunner):
sess = ort.InferenceSession(model_file)
input_dict = {}
for input in self.inputs:
input_dict[input['name']] = self.transform_input(
for i, input in enumerate(self.inputs):
new_value = self.transform_input(
self.data_pre_process(input['data']), "float32", "CPU")[0]
input_dict[input['name']] = new_value
if self.pre_process[0]['preprocess']:
bin_file = os.path.join(case_dir, f'frame_input_{i}.bin')
text_file = os.path.join(case_dir, f'frame_input_{i}.txt')
new_value[0].tofile(bin_file)
if not test_utils.in_ci():
self.totxtfile(text_file, new_value)
outputs = sess.run(None, input_dict)
i = 0

17
tests/preprocess_utils.py Normal file
View File

@ -0,0 +1,17 @@
def update_compile_options(compile_options, preprocess):
'''
update compile_options by preprocess options
'''
compile_options.preprocess = preprocess['preprocess']
compile_options.input_layout = preprocess['input_layout']
compile_options.output_layout = preprocess['output_layout']
compile_options.input_type = preprocess['input_type']
compile_options.input_shape = preprocess['input_shape']
compile_options.input_range = preprocess['input_range']
compile_options.swapRB = preprocess['swapRB']
compile_options.letterbox_value = preprocess['letterbox_value']
compile_options.mean = preprocess['mean']
compile_options.std = preprocess['std']
return compile_options

View File

@ -247,25 +247,25 @@ class TestRunner(Evaluator, Inference, metaclass=ABCMeta):
def transform_input(self, values: List[np.ndarray], type: str, stage: str) -> List[np.ndarray]:
new_values = []
for value in values:
new_value = value
if(len(value.shape) == 4 and self.pre_process[0]['preprocess']):
new_value = copy.deepcopy(value)
if(len(new_value.shape) == 4 and self.pre_process[0]['preprocess']):
if stage == "CPU":
# onnx \ caffe
if (self.model_type == "onnx" or self.model_type == "caffe"):
new_value = np.transpose(value, [0, 3, 1, 2])
new_value = np.transpose(new_value, [0, 3, 1, 2])
if type == 'float32':
new_value = value.astype(np.float32)
new_value = new_value.astype(np.float32)
elif type == 'uint8':
if value.dtype == np.float32:
new_value = (value * 255).astype(np.uint8)
if new_value.dtype == np.float32:
new_value = (new_value * 255).astype(np.uint8)
elif type == 'int8':
if value.dtype == np.float32:
new_value = (value * 255 - 128).astype(np.int8)
if new_value.dtype == np.float32:
new_value = (new_value * 255 - 128).astype(np.int8)
else:
raise TypeError(" Not support type for quant input")
new_values.append(value)
return values
new_values.append(new_value)
return np.array(new_values)
def get_process_config(self, config):
# preprocess flag
@ -300,6 +300,7 @@ class TestRunner(Evaluator, Inference, metaclass=ABCMeta):
# get layout
process_layout = {}
process_layout['input_layout'] = config['input_layout']
process_layout['model_layout'] = config['model_layout']
self.pre_process.append(preprocess_flag)
self.pre_process.append(process_deq)
@ -314,71 +315,70 @@ class TestRunner(Evaluator, Inference, metaclass=ABCMeta):
new_value = copy.deepcopy(value)
if self.pre_process[0]['preprocess'] and len(value.shape) == 4:
if self.pre_process[-1]['input_layout'] == 'NCHW':
new_value = np.transpose(value, [0, 2, 3, 1])
if self.pre_process[3]['input_type'] == "uint8":
new_value = value * 255.
# elif self.cfg.case.compile_opt.kwargs['input_type'] == "int8":
# data *= 255.
# data -= 128.
for item in self.pre_process:
new_value = np.transpose(new_value, [0, 2, 3, 1])
for item in self.pre_process[1:]:
# dequantize
if 'range' in item.keys() and 'input_type' in item.keys():
Q_max, Q_min = 0, 0
if item['input_type'] == 'uint8':
Q_max, Q_min = 255, 0
# elif item['input_type'] == 'int8':
# Q_max, Q_min = 127, -128
else:
continue
scale = (item['range'][1] - item['range'][0]) / (Q_max - Q_min)
bias = round((item['range'][1] * Q_min - item['range'][0] *
Q_max) / (item['range'][1] - item['range'][0]))
new_value = value * scale
new_value = new_value * scale
new_value = new_value - bias
# swapRB
if 'swapRB' in item.keys():
if value.shape[-1] != 3:
if new_value.shape[-1] != 3:
assert("Please confirm your input channel is 3.")
if item['swapRB'] == True:
new_value = value[:, :, :, ::-1]
new_value = new_value[:, :, :, ::-1]
new_value = np.array(new_value)
# LetterBox
if 'input_range' in item.keys() and 'input_shape' in item.keys() and 'model_shape' in item.keys():
if item['input_shape'] != []:
model_shape: List = []
if self.model_type == "onnx" or self.model_type == "caffe":
model_shape = [1, item['model_shape'][2],
if self.model_type in ["onnx", "caffe"] and self.pre_process[-1]['model_layout'] != 'NHWC':
model_shape = [item['model_shape'][0], item['model_shape'][2],
item['model_shape'][3], item['model_shape'][1]]
else:
model_shape = item['model_shape']
if model_shape[1] != value.shape[1] or model_shape[2] != value.shape[2]:
in_h, in_w = value.shape[1], value.shape[2]
if model_shape[1] != new_value.shape[1] or model_shape[2] != new_value.shape[2]:
in_h, in_w = new_value.shape[1], new_value.shape[2]
model_h, model_w = model_shape[1], model_shape[2]
ratio = min(model_h / in_h, model_w / in_w)
resize_shape = value.shape[0], round(
resize_shape = new_value.shape[0], round(
in_h * ratio), round(in_w * ratio), 3
resize_data = cv2.resize(value[0], (resize_shape[2],
resize_shape[1]), interpolation=cv2.INTER_LINEAR)
dh = model_shape[1] - resize_shape[1]
dw = model_shape[2] - resize_shape[2]
dh /= 2
dw /= 2
resize_data = np.array(resize_data, dtype=np.float32)
new_value = cv2.copyMakeBorder(resize_data, round(dh - 0.1), round(model_h - resize_shape[1] - round(dh - 0.1)), round(dw - 0.1), round(
model_w - resize_shape[2] - round(dw - 0.1)), cv2.BORDER_CONSTANT, value=(item['letterbox_value'], item['letterbox_value'], item['letterbox_value']))
resize_data = np.random.rand(*model_shape)
for batch_data in new_value:
tmp = cv2.resize(
batch_data, (resize_shape[2], resize_shape[1]), interpolation=cv2.INTER_LINEAR)
new_value = np.array(new_value, dtype=np.float32)
new_value = np.expand_dims(new_value, 0)
dh = model_shape[1] - resize_shape[1]
dw = model_shape[2] - resize_shape[2]
dh /= 2
dw /= 2
tmp = np.array(tmp, dtype=np.float32)
tmp = cv2.copyMakeBorder(tmp, round(dh - 0.1), round(model_h - resize_shape[1] - round(dh - 0.1)), round(dw - 0.1), round(
model_w - resize_shape[2] - round(dw - 0.1)), cv2.BORDER_CONSTANT, value=(item['letterbox_value'], item['letterbox_value'], item['letterbox_value']))
tmp = np.expand_dims(tmp, 0)
print("resize_data.shape = ", resize_data.shape)
print("tmp.shape = ", tmp.shape)
resize_data = np.concatenate([resize_data, tmp], axis=0)
new_value = np.array(resize_data[1:], dtype=np.float32)
# Normalize
if 'norm' in item.keys():
for i in range(value.shape[-1]):
for i in range(new_value.shape[-1]):
k = i
if value.shape[-1] > 3:
if new_value.shape[-1] > 3:
k = 0
new_value[:, :, :, i] = (value[:, :, :, i] - float(item['norm']['mean'][k])) / \
new_value[:, :, :, i] = (new_value[:, :, :, i] - float(item['norm']['mean'][k])) / \
float(item['norm']['std'][k])
else:
assert("Please confirm your input shape and model shape is 4D!")
@ -510,8 +510,9 @@ class TestRunner(Evaluator, Inference, metaclass=ABCMeta):
# compiler.dump_range_options(dump_range_options)
if kwargs['ptq']:
ptq_options = nncase.PTQTensorOptions()
ptq_options.set_tensor_data([self.transform_input(
sample['data'], preprocess['input_type'], "infer") for sample in self.calibs])
data = [self.transform_input(
sample['data'], preprocess['input_type'], "infer") for sample in self.calibs]
ptq_options.set_tensor_data(data)
ptq_options.samples_count = cfg.generate_calibs.numbers
ptq_options.calibrate_method = cfg.compile_opt.calibrate_method
ptq_options.quant_type = cfg.compile_opt.quant_type
@ -558,6 +559,9 @@ class TestRunner(Evaluator, Inference, metaclass=ABCMeta):
compile_options.output_layout = cfg['output_layout']
for k, v in cfg.items():
# TODO: support model with unusual layout e.g.: onnx->NHWC
if k == "model_layout":
continue
e = '"'
exec(f"compile_options.{k} = {e + v + e if isinstance(v, str) else v}")
return import_options, compile_options
@ -599,16 +603,21 @@ class TestRunner(Evaluator, Inference, metaclass=ABCMeta):
os.mkdir(os.path.join(case_dir, name))
for idx, input in enumerate(inputs):
samples = []
shape = copy.deepcopy(input['model_shape'])
# if preprocess_opt['preprocess'] and preprocess_opt['input_shape'] != [] and len(preprocess_opt['input_shape']) == 4:
# shape = copy.deepcopy(preprocess_opt['input_shape'])
# else:
# shape = copy.deepcopy(input['model_shape'])
shape = []
dtype = input['dtype']
if preprocess_opt['preprocess'] and preprocess_opt['input_shape'] != []:
shape = copy.deepcopy(preprocess_opt['input_shape'])
if preprocess_opt['input_type'] == "uint8":
dtype = np.uint8
elif preprocess_opt['input_type'] == "float32":
dtype = np.float32
else:
shape = copy.deepcopy(input['model_shape'])
if shape[0] != cfg.batch_size:
shape[0] *= cfg.batch_size
for n in range(cfg.numbers):
data = DataFactory[cfg.name](shape, input['dtype'], n,
data = DataFactory[cfg.name](shape, dtype, n,
cfg.batch_size, idx, **cfg.kwargs)
if not test_utils.in_ci():
path_list.append(

View File

@ -78,8 +78,16 @@ class TfliteTestRunner(TestRunner):
def cpu_infer(self, case_dir: str, model_file: bytes, type: str):
interp = tf.lite.Interpreter(model_path=model_file)
interp.allocate_tensors()
for input in self.inputs:
interp.set_tensor(input["index"], self.data_pre_process(input['data'])[0])
for idx, value in enumerate(self.inputs):
new_value = self.transform_input(
self.data_pre_process(value['data']), "float32", "CPU")[0]
interp.set_tensor(value["index"], new_value)
if self.pre_process[0]['preprocess']:
bin_file = os.path.join(case_dir, f'frame_input_{idx}.bin')
text_file = os.path.join(case_dir, f'frame_input_{idx}.txt')
new_value[0].tofile(bin_file)
if not test_utils.in_ci():
self.totxtfile(text_file, new_value)
interp.invoke()