From f2d97520abc455f2d719047c6ea0deb4767c74db Mon Sep 17 00:00:00 2001 From: chenjianping Date: Sat, 8 Aug 2020 16:18:00 +0800 Subject: [PATCH] support infershape when running graph --- mindspore/lite/include/context.h | 6 ++ mindspore/lite/include/errorcode.h | 5 +- mindspore/lite/src/executor.cc | 5 -- mindspore/lite/src/kernel_factory.cc | 2 +- mindspore/lite/src/kernel_registry.h | 1 - mindspore/lite/src/lite_kernel.h | 38 +++++++-- mindspore/lite/src/lite_session.cc | 1 + mindspore/lite/src/ops/cast.cc | 2 +- mindspore/lite/src/ops/ops.h | 3 + mindspore/lite/src/ops/reshape.cc | 15 ++-- .../kernel/arm/base/arg_min_max_base.cc | 21 +++-- .../kernel/arm/base/arg_min_max_base.h | 6 +- .../kernel/arm/base/batch_to_space_base.cc | 8 +- .../kernel/arm/base/batch_to_space_base.h | 5 +- .../runtime/kernel/arm/base/concat_base.cc | 28 ++++--- .../src/runtime/kernel/arm/base/concat_base.h | 7 +- .../kernel/arm/base/convolution_base.h | 5 +- .../src/runtime/kernel/arm/base/crop_base.cc | 24 +++--- .../src/runtime/kernel/arm/base/crop_base.h | 6 +- .../kernel/arm/base/depth_to_space_base.cc | 14 ++-- .../kernel/arm/base/depth_to_space_base.h | 5 +- .../kernel/arm/base/fullconnection_base.cc | 10 ++- .../kernel/arm/base/fullconnection_base.h | 5 +- .../runtime/kernel/arm/base/matmul_base.cc | 7 +- .../src/runtime/kernel/arm/base/matmul_base.h | 5 +- .../lite/src/runtime/kernel/arm/base/pad.cc | 8 +- .../runtime/kernel/arm/base/pooling_base.cc | 12 ++- .../runtime/kernel/arm/base/pooling_base.h | 5 +- .../src/runtime/kernel/arm/base/prelu_base.cc | 4 +- .../src/runtime/kernel/arm/base/prelu_base.h | 5 +- .../src/runtime/kernel/arm/base/prior_box.cc | 9 ++- .../src/runtime/kernel/arm/base/prior_box.h | 5 +- .../kernel/arm/base/quant_dtype_cast.cc | 13 ++- .../kernel/arm/base/quant_dtype_cast.h | 5 +- .../runtime/kernel/arm/base/reshape_base.cc | 12 +-- .../runtime/kernel/arm/base/reshape_base.h | 6 +- .../runtime/kernel/arm/base/softmax_base.cc | 8 +- .../runtime/kernel/arm/base/softmax_base.h | 5 +- .../src/runtime/kernel/arm/base/split_base.cc | 12 +-- .../src/runtime/kernel/arm/base/split_base.h | 5 +- .../runtime/kernel/arm/base/squeeze_base.cc | 4 +- .../runtime/kernel/arm/base/squeeze_base.h | 5 +- .../runtime/kernel/arm/base/strided_slice.cc | 12 ++- .../runtime/kernel/arm/base/strided_slice.h | 5 +- .../kernel/arm/fp16/convolution_3x3_fp16.cc | 14 +++- .../kernel/arm/fp16/convolution_3x3_fp16.h | 6 +- .../arm/fp16/convolution_depthwise_fp16.cc | 23 ++++-- .../arm/fp16/convolution_depthwise_fp16.h | 5 +- .../kernel/arm/fp16/convolution_fp16.cc | 22 +++-- .../kernel/arm/fp16/convolution_fp16.h | 5 +- .../arm/fp16/deconvolution_depthwise_fp16.cc | 22 +++-- .../arm/fp16/deconvolution_depthwise_fp16.h | 5 +- .../src/runtime/kernel/arm/fp32/activation.cc | 10 ++- .../src/runtime/kernel/arm/fp32/activation.h | 5 +- .../kernel/arm/fp32/activation_grad.cc | 20 ++--- .../runtime/kernel/arm/fp32/activation_grad.h | 5 +- .../lite/src/runtime/kernel/arm/fp32/addn.cc | 19 +++-- .../lite/src/runtime/kernel/arm/fp32/addn.h | 6 +- .../src/runtime/kernel/arm/fp32/argminmax.cc | 7 +- .../src/runtime/kernel/arm/fp32/argminmax.h | 6 +- .../src/runtime/kernel/arm/fp32/arithmetic.cc | 13 ++- .../src/runtime/kernel/arm/fp32/arithmetic.h | 5 +- .../kernel/arm/fp32/arithmetic_grad.cc | 5 +- .../runtime/kernel/arm/fp32/arithmetic_grad.h | 5 +- .../kernel/arm/fp32/arithmetic_self.cc | 16 +++- .../runtime/kernel/arm/fp32/arithmetic_self.h | 9 +-- .../runtime/kernel/arm/fp32/batch_to_space.cc | 5 ++ .../runtime/kernel/arm/fp32/batch_to_space.h | 6 +- .../src/runtime/kernel/arm/fp32/batchnorm.cc | 9 ++- .../src/runtime/kernel/arm/fp32/batchnorm.h | 5 +- .../lite/src/runtime/kernel/arm/fp32/bias.cc | 14 +++- .../lite/src/runtime/kernel/arm/fp32/bias.h | 8 +- .../src/runtime/kernel/arm/fp32/bias_grad.cc | 15 ++-- .../src/runtime/kernel/arm/fp32/bias_grad.h | 5 +- .../runtime/kernel/arm/fp32/bngrad_input.cc | 4 +- .../runtime/kernel/arm/fp32/bngrad_input.h | 5 +- .../runtime/kernel/arm/fp32/broadcast_to.cc | 13 ++- .../runtime/kernel/arm/fp32/broadcast_to.h | 10 +-- .../lite/src/runtime/kernel/arm/fp32/cast.cc | 22 +++-- .../lite/src/runtime/kernel/arm/fp32/cast.h | 16 ++-- .../src/runtime/kernel/arm/fp32/concat.cc | 80 +++++++++++-------- .../lite/src/runtime/kernel/arm/fp32/concat.h | 6 +- .../runtime/kernel/arm/fp32/convolution.cc | 23 ++++-- .../src/runtime/kernel/arm/fp32/convolution.h | 5 +- .../kernel/arm/fp32/convolution_1x1.cc | 9 +++ .../runtime/kernel/arm/fp32/convolution_1x1.h | 5 +- .../kernel/arm/fp32/convolution_3x3.cc | 9 +++ .../runtime/kernel/arm/fp32/convolution_3x3.h | 5 +- .../kernel/arm/fp32/convolution_depthwise.cc | 19 ++++- .../kernel/arm/fp32/convolution_depthwise.h | 6 +- .../arm/fp32/convolution_depthwise_3x3.cc | 11 ++- .../arm/fp32/convolution_depthwise_3x3.h | 5 +- .../arm/fp32/convolution_grad_filter.cc | 5 +- .../kernel/arm/fp32/convolution_grad_filter.h | 5 +- .../kernel/arm/fp32/convolution_grad_input.cc | 6 +- .../kernel/arm/fp32/convolution_grad_input.h | 5 +- .../kernel/arm/fp32/convolution_winograd.cc | 9 +++ .../kernel/arm/fp32/convolution_winograd.h | 5 +- .../lite/src/runtime/kernel/arm/fp32/crop.cc | 17 ++-- .../lite/src/runtime/kernel/arm/fp32/crop.h | 5 +- .../runtime/kernel/arm/fp32/deconvolution.cc | 14 +++- .../runtime/kernel/arm/fp32/deconvolution.h | 5 +- .../arm/fp32/deconvolution_depthwise.cc | 14 +++- .../kernel/arm/fp32/deconvolution_depthwise.h | 5 +- .../runtime/kernel/arm/fp32/depth_to_space.cc | 9 ++- .../runtime/kernel/arm/fp32/depth_to_space.h | 10 +-- .../kernel/arm/fp32/embedding_lookup.cc | 13 ++- .../kernel/arm/fp32/embedding_lookup.h | 5 +- .../src/runtime/kernel/arm/fp32/expandDims.cc | 16 +++- .../src/runtime/kernel/arm/fp32/expandDims.h | 6 +- .../lite/src/runtime/kernel/arm/fp32/fill.cc | 16 +++- .../lite/src/runtime/kernel/arm/fp32/fill.h | 6 +- .../src/runtime/kernel/arm/fp32/flatten.cc | 14 +++- .../src/runtime/kernel/arm/fp32/flatten.h | 6 +- .../runtime/kernel/arm/fp32/fullconnection.cc | 9 +++ .../runtime/kernel/arm/fp32/fullconnection.h | 6 +- .../kernel/arm/fp32/fused_batchnorm.cc | 14 +++- .../runtime/kernel/arm/fp32/fused_batchnorm.h | 6 +- .../src/runtime/kernel/arm/fp32/gather.cc | 10 ++- .../lite/src/runtime/kernel/arm/fp32/gather.h | 6 +- .../src/runtime/kernel/arm/fp32/gatherNd.cc | 16 +++- .../src/runtime/kernel/arm/fp32/gatherNd.h | 6 +- .../kernel/arm/fp32/local_response_norm.cc | 10 ++- .../kernel/arm/fp32/local_response_norm.h | 6 +- .../lite/src/runtime/kernel/arm/fp32/lstm.cc | 14 +++- .../lite/src/runtime/kernel/arm/fp32/lstm.h | 5 +- .../src/runtime/kernel/arm/fp32/matmul.cc | 9 +++ .../lite/src/runtime/kernel/arm/fp32/matmul.h | 5 +- .../src/runtime/kernel/arm/fp32/nchw2nhwc.cc | 10 ++- .../src/runtime/kernel/arm/fp32/nchw2nhwc.h | 6 +- .../src/runtime/kernel/arm/fp32/nhwc2nchw.cc | 10 ++- .../src/runtime/kernel/arm/fp32/nhwc2nchw.h | 6 +- .../src/runtime/kernel/arm/fp32/one_hot.cc | 13 ++- .../src/runtime/kernel/arm/fp32/one_hot.h | 5 +- .../runtime/kernel/arm/fp32/opt_momentum.cc | 11 ++- .../runtime/kernel/arm/fp32/opt_momentum.h | 5 +- .../lite/src/runtime/kernel/arm/fp32/pad.cc | 9 +++ .../lite/src/runtime/kernel/arm/fp32/pad.h | 5 +- .../src/runtime/kernel/arm/fp32/pooling.cc | 9 +++ .../src/runtime/kernel/arm/fp32/pooling.h | 6 +- .../runtime/kernel/arm/fp32/pooling_grad.cc | 4 +- .../runtime/kernel/arm/fp32/pooling_grad.h | 5 +- .../lite/src/runtime/kernel/arm/fp32/power.cc | 12 ++- .../lite/src/runtime/kernel/arm/fp32/power.h | 5 +- .../src/runtime/kernel/arm/fp32/power_grad.cc | 4 +- .../src/runtime/kernel/arm/fp32/power_grad.h | 13 +-- .../lite/src/runtime/kernel/arm/fp32/prelu.cc | 10 ++- .../lite/src/runtime/kernel/arm/fp32/prelu.h | 7 +- .../lite/src/runtime/kernel/arm/fp32/range.cc | 10 ++- .../lite/src/runtime/kernel/arm/fp32/range.h | 6 +- .../lite/src/runtime/kernel/arm/fp32/rank.cc | 12 ++- .../lite/src/runtime/kernel/arm/fp32/rank.h | 6 +- .../src/runtime/kernel/arm/fp32/reduce.cc | 21 +++-- .../lite/src/runtime/kernel/arm/fp32/reduce.h | 5 +- .../src/runtime/kernel/arm/fp32/reshape.cc | 6 +- .../src/runtime/kernel/arm/fp32/reshape.h | 6 +- .../src/runtime/kernel/arm/fp32/resize.cc | 14 +++- .../lite/src/runtime/kernel/arm/fp32/resize.h | 6 +- .../src/runtime/kernel/arm/fp32/reverse.cc | 16 +++- .../src/runtime/kernel/arm/fp32/reverse.h | 6 +- .../kernel/arm/fp32/reverse_sequence.cc | 14 +++- .../kernel/arm/fp32/reverse_sequence.h | 6 +- .../lite/src/runtime/kernel/arm/fp32/scale.cc | 15 +++- .../lite/src/runtime/kernel/arm/fp32/scale.h | 9 ++- .../src/runtime/kernel/arm/fp32/scatter_nd.cc | 26 +++--- .../src/runtime/kernel/arm/fp32/scatter_nd.h | 6 +- .../lite/src/runtime/kernel/arm/fp32/shape.cc | 16 ++-- .../lite/src/runtime/kernel/arm/fp32/shape.h | 6 +- .../lite/src/runtime/kernel/arm/fp32/slice.cc | 15 +++- .../lite/src/runtime/kernel/arm/fp32/slice.h | 9 +-- .../src/runtime/kernel/arm/fp32/softmax.cc | 9 +++ .../src/runtime/kernel/arm/fp32/softmax.h | 5 +- .../runtime/kernel/arm/fp32/space_to_batch.cc | 14 +++- .../runtime/kernel/arm/fp32/space_to_batch.h | 5 +- .../runtime/kernel/arm/fp32/space_to_depth.cc | 17 +++- .../runtime/kernel/arm/fp32/space_to_depth.h | 5 +- ...parse_softmax_cross_entropy_with_logits.cc | 10 ++- ...sparse_softmax_cross_entropy_with_logits.h | 6 +- .../kernel/arm/fp32/sparse_to_dense.cc | 12 ++- .../runtime/kernel/arm/fp32/sparse_to_dense.h | 6 +- .../lite/src/runtime/kernel/arm/fp32/split.cc | 11 ++- .../lite/src/runtime/kernel/arm/fp32/split.h | 5 +- .../src/runtime/kernel/arm/fp32/squeeze.cc | 16 ++-- .../src/runtime/kernel/arm/fp32/squeeze.h | 7 +- .../lite/src/runtime/kernel/arm/fp32/stack.cc | 13 ++- .../lite/src/runtime/kernel/arm/fp32/stack.h | 6 +- .../lite/src/runtime/kernel/arm/fp32/tile.cc | 15 +++- .../lite/src/runtime/kernel/arm/fp32/tile.h | 6 +- .../lite/src/runtime/kernel/arm/fp32/topk.cc | 15 +++- .../lite/src/runtime/kernel/arm/fp32/topk.h | 5 +- .../src/runtime/kernel/arm/fp32/transpose.cc | 21 +++-- .../src/runtime/kernel/arm/fp32/transpose.h | 7 +- .../src/runtime/kernel/arm/fp32/unique.cc | 12 ++- .../lite/src/runtime/kernel/arm/fp32/unique.h | 6 +- .../src/runtime/kernel/arm/fp32/unsqueeze.cc | 20 +++-- .../src/runtime/kernel/arm/fp32/unsqueeze.h | 6 +- .../src/runtime/kernel/arm/fp32/unstack.cc | 15 +++- .../src/runtime/kernel/arm/fp32/unstack.h | 10 +-- .../lite/src/runtime/kernel/arm/fp32/where.cc | 11 ++- .../lite/src/runtime/kernel/arm/fp32/where.h | 6 +- .../src/runtime/kernel/arm/fp32/zeroslike.cc | 10 ++- .../src/runtime/kernel/arm/fp32/zeroslike.h | 6 +- .../src/runtime/kernel/arm/int8/activation.cc | 10 +-- .../src/runtime/kernel/arm/int8/add_int8.cc | 14 +++- .../src/runtime/kernel/arm/int8/add_int8.h | 5 +- .../runtime/kernel/arm/int8/argminmax_int8.cc | 8 +- .../runtime/kernel/arm/int8/argminmax_int8.h | 5 +- .../kernel/arm/int8/arithmetic_int8.cc | 4 +- .../runtime/kernel/arm/int8/arithmetic_int8.h | 5 +- .../kernel/arm/int8/arithmetic_self_int8.cc | 16 +++- .../kernel/arm/int8/arithmetic_self_int8.h | 11 ++- .../kernel/arm/int8/batch_to_space_int8.cc | 5 ++ .../kernel/arm/int8/batch_to_space_int8.h | 5 +- .../runtime/kernel/arm/int8/bias_add_int8.cc | 15 +++- .../runtime/kernel/arm/int8/bias_add_int8.h | 6 +- .../runtime/kernel/arm/int8/concat_int8.cc | 12 ++- .../src/runtime/kernel/arm/int8/concat_int8.h | 6 +- .../kernel/arm/int8/convolution_3x3_int8.cc | 9 +++ .../kernel/arm/int8/convolution_3x3_int8.h | 6 +- .../arm/int8/convolution_depthwise_int8.cc | 16 +++- .../arm/int8/convolution_depthwise_int8.h | 5 +- .../kernel/arm/int8/convolution_int8.cc | 15 +++- .../kernel/arm/int8/convolution_int8.h | 5 +- .../src/runtime/kernel/arm/int8/crop_int8.cc | 7 +- .../src/runtime/kernel/arm/int8/crop_int8.h | 5 +- .../arm/int8/deconvolution_depthwise_int8.cc | 16 +++- .../arm/int8/deconvolution_depthwise_int8.h | 5 +- .../kernel/arm/int8/deconvolution_int8.cc | 13 ++- .../kernel/arm/int8/deconvolution_int8.h | 6 +- .../kernel/arm/int8/depth_to_space_int8.cc | 7 +- .../kernel/arm/int8/depth_to_space_int8.h | 5 +- .../kernel/arm/int8/fullconnection_int8.cc | 9 +++ .../kernel/arm/int8/fullconnection_int8.h | 5 +- .../runtime/kernel/arm/int8/hswish_int8.cc | 5 ++ .../src/runtime/kernel/arm/int8/hswish_int8.h | 5 +- .../runtime/kernel/arm/int8/matmul_int8.cc | 9 +++ .../src/runtime/kernel/arm/int8/matmul_int8.h | 5 +- .../src/runtime/kernel/arm/int8/mul_int8.cc | 14 +++- .../src/runtime/kernel/arm/int8/mul_int8.h | 7 +- .../src/runtime/kernel/arm/int8/pad_int8.cc | 9 +++ .../src/runtime/kernel/arm/int8/pad_int8.h | 5 +- .../runtime/kernel/arm/int8/pooling_int8.cc | 9 +++ .../runtime/kernel/arm/int8/pooling_int8.h | 5 +- .../src/runtime/kernel/arm/int8/prelu_int8.cc | 4 + .../src/runtime/kernel/arm/int8/prelu_int8.h | 5 +- .../src/runtime/kernel/arm/int8/relux_int8.cc | 10 +++ .../src/runtime/kernel/arm/int8/relux_int8.h | 15 ++-- .../runtime/kernel/arm/int8/reshape_int8.cc | 12 ++- .../runtime/kernel/arm/int8/reshape_int8.h | 6 +- .../runtime/kernel/arm/int8/sigmoid_int8.cc | 5 ++ .../runtime/kernel/arm/int8/sigmoid_int8.h | 5 +- .../runtime/kernel/arm/int8/softmax_int8.cc | 9 +++ .../runtime/kernel/arm/int8/softmax_int8.h | 5 +- .../src/runtime/kernel/arm/int8/split_int8.cc | 11 ++- .../src/runtime/kernel/arm/int8/split_int8.h | 5 +- .../runtime/kernel/arm/int8/squeeze_int8.cc | 11 ++- .../runtime/kernel/arm/int8/squeeze_int8.h | 5 +- .../src/runtime/kernel/arm/int8/topk_int8.cc | 14 +++- .../src/runtime/kernel/arm/int8/topk_int8.h | 5 +- .../runtime/kernel/arm/int8/unsqueeze_int8.cc | 15 +++- .../runtime/kernel/arm/int8/unsqueeze_int8.h | 5 +- .../runtime/kernel/arm/nnacl/common_func.cc | 8 +- .../runtime/kernel/arm/nnacl/common_func.h | 4 +- .../src/runtime/kernel/arm/nnacl/fp32/conv.cc | 8 +- .../kernel/opencl/kernel/arithmetic.cc | 2 +- .../runtime/kernel/opencl/kernel/concat.cc | 2 +- .../kernel/opencl/kernel/conv2d_transpose.cc | 3 +- .../kernel/opencl/kernel/conv2d_transpose.h | 5 +- .../kernel/opencl/kernel/convolution.cc | 2 +- .../kernel/opencl/kernel/depthwise_conv2d.cc | 3 +- .../runtime/kernel/opencl/kernel/matmul.cc | 2 +- .../src/runtime/kernel/opencl/kernel/matmul.h | 2 +- .../runtime/kernel/opencl/kernel/pooling2d.cc | 2 +- .../runtime/kernel/opencl/kernel/softmax.cc | 3 +- .../runtime/kernel/opencl/kernel/softmax.h | 2 +- .../runtime/kernel/opencl/kernel/transpose.cc | 2 +- .../src/runtime/kernel/opencl/opencl_kernel.h | 2 +- .../kernel/opencl/subgraph_opencl_kernel.h | 2 +- mindspore/lite/src/scheduler.cc | 22 +++-- mindspore/lite/src/scheduler.h | 6 +- .../kernel/arm/common/strided_slice_tests.cc | 4 +- .../kernel/arm/fp32/activation_fp32_test.cc | 2 +- .../arm/fp32/arithmetic_grad_fp32_tests.cc | 26 +++--- .../kernel/arm/fp32/batchnorm_fp32_tests.cc | 2 +- .../kernel/arm/fp32/bias_grad_fp32_tests.cc | 2 +- .../kernel/arm/fp32/conv1x1_fp32_tests.cc | 4 +- .../fp32/convolution_depthwise_fp32_tests.cc | 6 +- .../arm/fp32/convolution_grad_fp32_tests.cc | 12 +-- .../arm/fp32/deconvolution_fp32_tests.cc | 8 +- .../arm/fp32/embedding_lookup_fp32_test.cc | 2 +- .../arm/fp32/fullconnection_fp32_tests.cc | 4 +- .../kernel/arm/fp32/lstm_fp32_tests.cc | 6 +- .../kernel/arm/fp32/matmul_fp32_tests.cc | 8 +- .../arm/fp32/pooling_grad_fp32_tests.cc | 2 +- .../kernel/arm/fp32/power_fp32_tests.cc | 6 +- .../arm/fp32/space_to_batch_fp32_tests.cc | 2 +- .../arm/fp32/space_to_depth_fp32_tests.cc | 2 +- .../kernel/arm/fp32/topk_fp32_tests.cc | 2 +- .../runtime/kernel/arm/int8/add_int8_tests.cc | 2 +- .../arm/int8/arithmetic_self_int8_tests.cc | 32 ++++---- .../kernel/arm/int8/concat_int8_tests.cc | 6 +- .../kernel/arm/int8/crop_int8_tests.cc | 20 ++--- .../kernel/arm/int8/deconv_int8_tests.cc | 4 +- .../arm/int8/fullconnection_int8_tests.cc | 4 +- .../kernel/arm/int8/hswish_int8_tests.cc | 2 +- .../kernel/arm/int8/matmul_int8_tests.cc | 2 +- .../runtime/kernel/arm/int8/mul_int8_tests.cc | 8 +- .../runtime/kernel/arm/int8/pad_int8_tests.cc | 6 +- .../kernel/arm/int8/prelu_int8_tests.cc | 2 +- .../kernel/arm/int8/quant_dtype_cast_tests.cc | 4 +- .../kernel/arm/int8/relux_int8_tests.cc | 4 +- .../kernel/arm/int8/reshape_int8_tests.cc | 4 +- .../kernel/arm/int8/sigmoid_int8_tests.cc | 2 +- .../kernel/arm/int8/softmax_int8_tests.cc | 2 +- .../kernel/arm/int8/split_int8_tests.cc | 6 +- .../kernel/arm/int8/squeeze_int8_tests.cc | 2 +- .../kernel/arm/int8/topk_int8_tests.cc | 2 +- .../kernel/arm/int8/unsqueeze_int8_tests.cc | 2 +- 318 files changed, 1822 insertions(+), 911 deletions(-) diff --git a/mindspore/lite/include/context.h b/mindspore/lite/include/context.h index 02b6cd04e2..2e5dd1944a 100644 --- a/mindspore/lite/include/context.h +++ b/mindspore/lite/include/context.h @@ -64,11 +64,17 @@ class MS_API Context { /// \brief Destructor of MindSpore Lite Context. virtual ~Context(); + void InferShapeInterrupt() { + infer_shape_interrupt_ = true; + } + public: DeviceContext device_ctx_{DT_CPU}; int thread_num_ = 2; /**< thread number config for thread pool */ std::shared_ptr allocator = nullptr; CpuBindMode cpu_bind_mode_ = MID_CPU; + bool infer_shape_interrupt_ = false; + bool running_ = false; }; } // namespace mindspore::lite #endif // MINDSPORE_LITE_INCLUDE_CONTEXT_H_ diff --git a/mindspore/lite/include/errorcode.h b/mindspore/lite/include/errorcode.h index 2cdd4659de..c7c8224e07 100644 --- a/mindspore/lite/include/errorcode.h +++ b/mindspore/lite/include/errorcode.h @@ -48,8 +48,11 @@ constexpr int RET_OP_EXECUTE_FAILURE = -304; /**< Failed to execution operator. /* Tensor error code, range: [-401,-500] */ constexpr int RET_FORMAT_ERR = -401; /**< Failed to checking tensor format. */ + +/* InferShape error code, range: [-501,-600] */ +constexpr int RET_INFER_ERR = -501; /**< Failed to infer shape. */ +constexpr int RET_INFER_INVALID = -502; /**< Invalid to infer shape before runtime. */ } // namespace lite } // namespace mindspore #endif // MINDSPORE_LITE_INCLUDE_ERRORCODE_H_ - diff --git a/mindspore/lite/src/executor.cc b/mindspore/lite/src/executor.cc index a82d7cd1ed..e4a2dc62a8 100644 --- a/mindspore/lite/src/executor.cc +++ b/mindspore/lite/src/executor.cc @@ -37,11 +37,6 @@ int Executor::Run(std::vector &inputs, std::vectorGetOutputs(); - for (auto *output : outputs) { - MS_ASSERT(nullptr != output); - output->MallocData(); - } session::CallBackParam callbackParam; callbackParam.name_callback_param = kernel->Name(); callbackParam.type_callback_param = kernel->type_str(); diff --git a/mindspore/lite/src/kernel_factory.cc b/mindspore/lite/src/kernel_factory.cc index b835506bec..fea004db99 100644 --- a/mindspore/lite/src/kernel_factory.cc +++ b/mindspore/lite/src/kernel_factory.cc @@ -45,7 +45,7 @@ LiteKernel *KernelFactory::GetKernel(const std::vector &inputs } auto creator = KernelRegistry::GetInstance()->GetCreator(key); if (creator != nullptr) { - auto kernel = creator(inputs, outputs, parameter, ctx, key); + auto kernel = creator(inputs, outputs, parameter, ctx, key, primitive); return kernel; } return nullptr; diff --git a/mindspore/lite/src/kernel_registry.h b/mindspore/lite/src/kernel_registry.h index eab7d03a53..772833f775 100644 --- a/mindspore/lite/src/kernel_registry.h +++ b/mindspore/lite/src/kernel_registry.h @@ -45,7 +45,6 @@ class KernelRegistry { int device_type_length_; int data_type_length_; int op_type_length_; - std::mutex lock_; }; class KernelRegistrar { diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index 1dba44c862..040068bf6c 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -25,6 +25,7 @@ #include "include/context.h" #include "src/ir/tensor.h" #include "src/ops/ops.h" +#include "include/errorcode.h" #ifdef ENABLE_FP16 using FLOAT_t = float16_t; @@ -34,6 +35,8 @@ using FLOAT_t = float; // using mindspore::kernel::AddressPtr; namespace mindspore::kernel { +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; enum KERNEL_ARCH { kCPU, kGPU, kNPU, kKernelArch_MIN = kCPU, kKernelArch_MAX = kNPU }; struct KernelKey { KERNEL_ARCH arch; @@ -55,15 +58,30 @@ class LiteKernel { public: LiteKernel() = default; explicit LiteKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : opParameter(parameter), inputs_(inputs), outputs_(outputs), train_mode(false) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : opParameter(parameter), inputs_(inputs), outputs_(outputs), train_mode(false), primitive_(primitive), + context_(ctx) { this->in_kernel_.clear(); this->out_kernel_.clear(); } virtual ~LiteKernel() { delete opParameter; } - virtual int Prepare() { return -1; } + virtual int Prepare() { + if (primitive_ != nullptr && !primitive_->GetInferFlag()) { + (const_cast(primitive_))->InferShape(inputs_, outputs_); + } + if (need_reinit) { + Init(); + } + auto &outputs = this->GetOutputs(); + for (auto *output : outputs) { + MS_ASSERT(output != nullptr); + output->MallocData(); + } + return RET_OK; + } virtual int Init() { return -1; } virtual int ReSize() { return -1; } virtual int Run() { return -1; } @@ -103,16 +121,23 @@ class LiteKernel { void set_desc(const KernelKey kernel_key) { desc = kernel_key; } + void SetNeedReInit() { + need_reinit = true; + } + protected: KernelKey desc; std::string name; OpParameter *opParameter = nullptr; + const lite::Primitive *primitive_; + const lite::Context *context_; // tensor will free in ~lite_session() std::vector inputs_; std::vector outputs_; std::vector in_kernel_; std::vector out_kernel_; bool train_mode; + bool need_reinit = false; }; class SubGraphKernel : public LiteKernel { @@ -121,8 +146,9 @@ class SubGraphKernel : public LiteKernel { const std::vector &outputs, const std::vector &inKernels, const std::vector &outKernels, - const std::vector &nodes) - : LiteKernel(nullptr, inputs, outputs), + const std::vector &nodes, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(nullptr, inputs, outputs, ctx, primitive), inputs_(inputs), outputs_(outputs), inkernels_(inKernels), @@ -144,7 +170,7 @@ class SubGraphKernel : public LiteKernel { typedef LiteKernel *(*KernelCreator)(const std::vector &inputs, const std::vector &outputs, OpParameter *parameter, - const lite::Context *ctx, const KernelKey &desc); + const lite::Context *ctx, const KernelKey &desc, const lite::Primitive *primitive); class LiteKernelUtil { public: diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index ae529836ab..aa402415e7 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -168,6 +168,7 @@ std::vector LiteSession::GetInputs() const { int LiteSession::RunGraph(const session::KernelCallBack &before, const session::KernelCallBack &after) { MS_EXCEPTION_IF_NULL(this->context_); SetMaxWokerNum(context_->thread_num_); + context_->running_ = true; Executor executor; if (before == nullptr && after == nullptr) { return executor.Run(this->inputs, this->outputs, this->kernels, this->context_->allocator.get()); diff --git a/mindspore/lite/src/ops/cast.cc b/mindspore/lite/src/ops/cast.cc index 13de84ff5e..565f8de767 100644 --- a/mindspore/lite/src/ops/cast.cc +++ b/mindspore/lite/src/ops/cast.cc @@ -40,7 +40,7 @@ int Cast::InferShape(std::vector inputs_, std::vectordata_type(); return RET_INPUT_TENSOR_ERROR; } - if (cast_prim->dstT() != kNumberTypeFloat || cast_prim->dstT() != kNumberTypeFloat32) { + if (cast_prim->dstT() != kNumberTypeFloat && cast_prim->dstT() != kNumberTypeFloat32) { MS_LOG(ERROR) << "Invalid output datatype " << cast_prim->dstT(); return RET_INPUT_TENSOR_ERROR; } diff --git a/mindspore/lite/src/ops/ops.h b/mindspore/lite/src/ops/ops.h index 302f085b9d..04133f9357 100644 --- a/mindspore/lite/src/ops/ops.h +++ b/mindspore/lite/src/ops/ops.h @@ -45,12 +45,15 @@ class Primitive { static Primitive *CreatePrimitive(schema::Primitive *primitive); virtual ~Primitive() {} const schema::Primitive *Value() const { return this->primitive; } + const bool GetInferFlag() const { return this->infer_flag_; } + void SetInferFlag(bool flag) { this->infer_flag_ = flag; } schema::PrimitiveType Type() const { return this->primitive->value_type(); } const void *Attribute() const { return this->primitive->value(); } virtual int InferShape(std::vector inputs_, std::vector outputs_); protected: schema::Primitive *primitive; + bool infer_flag_ = true; }; class Conv2D : public Primitive { diff --git a/mindspore/lite/src/ops/reshape.cc b/mindspore/lite/src/ops/reshape.cc index 683e3f404e..7285de3877 100644 --- a/mindspore/lite/src/ops/reshape.cc +++ b/mindspore/lite/src/ops/reshape.cc @@ -34,11 +34,11 @@ int Reshape::CalNewShape(const tensor::Tensor *in_tensor, std::vector *out_ inferIndex = i; } else { MS_LOG(ERROR) << "output shape should has no more than one dim which need infer"; - return RET_ERROR; + return RET_INFER_ERR; } } else if (out_shape->at(i) < 0) { MS_LOG(ERROR) << "output shape dim should be non-negative"; - return RET_ERROR; + return RET_INFER_ERR; } else if (out_shape->at(i) == 0) { out_shape->at(i) = in_tensor->shape().at(i); out_shapeSize *= out_shape->at(i); @@ -49,7 +49,7 @@ int Reshape::CalNewShape(const tensor::Tensor *in_tensor, std::vector *out_ if (inferIndex == -1 && out_shapeSize != in_shape_size) { MS_LOG(ERROR) << "output shapeSize: " << out_shapeSize << " should be equal to input shapeSize: " << in_shape_size; - return RET_ERROR; + return RET_INFER_ERR; } if (inferIndex != -1) { out_shape->at(inferIndex) = in_shape_size / out_shapeSize; @@ -88,7 +88,11 @@ int Reshape::InferShape(std::vector inputs_, std::vector out_shape; if (inputs_.size() == kDoubleNum) { auto shape_tensor = inputs_.at(1); - size_t shape_size = shape_tensor->ElementsNum(); + if (shape_tensor->Data() == nullptr) { + MS_LOG(INFO) << "Do infer shape in runtime."; + return RET_INFER_INVALID; + } + size_t shape_size = shape_tensor->shape().size(); switch (shape_tensor->data_type()) { case kNumberTypeInt8: { auto data = reinterpret_cast(shape_tensor->Data()); @@ -108,13 +112,14 @@ int Reshape::InferShape(std::vector inputs_, std::vectordata_type(); - return RET_ERROR; + return RET_INFER_ERR; } } } else if (inputs_.size() == kSingleNum) { std::copy(reshape_prim->shape()->begin(), reshape_prim->shape()->end(), std::back_inserter(out_shape)); } else { MS_LOG(ERROR) << "inputs tensor size invalid."; + return RET_INFER_ERR; } auto ret = CalNewShape(inputs_.front(), &out_shape); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc index 0e356a5315..b3b6d4cba0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc @@ -24,14 +24,18 @@ using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_PARAM_INVALID; using mindspore::lite::RET_FORMAT_ERR; using mindspore::lite::RET_OK; +using mindspore::lite::RET_PARAM_INVALID; using mindspore::schema::PrimitiveType_ArgMax; using mindspore::schema::PrimitiveType_ArgMin; namespace mindspore::kernel { int ArgMinMaxBaseCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto param = reinterpret_cast(opParameter); switch (opParameter->type_) { case PrimitiveType_ArgMax: @@ -44,6 +48,7 @@ int ArgMinMaxBaseCPUKernel::Init() { MS_LOG(ERROR) << "Unexpected type " << opParameter->type_; return RET_ERROR; } + auto in_shape = inputs_.at(0)->shape(); auto dims_size = in_shape.size(); int axis = param->axis_ < 0 ? param->axis_ + dims_size : param->axis_; @@ -56,9 +61,9 @@ int ArgMinMaxBaseCPUKernel::Init() { param->topk_ = MSMIN(param->topk_, in_shape[axis]); if (param->topk_ > 1) { if (context_ != nullptr && context_->allocator != nullptr) { - param->arg_elements_ - = reinterpret_cast(context_->allocator->Malloc(sizeof(ArgElement) * in_shape[axis])); - data_from_allocator_ = true; + param->arg_elements_ = + reinterpret_cast(context_->allocator->Malloc(sizeof(ArgElement) * in_shape[axis])); + data_from_allocator_ = true; } else { param->arg_elements_ = reinterpret_cast(malloc(sizeof(ArgElement) * in_shape[axis])); } @@ -98,12 +103,12 @@ void ArgMinMaxBaseCPUKernel::FreeTmpMemory() { kernel::LiteKernel *CpuArgMinMaxInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (op_parameter == nullptr) { MS_LOG(ERROR) << "Input op_parameter is nullptr!"; return nullptr; } - auto kernel = new (std::nothrow) ArgMinMaxInt8CPUKernel(op_parameter, inputs, outputs, ctx); + auto kernel = new (std::nothrow) ArgMinMaxInt8CPUKernel(op_parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new ArgMinMaxInt8CPUKernel fail!"; return nullptr; @@ -122,12 +127,12 @@ kernel::LiteKernel *CpuArgMinMaxInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (op_parameter == nullptr) { MS_LOG(ERROR) << "Input op_parameter is nullptr!"; return nullptr; } - auto kernel = new (std::nothrow) ArgMinMaxCPUKernel(op_parameter, inputs, outputs, ctx); + auto kernel = new (std::nothrow) ArgMinMaxCPUKernel(op_parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new ArgMinMaxCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.h b/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.h index d5ad81f29d..771ecce902 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.h @@ -24,8 +24,9 @@ namespace mindspore::kernel { class ArgMinMaxBaseCPUKernel : public LiteKernel { public: ArgMinMaxBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), context_(ctx), data_from_allocator_(false) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), data_from_allocator_(false) { opParameter->thread_num_ = ctx->thread_num_; } @@ -40,7 +41,6 @@ class ArgMinMaxBaseCPUKernel : public LiteKernel { void FreeTmpMemory(); private: - const lite::Context *context_; bool data_from_allocator_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc index e816153411..dc320b049f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc @@ -46,13 +46,13 @@ int BatchToSpaceBaseCPUKernel::Init() { kernel::LiteKernel *CpuBatchToSpaceInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(desc.type == schema::PrimitiveType_BatchToSpace); if (op_parameter == nullptr) { MS_LOG(ERROR) << "Input op_parameter is nullptr!"; return nullptr; } - auto *kernel = new (std::nothrow) BatchToSpaceInt8CPUKernel(op_parameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) BatchToSpaceInt8CPUKernel(op_parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new BatchToSpaceInt8CPUKernel fail!"; return nullptr; @@ -71,13 +71,13 @@ kernel::LiteKernel *CpuBatchToSpaceInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(desc.type == schema::PrimitiveType_BatchToSpace); if (op_parameter == nullptr) { MS_LOG(ERROR) << "Input op_parameter is nullptr!"; return nullptr; } - auto *kernel = new (std::nothrow) BatchToSpaceCPUKernel(op_parameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) BatchToSpaceCPUKernel(op_parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new BatchToSpaceCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.h b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.h index 131b512e76..e8e6f83ac9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class BatchToSpaceBaseCPUKernel : public LiteKernel { public: BatchToSpaceBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { opParameter->thread_num_ = ctx->thread_num_; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc index 24b0381d50..5674efc31c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.cc @@ -30,21 +30,24 @@ using mindspore::schema::PrimitiveType_Concat; namespace mindspore::kernel { int ConcatBaseCPUKernel::Init() { - auto axis = concat_param_->axis_; - axis_ = axis >= 0 ? axis : inputs_.front()->shape().size() + axis; + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } + axis_ = concat_param_->axis_ >= 0 ? concat_param_->axis_ : inputs_.front()->shape().size() + concat_param_->axis_; return RET_OK; } kernel::LiteKernel *CpuConcatInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Concat); - auto *kernel = new(std::nothrow) ConcatInt8CPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) ConcatInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; return nullptr; @@ -60,15 +63,15 @@ kernel::LiteKernel *CpuConcatInt8KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Concat); - auto *kernel = new(std::nothrow) ConcatCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) ConcatCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; return nullptr; @@ -84,15 +87,15 @@ kernel::LiteKernel *CpuConcatInt32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Concat); - auto *kernel = new(std::nothrow) ConcatCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) ConcatCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; return nullptr; @@ -111,4 +114,3 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Concat, CpuConcatInt8KernelCreat REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Concat, CpuConcatInt32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Concat, CpuConcatFp32KernelCreator) } // namespace mindspore::kernel - diff --git a/mindspore/lite/src/runtime/kernel/arm/base/concat_base.h b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.h index 9c7f558083..7ed1220b09 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/concat_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/concat_base.h @@ -28,8 +28,9 @@ namespace mindspore::kernel { class ConcatBaseCPUKernel : public LiteKernel { public: ConcatBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { opParameter->thread_num_ = ctx->thread_num_; concat_param_ = reinterpret_cast(opParameter); } @@ -41,6 +42,7 @@ class ConcatBaseCPUKernel : public LiteKernel { int ReSize() override { return 0; } int Run() override { return 0; } + protected: int thread_count_; int axis_; @@ -50,4 +52,3 @@ class ConcatBaseCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONCAT_BASE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h index 89b53dfcad..a2845b6897 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h @@ -37,8 +37,9 @@ namespace mindspore::kernel { class ConvolutionBaseCPUKernel : public LiteKernel { public: ConvolutionBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { opParameter->thread_num_ = ctx->thread_num_; conv_param_ = reinterpret_cast(opParameter); } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc index 9f66feb208..be27b1f9cc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc @@ -31,15 +31,15 @@ namespace mindspore::kernel { int CropBaseCPUKernel::Init() { return RET_OK; } kernel::LiteKernel *CpuCropInt8KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Crop); - auto *kernel = new (std::nothrow) CropInt8CPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) CropInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new CropCPUKernel fail!"; return nullptr; @@ -55,15 +55,15 @@ kernel::LiteKernel *CpuCropInt8KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Crop); - auto *kernel = new (std::nothrow) CropCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) CropCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new CropCPUKernel fail!"; return nullptr; @@ -79,15 +79,15 @@ kernel::LiteKernel *CpuCropInt32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const std::vector &outputs, + OpParameter *opParameter, const Context *ctx, + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Crop); - auto *kernel = new (std::nothrow) CropCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) CropCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new CropCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/crop_base.h b/mindspore/lite/src/runtime/kernel/arm/base/crop_base.h index f4ad763b5f..6e9a6843b8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/crop_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/crop_base.h @@ -27,8 +27,9 @@ namespace mindspore::kernel { class CropBaseCPUKernel : public LiteKernel { public: CropBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { opParameter->thread_num_ = ctx->thread_num_; } ~CropBaseCPUKernel() = default; @@ -39,7 +40,6 @@ class CropBaseCPUKernel : public LiteKernel { protected: int thread_count_; - const Context *ctx_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc index b18f7dc9bb..cd90425efa 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc @@ -25,13 +25,17 @@ using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_PARAM_INVALID; using mindspore::lite::RET_FORMAT_ERR; using mindspore::lite::RET_OK; +using mindspore::lite::RET_PARAM_INVALID; using mindspore::schema::PrimitiveType_DepthToSpace; namespace mindspore::kernel { int DepthToSpaceBaseCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } if (inputs_[0]->GetFormat() != schema::Format_NHWC) { MS_LOG(ERROR) << "depth_to_space only support NHWC now!"; return RET_FORMAT_ERR; @@ -62,13 +66,13 @@ int DepthToSpaceBaseCPUKernel::Init() { kernel::LiteKernel *CpuDepthToSpaceInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(desc.type == schema::PrimitiveType_DepthToSpace); if (op_parameter == nullptr) { MS_LOG(ERROR) << "Input op_parameter is nullptr!"; return nullptr; } - auto *kernel = new (std::nothrow) DepthToSpaceInt8CPUKernel(op_parameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) DepthToSpaceInt8CPUKernel(op_parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new BatchToSpaceInt8CPUKernel fail!"; return nullptr; @@ -87,13 +91,13 @@ kernel::LiteKernel *CpuDepthToSpaceInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(desc.type == schema::PrimitiveType_DepthToSpace); if (op_parameter == nullptr) { MS_LOG(ERROR) << "Input op_parameter is nullptr!"; return nullptr; } - auto *kernel = new (std::nothrow) DepthToSpaceCPUKernel(op_parameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) DepthToSpaceCPUKernel(op_parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new DepthToSpaceCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h index 32974271aa..849934fb12 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class DepthToSpaceBaseCPUKernel : public LiteKernel { public: DepthToSpaceBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { opParameter->thread_num_ = ctx->thread_num_; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc index 4f74e94360..ecdf77cad4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc @@ -35,10 +35,11 @@ int FullconnectionBaseCPUKernel::Init() { kernel::LiteKernel *CpuFullConnectionInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Concat); - auto kernel = new (std::nothrow) FullconnectionInt8CPUKernel(opParameter, inputs, outputs, ctx); + auto kernel = new (std::nothrow) FullconnectionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (!kernel) { MS_LOG(ERROR) << "kernel is nullptr."; return nullptr; @@ -56,10 +57,11 @@ kernel::LiteKernel *CpuFullConnectionInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Concat); - auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx); + auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (!kernel) { MS_LOG(ERROR) << "kernel is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.h b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.h index 3d29519370..a5a056a1ef 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.h @@ -28,8 +28,9 @@ namespace mindspore::kernel { class FullconnectionBaseCPUKernel : public LiteKernel { public: FullconnectionBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { fc_param_ = reinterpret_cast(opParameter); } ~FullconnectionBaseCPUKernel() = default; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc index 56024abf33..af7ba4afab 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc @@ -28,7 +28,8 @@ using mindspore::schema::PrimitiveType_MatMul; namespace mindspore::kernel { kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, - const lite::Context *ctx, const kernel::KernelKey &desc) { + const lite::Context *ctx, const kernel::KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Concat); auto input_tensor = inputs.at(kInputIndex); @@ -37,7 +38,7 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { params_ = reinterpret_cast(opParameter); } ~MatmulBaseCPUKernel() = default; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/pad.cc b/mindspore/lite/src/runtime/kernel/arm/base/pad.cc index 723657b603..2e071706a7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/pad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/pad.cc @@ -31,10 +31,10 @@ namespace mindspore::kernel { kernel::LiteKernel *CpuPadInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Pad); - auto *kernel = new (std::nothrow) PadInt8CPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) PadInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new PadCPUKernel failed."; return nullptr; @@ -52,10 +52,10 @@ kernel::LiteKernel *CpuPadInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Pad); - auto *kernel = new (std::nothrow) PadCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) PadCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new PadCPUKernel failed."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc index 68dedfe351..45e7186b6e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc @@ -56,6 +56,10 @@ void PoolingBaseCPUKernel::FreeQuantParam() { } int PoolingBaseCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } MS_ASSERT(inputs_.size() == 1); MS_ASSERT(outputs_.size() == 1); pooling_param_->thread_num_ = thread_count_; @@ -78,13 +82,13 @@ int PoolingBaseCPUKernel::Init() { kernel::LiteKernel *CpuPoolingInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Pooling); - auto *kernel = new (std::nothrow) PoolingInt8CPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) PoolingInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new PoolingInt8CPUKernel fail!"; return nullptr; @@ -102,13 +106,13 @@ kernel::LiteKernel *CpuPoolingInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Pooling); - auto *kernel = new (std::nothrow) PoolingCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) PoolingCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new PoolingCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.h b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.h index a601db0bb3..4a56a44f9c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.h @@ -29,8 +29,9 @@ namespace mindspore::kernel { class PoolingBaseCPUKernel : public LiteKernel { public: PoolingBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { pooling_param_ = reinterpret_cast(opParameter); } ~PoolingBaseCPUKernel() = default; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/prelu_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/prelu_base.cc index f08cb56950..1c765d6de3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/prelu_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/prelu_base.cc @@ -32,13 +32,13 @@ int PreluBaseCPUKernel::Init() {return RET_OK;} kernel::LiteKernel *CpuPreluInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Prelu); - auto *kernel = new(std::nothrow) PreluInt8CPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new(std::nothrow) PreluInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new PreluCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/prelu_base.h b/mindspore/lite/src/runtime/kernel/arm/base/prelu_base.h index 3b10023681..66c5cdc147 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/prelu_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/prelu_base.h @@ -28,8 +28,9 @@ namespace mindspore::kernel { class PreluBaseCPUKernel : public LiteKernel { public: PreluBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { opParameter->thread_num_ = ctx->thread_num_; prelu_param_ = reinterpret_cast(opParameter); } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc b/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc index e09ec8889c..680e62e329 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc @@ -39,6 +39,11 @@ int PriorBoxCPUKernel::Init() { MS_LOG(ERROR) << "PriorBoxParameter nullptr"; return RET_NULL_PTR; } + + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } MS_ASSERT(inputs_.size() == kInputNum); MS_ASSERT(outputs_.size() == kOutputNum); @@ -164,7 +169,7 @@ int PriorBoxCPUKernel::Run() { kernel::LiteKernel *CpuPriorBoxKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; @@ -173,7 +178,7 @@ kernel::LiteKernel *CpuPriorBoxKernelCreator(const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { prior_box_param_ = reinterpret_cast(opParameter); } ~PriorBoxCPUKernel() = default; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc index 80fac8f757..2ecf52ee12 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc @@ -34,6 +34,10 @@ constexpr int kQuantDTypeCastOutputNum = 1; } // namespace int QuantDTypeCastCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } if (inputs_.size() != 1) { MS_LOG(ERROR) << "inputs number should be 1, but " << inputs_.size() << " is given."; return RET_ERROR; @@ -83,8 +87,8 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { ret = DequantizeInt8(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, quant_arg.zeroPoint, num_unit_thread); } else { - ret = QuantizeToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, - quant_arg.zeroPoint, num_unit_thread); + ret = QuantizeToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, quant_arg.zeroPoint, + num_unit_thread); } if (ret != RET_OK) { MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]"; @@ -124,12 +128,13 @@ int QuantDTypeCastCPUKernel::Run() { kernel::LiteKernel *CpuQuantDTypeCastFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, + const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } - auto *kernel = new (std::nothrow) QuantDTypeCastCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) QuantDTypeCastCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new QuantDTypeCastCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h index 0ea72b2ddc..e7ca5edfc0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h @@ -24,8 +24,9 @@ namespace mindspore::kernel { class QuantDTypeCastCPUKernel : public LiteKernel { public: QuantDTypeCastCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), thread_num_(ctx->thread_num_) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) {} ~QuantDTypeCastCPUKernel() = default; int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc index f9712c9968..22f7f3e657 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc @@ -36,13 +36,13 @@ int ReshapeBaseCPUKernel::Init() { kernel::LiteKernel *CpuReshapeInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Reshape); - auto *kernel = new (std::nothrow) ReshapeInt8CPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) ReshapeInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new ReshapeInt8CPUKernel fail!"; return nullptr; @@ -60,13 +60,13 @@ kernel::LiteKernel *CpuReshapeInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Reshape); - auto *kernel = new (std::nothrow) ReshapeCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) ReshapeCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new ReshapeCPUKernel fail!"; return nullptr; @@ -84,13 +84,13 @@ kernel::LiteKernel *CpuReshapeInt32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Reshape); - auto *kernel = new (std::nothrow) ReshapeCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) ReshapeCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new ReshapeCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.h b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.h index 2c0ca7ea30..a9d2fb3ab9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.h @@ -27,8 +27,9 @@ namespace mindspore::kernel { class ReshapeBaseCPUKernel : public LiteKernel { public: ReshapeBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { reshape_param_ = reinterpret_cast(opParameter); } ~ReshapeBaseCPUKernel() = default; @@ -45,4 +46,3 @@ class ReshapeBaseCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_RESHAPE_BASE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc index f5f348e9ee..8029b715ac 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.cc @@ -53,13 +53,13 @@ int SoftmaxBaseCPUKernel::Init() { kernel::LiteKernel *CpuSoftmaxInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_SoftMax); - auto *kernel = new (std::nothrow) SoftmaxInt8CPUKernel(opParameter, inputs, outputs, ctx); + SoftmaxInt8CPUKernel *kernel = new (std::nothrow) SoftmaxInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new SoftmaxCPUKernel fail!"; return nullptr; @@ -77,13 +77,13 @@ kernel::LiteKernel *CpuSoftmaxInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_SoftMax); - auto *kernel = new (std::nothrow) SoftmaxCPUKernel(opParameter, inputs, outputs, ctx); + SoftmaxCPUKernel *kernel = new (std::nothrow) SoftmaxCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new SoftmaxCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.h b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.h index 4e8873aec0..c150b566a8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/softmax_base.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class SoftmaxBaseCPUKernel : public LiteKernel { public: SoftmaxBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { opParameter->thread_num_ = ctx->thread_num_; softmax_param_ = reinterpret_cast(opParameter); } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc index 7e04a8ab2d..4636c90a3e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc @@ -61,13 +61,13 @@ int SplitBaseCPUKernel::Init() { kernel::LiteKernel *CpuSplitInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Split); - auto *kernel = new (std::nothrow) SplitInt8CPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) SplitInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new SplitCPUKernel fail!"; return nullptr; @@ -85,13 +85,13 @@ kernel::LiteKernel *CpuSplitInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Split); - auto *kernel = new (std::nothrow) SplitCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) SplitCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new SplitCPUKernel fail!"; return nullptr; @@ -109,13 +109,13 @@ kernel::LiteKernel *CpuSplitInt32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Split); - auto *kernel = new (std::nothrow) SplitCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) SplitCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new SplitCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/split_base.h b/mindspore/lite/src/runtime/kernel/arm/base/split_base.h index 0f90604cfc..38d876cc1f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/split_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/split_base.h @@ -27,8 +27,9 @@ namespace mindspore::kernel { class SplitBaseCPUKernel : public LiteKernel { public: SplitBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { param = reinterpret_cast(opParameter); } ~SplitBaseCPUKernel() = default; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/squeeze_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/squeeze_base.cc index ebc5107a6d..76dc57002a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/squeeze_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/squeeze_base.cc @@ -32,13 +32,13 @@ int SqueezeBaseCPUKernel::Init() { return RET_OK; } kernel::LiteKernel *CpuSqueezeInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Squeeze); - auto *kernel = new (std::nothrow) SqueezeInt8CPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) SqueezeInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new SqueezeCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/squeeze_base.h b/mindspore/lite/src/runtime/kernel/arm/base/squeeze_base.h index 66f10f936e..ac203e70ed 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/squeeze_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/squeeze_base.h @@ -28,8 +28,9 @@ namespace mindspore::kernel { class SqueezeBaseCPUKernel : public LiteKernel { public: SqueezeBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { opParameter->thread_num_ = ctx->thread_num_; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc index 886bec2e68..9da679e86f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc @@ -42,12 +42,18 @@ int StridedSliceCPUKernel::Init() { int StridedSliceCPUKernel::ReSize() { return 0; } int StridedSliceCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } + auto input = inputs_.at(0); auto output = outputs_.at(0); MS_ASSERT(input); MS_ASSERT(output); - auto ret = DoStridedSlice(input->Data(), output->Data(), reinterpret_cast(opParameter)); + ret = DoStridedSlice(input->Data(), output->Data(), reinterpret_cast(opParameter)); if (ret != RET_OK) { MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]"; return RET_ERROR; @@ -58,13 +64,13 @@ int StridedSliceCPUKernel::Run() { kernel::LiteKernel *CpuStridedSliceKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(desc.type == schema::PrimitiveType_StridedSlice); if (opParameter == nullptr) { MS_LOG(ERROR) << "opParameter null pointer dereferencing."; return nullptr; } - auto *kernel = new (std::nothrow) StridedSliceCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) StridedSliceCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "New kernel fails."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h index f6d8845ad1..66d063c7ca 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class StridedSliceCPUKernel : public LiteKernel { public: StridedSliceCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), thread_num_(ctx->thread_num_) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) {} ~StridedSliceCPUKernel() override = default; int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc index 0f2e18111c..b294bb30db 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc @@ -183,10 +183,14 @@ void Convolution3x3FP16CPUKernel::ConfigInputOutput() { } int Convolution3x3FP16CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvolutionBase init failed."; - return RET_ERROR; + return ret; } ret = InitWeightBias(); if (ret != RET_OK) { @@ -228,7 +232,7 @@ int Convolution3x3FP16CPUKernel::ReSize() { auto ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvolutionBase init failed."; - return RET_ERROR; + return ret; } ret = InitTmpBuffer(); if (ret != RET_OK) { @@ -256,7 +260,11 @@ int Convolution3x3Fp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) } int Convolution3x3FP16CPUKernel::Run() { - // cast fp32 input data to fp16 + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input_tensor = inputs_.at(kInputIndex); auto ori_input_data = reinterpret_cast(input_tensor->Data()); for (int i = 0; i < input_tensor->ElementsNum(); ++i) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h index e6f8862e9b..8b5994be0a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h @@ -27,8 +27,9 @@ namespace mindspore::kernel { class Convolution3x3FP16CPUKernel : public ConvolutionBaseCPUKernel { public: Convolution3x3FP16CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~Convolution3x3FP16CPUKernel() override { if (fp16_input_ != nullptr) { free(fp16_input_); @@ -78,4 +79,3 @@ void ProcessFilterFp16(float16_t *origin_weight, float16_t *dst_weight, ConvPara } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_3x3_FP16_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc index ec027deacd..522dd49d90 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc @@ -85,14 +85,20 @@ int ConvolutionDepthwiseFp16CPUKernel::InitWeightBias() { } int ConvolutionDepthwiseFp16CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } // conv base init - ConvolutionBaseCPUKernel::Init(); - + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + return ret; + } // init sliding_ window param sliding_ = new SlidingWindowParam; InitSlidingParam(sliding_, conv_param_, C8NUM); - auto ret = InitWeightBias(); + ret = InitWeightBias(); if (ret != 0) { MS_LOG(ERROR) << "Convolution depthwise fp16 InitWeightBias failed."; return RET_ERROR; @@ -138,6 +144,11 @@ int ConvDwFp16Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int ConvolutionDepthwiseFp16CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } if (conv_param_->input_channel_ != conv_param_->output_channel_) { MS_LOG(ERROR) << "Only support input channel equals output channel."; return RET_ERROR; @@ -149,7 +160,7 @@ int ConvolutionDepthwiseFp16CPUKernel::Run() { PackNHWCFp32ToNHWC8Fp16(input_addr, packed_input_, conv_param_->input_batch_, conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_); - auto ret = LiteBackendParallelLaunch(ConvDwFp16Run, this, conv_param_->thread_num_); + ret = LiteBackendParallelLaunch(ConvDwFp16Run, this, conv_param_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvDwFp16Run error: error_code[" << ret << "]"; return RET_ERROR; @@ -165,10 +176,10 @@ int ConvolutionDepthwiseFp16CPUKernel::Run() { kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); - auto kernel = new (std::nothrow) ConvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx); + auto kernel = new (std::nothrow) ConvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h index 5b81404e3e..d605ca5ba0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h @@ -26,8 +26,9 @@ namespace mindspore::kernel { class ConvolutionDepthwiseFp16CPUKernel : public ConvolutionBaseCPUKernel { public: ConvolutionDepthwiseFp16CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~ConvolutionDepthwiseFp16CPUKernel() override { delete sliding_; free(packed_weight_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index fc48c4d188..96ed5bc6fa 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -154,10 +154,14 @@ void ConvolutionFP16CPUKernel::ConfigInputOutput() { } int ConvolutionFP16CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { - MS_LOG(ERROR) << "ConvolutionBase init failed."; - return RET_ERROR; + MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret; + return ret; } ret = InitWeightBias(); if (ret != RET_OK) { @@ -193,7 +197,7 @@ int ConvolutionFP16CPUKernel::ReSize() { auto ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvolutionBase init failed."; - return RET_ERROR; + return ret; } ret = InitTmpBuffer(); if (ret != RET_OK) { @@ -220,7 +224,11 @@ int ConvolutionFp16Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int ConvolutionFP16CPUKernel::Run() { - // cast fp32 input data to fp16 + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input_tensor = inputs_.at(kInputIndex); auto ori_input_data = reinterpret_cast(input_tensor->Data()); for (int i = 0; i < input_tensor->ElementsNum(); ++i) { @@ -251,7 +259,7 @@ int ConvolutionFP16CPUKernel::Run() { kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); auto conv_param = reinterpret_cast(opParameter); @@ -267,7 +275,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vectoroutput_w_ = outputs.front()->Width(); kernel::LiteKernel *kernel = nullptr; if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { - kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx); + kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); } else { bool use_winograd = false; int out_unit; @@ -275,7 +283,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~ConvolutionFP16CPUKernel() override { if (fp16_input_ != nullptr) { free(fp16_input_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc index 4a0097b75c..0df66a963c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc @@ -99,12 +99,19 @@ int DeconvolutionDepthwiseFp16CPUKernel::InitWeightBias() { } int DeconvolutionDepthwiseFp16CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } sliding_ = new SlidingWindowParam; InitSlideParam(); // conv base init - ConvolutionBaseCPUKernel::Init(); + auto ret = ConvolutionBaseCPUKernel::Init(); + if (ret != RET_OK) { + return ret; + } - auto ret = InitWeightBias(); + ret = InitWeightBias(); if (ret != 0) { MS_LOG(ERROR) << "Deconvolution depthwise fp16 InitWeightBias failed."; return RET_ERROR; @@ -150,6 +157,11 @@ int DeconvDwFp16Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int DeconvolutionDepthwiseFp16CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } if (conv_param_->input_channel_ != conv_param_->output_channel_) { MS_LOG(ERROR) << "Only support input channel equals output channel."; return RET_ERROR; @@ -161,7 +173,7 @@ int DeconvolutionDepthwiseFp16CPUKernel::Run() { PackNHWCFp32ToNHWC8Fp16(input_addr, packed_input_, conv_param_->input_batch_, conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_); - auto ret = LiteBackendParallelLaunch(DeconvDwFp16Run, this, conv_param_->thread_num_); + ret = LiteBackendParallelLaunch(DeconvDwFp16Run, this, conv_param_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "DeconvDwFp16Run error: error_code[" << ret << "]"; return RET_ERROR; @@ -176,10 +188,10 @@ int DeconvolutionDepthwiseFp16CPUKernel::Run() { kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); - auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx); + auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.h index 0e3a31682c..64807fa9d8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.h @@ -26,8 +26,9 @@ namespace mindspore::kernel { class DeconvolutionDepthwiseFp16CPUKernel : public ConvolutionBaseCPUKernel { public: DeconvolutionDepthwiseFp16CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~DeconvolutionDepthwiseFp16CPUKernel() override { delete sliding_; free(packed_weight_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc index e4fda400a4..756fa678fb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc @@ -19,6 +19,7 @@ #include "src/kernel_registry.h" #include "src/runtime/runtime_api.h" #include "include/errorcode.h" +#include "src/ops/ops.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -78,6 +79,11 @@ int ActivationRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int ActivationCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } int error_code = LiteBackendParallelLaunch(ActivationRun, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Activation function error error_code[" << error_code << "]"; @@ -89,10 +95,10 @@ int ActivationCPUKernel::Run() { kernel::LiteKernel *CpuActivationFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Activation); - auto *kernel = new (std::nothrow) ActivationCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) ActivationCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation.h b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.h index 2e21629d79..3ecc6e9f62 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/activation.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class ActivationCPUKernel : public LiteKernel { public: ActivationCPUKernel(OpParameter *param, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(param, inputs, outputs), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { type_ = (reinterpret_cast(param))->type_; alpha_ = (reinterpret_cast(param))->alpha_; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.cc index 279832aca2..df1b0f0392 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.cc @@ -20,8 +20,8 @@ #include "src/runtime/runtime_api.h" #include "include/errorcode.h" -using mindspore::lite::KernelRegistrar; using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; using mindspore::schema::ActivationGradType_HSWISH; @@ -32,8 +32,8 @@ using mindspore::schema::PrimitiveType_ActivationGrad; namespace mindspore::kernel { int ActivationGradCPUKernel::Init() { - outputs_[0]->set_shape(inputs_[0]->shape()); - return RET_OK; + outputs_[0]->set_shape(inputs_[0]->shape()); + return RET_OK; } int ActivationGradCPUKernel::ReSize() { return RET_OK; } @@ -58,7 +58,7 @@ int ActivationGradCPUKernel::DoActivation(int task_id) { error_code = TanhGrad(yt_addr, input_addr, length, output_addr); } else if (type_ == schema::ActivationGradType_HSWISH) { error_code = HSwishGrad(yt_addr, input_addr, length, output_addr); - } else if (type_ == schema::ActivationGradType_HSIGMOID) { + } else if (type_ == schema::ActivationGradType_HSIGMOID) { error_code = HSigmoidGrad(yt_addr, input_addr, length, output_addr); } else { MS_LOG(ERROR) << "Activation type error"; @@ -90,17 +90,17 @@ int ActivationGradCPUKernel::Run() { } kernel::LiteKernel *CpuActivationGradFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_ActivationGrad); - auto *kernel = new (std::nothrow) ActivationGradCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) ActivationGradCPUKernel(opParameter, inputs, outputs, ctx, primitive); MS_ASSERT(kernel != nullptr); auto ret = kernel->Init(); if (ret != RET_OK) { - MS_LOG(ERROR) << "InferShape kernel failed, name: " << opParameter->name_ - << ", type: " + MS_LOG(ERROR) << "InferShape kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); } return kernel; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h index c3de590123..c57713b9a0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/activation_grad.h @@ -27,8 +27,9 @@ namespace mindspore::kernel { class ActivationGradCPUKernel : public LiteKernel { public: explicit ActivationGradCPUKernel(OpParameter *param, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(param, inputs, outputs) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(param, inputs, outputs, ctx, primitive) { ActivationGradParameter *param_act_grad = reinterpret_cast(param); type_ = param_act_grad->type_; alpha_ = param_act_grad->alpha_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc index 226843c805..7ae87bcd5b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/addn.cc @@ -36,12 +36,9 @@ int AddNLaunch(int thread_id, LiteParallelGroupEnv *penv, void *cdata) { auto kernel = reinterpret_cast(cdata); return kernel->AddNParallelRun(thread_id); } -} +} // namespace -int AddNCPUKernel::Init() { - elements_num_ = inputs_[0]->ElementsNum(); - return RET_OK; -} +int AddNCPUKernel::Init() { return RET_OK; } int AddNCPUKernel::ReSize() { return RET_OK; } @@ -58,6 +55,12 @@ int AddNCPUKernel::AddNParallelRun(int thread_id) { } int AddNCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } + elements_num_ = inputs_[0]->ElementsNum(); auto input0_data = reinterpret_cast(inputs_[0]->Data()); auto input1_data = reinterpret_cast(inputs_[1]->Data()); auto output_data = reinterpret_cast(outputs_[0]->Data()); @@ -71,7 +74,7 @@ int AddNCPUKernel::Run() { in1_addr_ = input0_data; in2_addr_ = input1_data; out_addr_ = output_data; - int ret = LiteBackendParallelLaunch(AddNLaunch, this, opParameter->thread_num_); + ret = LiteBackendParallelLaunch(AddNLaunch, this, opParameter->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "addn launch fail!ret: " << ret; return RET_ERROR; @@ -91,7 +94,7 @@ int AddNCPUKernel::Run() { kernel::LiteKernel *CpuAddNFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (op_parameter == nullptr) { MS_LOG(ERROR) << "Input op_parameter is nullptr!"; return nullptr; @@ -102,7 +105,7 @@ kernel::LiteKernel *CpuAddNFp32KernelCreator(const std::vectorthread_num_ = ctx->thread_num_; - auto *kernel = new (std::nothrow) AddNCPUKernel(op_parameter, inputs, outputs); + auto *kernel = new (std::nothrow) AddNCPUKernel(op_parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new AddNCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/addn.h b/mindspore/lite/src/runtime/kernel/arm/fp32/addn.h index 43d27fad02..31a2fecfa0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/addn.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/addn.h @@ -21,18 +21,20 @@ #include "src/lite_kernel.h" #include "schema/model_generated.h" - namespace mindspore::kernel { class AddNCPUKernel : public LiteKernel { public: AddNCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~AddNCPUKernel() = default; int Init() override; int ReSize() override; int Run() override; int AddNParallelRun(int thread_id); + private: float *in1_addr_; float *in2_addr_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc index a89eb715fd..8ba1595ed9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc @@ -40,7 +40,12 @@ int ArgMinMaxCPUKernel::Init() { } int ArgMinMaxCPUKernel::Run() { - auto ret = ArgMinMaxBaseCPUKernel::Run(); + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << ret; + return ret; + } + ret = ArgMinMaxBaseCPUKernel::Run(); ArgMinMaxBaseCPUKernel::FreeTmpMemory(); return ret; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h index c6c4fb9be7..fdc8fbdd3a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h @@ -23,8 +23,9 @@ namespace mindspore::kernel { class ArgMinMaxCPUKernel : public ArgMinMaxBaseCPUKernel { public: ArgMinMaxCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : ArgMinMaxBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : ArgMinMaxBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~ArgMinMaxCPUKernel() = default; @@ -35,4 +36,3 @@ class ArgMinMaxCPUKernel : public ArgMinMaxBaseCPUKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc index 3d3d55f894..c8e01d387c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc @@ -41,6 +41,10 @@ ArithmeticCPUKernel::~ArithmeticCPUKernel() { } } int ArithmeticCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto element_num = outputs_[0]->ElementsNum(); tile_data0_ = new float[element_num]; @@ -92,6 +96,11 @@ int ArithmeticsRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int ArithmeticCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } if (arithmeticParameter_->broadcasting_) { auto input_data0 = reinterpret_cast(inputs_[0]->Data()); auto input_data1 = reinterpret_cast(inputs_[1]->Data()); @@ -108,9 +117,9 @@ int ArithmeticCPUKernel::Run() { kernel::LiteKernel *CpuArithmeticFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *parameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(parameter != nullptr); - auto kernel = new (std::nothrow) ArithmeticCPUKernel(parameter, inputs, outputs, ctx); + auto kernel = new (std::nothrow) ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h index 9fae6696c8..a24ee75372 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h @@ -48,8 +48,9 @@ class ArithmeticCPUKernel : public LiteKernel { public: ArithmeticCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { arithmeticParameter_ = reinterpret_cast(parameter); switch (parameter->type_) { case PrimitiveType_Mul: diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.cc index b3f9075bef..4c3fddbcbf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.cc @@ -261,12 +261,13 @@ int ArithmeticGradCPUKernel::Run() { kernel::LiteKernel *CpuArithmeticGradFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, + const lite::Primitive *primitive) { MS_EXCEPTION_IF_NULL(opParameter); if (opParameter == nullptr) { return nullptr; } - auto *kernel = new (std::nothrow) ArithmeticGradCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) ArithmeticGradCPUKernel(opParameter, inputs, outputs, ctx, primitive); MS_ASSERT(kernel != nullptr); auto ret = kernel->Init(); if (ret != RET_OK) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h index cea0c0e659..3dcd5811ec 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_grad.h @@ -37,8 +37,9 @@ class ArithmeticGradCPUKernel : public LiteKernel { public: explicit ArithmeticGradCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs), tile_data0(NULL), tile_data1(NULL), tile_data2(NULL) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), tile_data0(NULL), tile_data1(NULL), tile_data2(NULL) { switch (type()) { case PrimitiveType_MulGrad: arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul; // this will be adjusted in InferShape diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc index d65ef0d1c9..88c7e6cd0a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc @@ -27,6 +27,10 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { int ArithmeticSelfCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } int ret = ReSize(); return ret; } @@ -68,11 +72,16 @@ int ArithmeticSelfCPUKernel::DoArithmeticSelf(int task_id) { } int ArithmeticSelfCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } auto input_tensor = inputs_.at(0); auto out_tensor = outputs_.at(0); in_ptr_ = reinterpret_cast(input_tensor->Data()); out_ptr_ = reinterpret_cast(out_tensor->Data()); - int ret = LiteBackendParallelLaunch(ArithmeticSelfRuns, this, thread_sz_count_); + ret = LiteBackendParallelLaunch(ArithmeticSelfRuns, this, thread_sz_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]"; return ret; @@ -83,13 +92,14 @@ int ArithmeticSelfCPUKernel::Run() { kernel::LiteKernel *CpuArithmeticSelfFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); if (opParameter == nullptr) { MS_LOG(ERROR) << "Creator failed, opParameter is nullptr!"; return nullptr; } - auto *kernel = new (std::nothrow) ArithmeticSelfCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) ArithmeticSelfCPUKernel(opParameter, inputs, outputs, ctx, primitive); MS_ASSERT(kernel != nullptr); auto ret = kernel->Init(); if (ret != RET_OK) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h index bcc56820db..d17a97a914 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h @@ -24,9 +24,9 @@ #include "schema/model_generated.h" #include "include/context.h" - using mindspore::lite::Context; using mindspore::schema::PrimitiveType_Abs; +using mindspore::schema::PrimitiveType_Ceil; using mindspore::schema::PrimitiveType_Cos; using mindspore::schema::PrimitiveType_Exp; using mindspore::schema::PrimitiveType_Floor; @@ -36,7 +36,6 @@ using mindspore::schema::PrimitiveType_Rsqrt; using mindspore::schema::PrimitiveType_Sin; using mindspore::schema::PrimitiveType_Sqrt; using mindspore::schema::PrimitiveType_Square; -using mindspore::schema::PrimitiveType_Ceil; namespace mindspore::kernel { class ArithmeticSelfCPUKernel : public LiteKernel { @@ -44,8 +43,9 @@ class ArithmeticSelfCPUKernel : public LiteKernel { public: explicit ArithmeticSelfCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { switch (parameter->type_) { case PrimitiveType_Abs: arithmeticSelf_run_ = ElementAbs; @@ -106,4 +106,3 @@ class ArithmeticSelfCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_SELF_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc index a24cae52de..79b86e4d6d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc @@ -28,6 +28,11 @@ int BatchToSpaceCPUKernel::Init() { } int BatchToSpaceCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto input = inputs_[0]; auto output = outputs_[0]; const float *input_data = reinterpret_cast(input->Data()); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h index 2933bee478..2ac09c455a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h @@ -22,8 +22,9 @@ namespace mindspore::kernel { class BatchToSpaceCPUKernel : public BatchToSpaceBaseCPUKernel { public: BatchToSpaceCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : BatchToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : BatchToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~BatchToSpaceCPUKernel() = default; @@ -34,4 +35,3 @@ class BatchToSpaceCPUKernel : public BatchToSpaceBaseCPUKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BATCH_TO_SPACE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.cc index 063fd4f6aa..a9645f0408 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.cc @@ -53,6 +53,11 @@ int BatchNormRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int BatchnormCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } in_addr_ = reinterpret_cast(inputs_.at(0)->Data()); mean_addr_ = reinterpret_cast(inputs_.at(1)->Data()); var_addr_ = reinterpret_cast(inputs_.at(2)->Data()); @@ -76,10 +81,10 @@ int BatchnormCPUKernel::Run() { kernel::LiteKernel *CpuBatchnormKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_BatchNorm); - auto *kernel = new (std::nothrow) BatchnormCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) BatchnormCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new BatchNormCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.h b/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.h index e29aa20a2f..c3532b19ae 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.h @@ -28,8 +28,9 @@ namespace mindspore::kernel { class BatchnormCPUKernel : public LiteKernel { public: BatchnormCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { batchnorm_param_ = reinterpret_cast(parameter); } ~BatchnormCPUKernel() override { delete batchnorm_param_; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bias.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/bias.cc index 9464b342a2..8d3d99d7b9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/bias.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bias.cc @@ -31,6 +31,11 @@ namespace mindspore::kernel { int BiasCPUKernel::ReSize() { return RET_OK; } int BiasCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto in = reinterpret_cast(inputs_.at(0)->Data()); auto bias = reinterpret_cast(inputs_.at(1)->Data()); auto out = reinterpret_cast(outputs_.at(0)->Data()); @@ -44,6 +49,10 @@ int BiasCPUKernel::Run() { } int BiasCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto dims = inputs_[0]->shape(); MS_ASSERT(dims.size() <= 5); bias_param_->ndim_ = dims.size(); @@ -58,10 +67,11 @@ int BiasCPUKernel::Init() { kernel::LiteKernel *CpuBiasFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *parameter, - const lite::Context *ctx, const kernel::KernelKey &desc) { + const lite::Context *ctx, const kernel::KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(parameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_BiasAdd); - auto kernel = new (std::nothrow) BiasCPUKernel(parameter, inputs, outputs); + auto kernel = new (std::nothrow) BiasCPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bias.h b/mindspore/lite/src/runtime/kernel/arm/fp32/bias.h index a4d88378fd..0282c668cb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/bias.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bias.h @@ -24,9 +24,10 @@ namespace mindspore::kernel { class BiasCPUKernel : public LiteKernel { public: BiasCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) { - bias_param_ = reinterpret_cast(parameter); + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + bias_param_ = reinterpret_cast(parameter); } ~BiasCPUKernel() override = default; @@ -40,4 +41,3 @@ class BiasCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BIAS_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.cc index e57fe298ab..0bc583343a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.cc @@ -20,12 +20,11 @@ #include "src/kernel_registry.h" #include "include/errorcode.h" - using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; -using mindspore::schema::PrimitiveType_BiasGrad; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_BiasGrad; namespace mindspore::kernel { int BiasGradCPUKernel::InferShape() { @@ -68,10 +67,14 @@ int BiasGradCPUKernel::Init() { return RET_OK; } - int BiasGradCPUKernel::ReSize() { return 0; } int BiasGradCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto in = reinterpret_cast(inputs_.at(0)->Data()); auto out = reinterpret_cast(outputs_.at(0)->Data()); // size_t data_size = inputs_.at(0)->ElementsNum(); @@ -91,14 +94,14 @@ int BiasGradCPUKernel::Run() { return RET_OK; } - kernel::LiteKernel *CpuBiasGradFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_BiasGrad); - auto *kernel = new (std::nothrow) BiasGradCPUKernel(reinterpret_cast(opParameter), inputs, outputs); + auto *kernel = + new (std::nothrow) BiasGradCPUKernel(reinterpret_cast(opParameter), inputs, outputs, ctx, primitive); MS_ASSERT(kernel != nullptr); auto ret = kernel->Init(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.h index 797abfd162..ed652ab617 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bias_grad.h @@ -27,8 +27,9 @@ namespace mindspore::kernel { class BiasGradCPUKernel : public LiteKernel { public: explicit BiasGradCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { bias_param = reinterpret_cast(parameter); } ~BiasGradCPUKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.cc index 2a07167058..79bd775e70 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.cc @@ -96,12 +96,12 @@ int BNGradInputCPUKernel::Run() { kernel::LiteKernel *CpuBNGradInputFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_BNGradInput); // parameter->name = opDef.name()->str().data(); // parameter->type = opDef.attr_type(); - auto *kernel = new (std::nothrow) BNGradInputCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) BNGradInputCPUKernel(opParameter, inputs, outputs, ctx, primitive); MS_ASSERT(kernel != nullptr); auto ret = kernel->Init(); if (RET_OK != ret) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.h b/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.h index e4e6d6e746..182257d5a7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/bngrad_input.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class BNGradInputCPUKernel : public LiteKernel { public: explicit BNGradInputCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~BNGradInputCPUKernel() override { delete workspace; } int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.cc index ed5ffd9822..84e4366e3d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.cc @@ -27,6 +27,10 @@ using mindspore::schema::PrimitiveType_BroadcastTo; namespace mindspore::kernel { int BroadcastToCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto input_shape = inputs_[0]->shape(); for (size_t i = 0; i < input_shape.size(); ++i) { shape_info_.input_shape_[i] = input_shape[i]; @@ -42,6 +46,11 @@ int BroadcastToCPUKernel::Init() { } int BroadcastToCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto input_data = reinterpret_cast(inputs_.at(0)->Data()); auto output_data = reinterpret_cast(outputs_.at(0)->Data()); @@ -51,13 +60,13 @@ int BroadcastToCPUKernel::Run() { kernel::LiteKernel *CpuBroadcastToFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (op_parameter == nullptr) { MS_LOG(ERROR) << "Input op_parameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_BroadcastTo); - auto *kernel = new (std::nothrow) BroadcastToCPUKernel(op_parameter, inputs, outputs); + auto *kernel = new (std::nothrow) BroadcastToCPUKernel(op_parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new BroadcastToCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.h b/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.h index cfb8969448..c0cefc522b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/broadcast_to.h @@ -25,18 +25,18 @@ namespace mindspore::kernel { class BroadcastToCPUKernel : public LiteKernel { public: BroadcastToCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~BroadcastToCPUKernel() = default; int Init() override; - int ReSize() override { - return 0; - } + int ReSize() override { return 0; } int Run() override; + private: BroadcastShapeInfo shape_info_; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_BROADCAST_TO_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc index b17bc4bb7a..efcf7f04fc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc @@ -30,9 +30,6 @@ using mindspore::schema::PrimitiveType_Cast; namespace mindspore::kernel { namespace { -constexpr int kInputNum = 1; -constexpr int kOutputNum = 1; -const std::vector kSupportInputDataType = {kNumberTypeUInt8, kNumberTypeInt32}; int CastRun(int thread_id, LiteParallelGroupEnv *penv, void *cdata) { if (cdata == nullptr) { MS_LOG(ERROR) << "input cdata is nullptr!"; @@ -44,12 +41,16 @@ int CastRun(int thread_id, LiteParallelGroupEnv *penv, void *cdata) { } // namespace int CastCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } data_num_ = inputs_[0]->ElementsNum(); if (data_num_ == 0) { return RET_OK; } - thread_num_ = MSMIN(thread_num_, data_num_); - stride_ = UP_DIV(data_num_, thread_num_); + opParameter->thread_num_ = MSMIN(opParameter->thread_num_, data_num_); + stride_ = UP_DIV(data_num_, opParameter->thread_num_); return RET_OK; } @@ -77,16 +78,21 @@ int CastCPUKernel::DoCast(int thread_id) { } int CastCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } if (data_num_ == 0) { return RET_OK; } - return LiteBackendParallelLaunch(CastRun, this, thread_num_); + return LiteBackendParallelLaunch(CastRun, this, opParameter->thread_num_); } kernel::LiteKernel *CpuCastFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; @@ -99,7 +105,7 @@ kernel::LiteKernel *CpuCastFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs) { - if (ctx != nullptr) { - thread_num_ = ctx->thread_num_; - } + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + opParameter->thread_num_ = ctx->thread_num_; } ~CastCPUKernel() = default; int Init() override; - int ReSize() override { - return 0; - }; + int ReSize() override { return 0; }; int Run() override; int DoCast(int thread_id); + private: - uint32_t thread_num_; uint32_t stride_; uint32_t data_num_; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CAST_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/concat.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/concat.cc index df3ba83c5c..2ba7ca8fe8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/concat.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/concat.cc @@ -28,44 +28,54 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Concat; namespace mindspore::kernel { - int ConcatCPUKernel::Init() { - ConcatBaseCPUKernel::Init(); - schema::Format input0_format = inputs_[0]->GetFormat(); - bool need_convert_format = false; - for (size_t i = 1; i < inputs_.size(); ++i) { - if (inputs_[i]->GetFormat() != input0_format) { - need_convert_format = true; - } - } - if (!need_convert_format) { - outputs_[0]->SetFormat(input0_format); - return RET_OK; - } - MS_LOG(ERROR) << "All input format should be the same!"; - return RET_ERROR; +int ConcatCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } + auto ret = ConcatBaseCPUKernel::Init(); + if (ret != RET_OK) { + return ret; + } + schema::Format input0_format = inputs_[0]->GetFormat(); + bool need_convert_format = false; + for (size_t i = 1; i < inputs_.size(); ++i) { + if (inputs_[i]->GetFormat() != input0_format) { + need_convert_format = true; } + } + if (!need_convert_format) { + outputs_[0]->SetFormat(input0_format); + return RET_OK; + } + MS_LOG(ERROR) << "All input format should be the same!"; + return RET_ERROR; +} - int ConcatCPUKernel::ReSize() { return RET_OK; } +int ConcatCPUKernel::ReSize() { return RET_OK; } - int ConcatCPUKernel::Run() { - auto input_num = inputs_.size(); - std::vector inputs_addr(input_num, nullptr); - std::vector inputs_output_shape(input_num + 1, nullptr); +int ConcatCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } + auto input_num = inputs_.size(); + std::vector inputs_addr(input_num, nullptr); + std::vector inputs_output_shape(input_num + 1, nullptr); - std::vector > shapes; - for (size_t i = 0; i < input_num; ++i) { - inputs_addr[i] = inputs_[i]->Data(); - shapes.push_back(inputs_[i]->shape()); - inputs_output_shape[i] = shapes[i].data(); - } - auto output_shape = outputs_.at(0)->shape(); - inputs_output_shape[input_num] = output_shape.data(); - auto output_addr = outputs_.at(0)->Data(); + std::vector> shapes; + for (size_t i = 0; i < input_num; ++i) { + inputs_addr[i] = inputs_[i]->Data(); + shapes.push_back(inputs_[i]->shape()); + inputs_output_shape[i] = shapes[i].data(); + } + auto output_shape = outputs_.at(0)->shape(); + inputs_output_shape[input_num] = output_shape.data(); + auto output_addr = outputs_.at(0)->Data(); - Concat(reinterpret_cast(inputs_addr.data()), input_num, axis_, inputs_output_shape.data(), - output_shape.size(), output_addr); - return RET_OK; - } + Concat(reinterpret_cast(inputs_addr.data()), input_num, axis_, inputs_output_shape.data(), + output_shape.size(), output_addr); + return RET_OK; +} } // namespace mindspore::kernel - - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/concat.h b/mindspore/lite/src/runtime/kernel/arm/fp32/concat.h index 078921a53d..cafd6c84f7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/concat.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/concat.h @@ -28,8 +28,9 @@ namespace mindspore::kernel { class ConcatCPUKernel : public ConcatBaseCPUKernel { public: ConcatCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConcatBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : ConcatBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~ConcatCPUKernel() = default; @@ -42,4 +43,3 @@ class ConcatCPUKernel : public ConcatBaseCPUKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONCAT_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc index f30d2af194..1189d40da3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc @@ -29,6 +29,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; +using mindspore::lite::RET_INFER_INVALID; using mindspore::schema::PrimitiveType_Conv2D; namespace mindspore::kernel { @@ -136,6 +137,10 @@ void ConvolutionCPUKernel::ConfigInputOutput() { } int ConvolutionCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvolutionBase init failed."; @@ -204,6 +209,11 @@ int ConvolutionImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int ConvolutionCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto input_tensor = inputs_.at(kInputIndex); auto ori_input_data = input_tensor->Data(); int in_batch = conv_param_->input_batch_; @@ -223,7 +233,7 @@ int ConvolutionCPUKernel::Run() { kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); auto conv_param = reinterpret_cast(opParameter); @@ -245,20 +255,21 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vectorInit(); - if (ret != RET_OK) { + if (ret != RET_OK && ret != RET_INFER_INVALID) { delete kernel; MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.h index 9e79609280..20642484c9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.h @@ -27,8 +27,9 @@ namespace mindspore::kernel { class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel { public: ConvolutionCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~ConvolutionCPUKernel() override { if (packed_input_ != nullptr) { free(packed_input_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc index e2456bfdf4..beda5ba884 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.cc @@ -136,6 +136,10 @@ void Convolution1x1CPUKernel::Pre1x1Trans(float *src_input, float *src_output) { } int Convolution1x1CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } ConvolutionBaseCPUKernel::Init(); InitConv1x1MatmulParam(); @@ -178,6 +182,11 @@ int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int Convolution1x1CPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto src_in = reinterpret_cast(inputs_[0]->Data()); auto src_out = reinterpret_cast(outputs_[0]->Data()); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.h index 6d5840e017..1753821147 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1.h @@ -34,8 +34,9 @@ namespace mindspore::kernel { class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { public: Convolution1x1CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) { matmul_param_ = new MatMulParameter(); } ~Convolution1x1CPUKernel(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc index aa1f363010..9bd829bb67 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.cc @@ -166,6 +166,10 @@ void Convolution3x3CPUKernel::ConfigInputOutput() { } int Convolution3x3CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvolutionBase init failed."; @@ -237,6 +241,11 @@ int Convolution3x3Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int Convolution3x3CPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto input_tensor = inputs_.at(kInputIndex); auto ori_input_data = input_tensor->Data(); int in_batch = conv_param_->input_batch_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.h index 90dbac5a6f..f0aa53ca79 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_3x3.h @@ -26,8 +26,9 @@ namespace mindspore::kernel { class Convolution3x3CPUKernel : public ConvolutionBaseCPUKernel { public: Convolution3x3CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~Convolution3x3CPUKernel() override { if (transformed_filter_addr_ != nullptr) { free(transformed_filter_addr_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc index 6eec3f91aa..01901d01fe 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc @@ -25,6 +25,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; +using mindspore::lite::RET_INFER_INVALID; using mindspore::schema::PrimitiveType_DepthwiseConv2D; namespace mindspore::kernel { @@ -86,6 +87,10 @@ int ConvolutionDepthwiseCPUKernel::InitBuffer() { } int ConvolutionDepthwiseCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } // conv base init ConvolutionBaseCPUKernel::Init(); @@ -144,6 +149,11 @@ int ConvDwRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int ConvolutionDepthwiseCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } if (conv_param_->input_channel_ != conv_param_->output_channel_) { MS_LOG(ERROR) << "Only support input channel equals output channel."; return RET_ERROR; @@ -164,7 +174,7 @@ int ConvolutionDepthwiseCPUKernel::Run() { packed_output_ = output_addr; } - auto ret = LiteBackendParallelLaunch(ConvDwRun, this, conv_param_->thread_num_); + ret = LiteBackendParallelLaunch(ConvDwRun, this, conv_param_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvDwRun error: error_code[" << ret << "]"; return RET_ERROR; @@ -180,11 +190,11 @@ int ConvolutionDepthwiseCPUKernel::Run() { kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); kernel::LiteKernel *kernel; - kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); + kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive); // auto param = reinterpret_cast(opParameter); // if (param->kernel_h_ == 3 && param->kernel_w_ == 3 && param->stride_h_ == 1 && param->stride_w_ == 1 && // param->dilation_h_ == 1 && param->dilation_w_ == 1) { @@ -192,12 +202,13 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vectorInit(); - if (ret != RET_OK) { + if (ret != RET_OK && ret != RET_INFER_INVALID) { delete kernel; MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.h index e0d742c6ae..08706ac050 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.h @@ -26,8 +26,9 @@ namespace mindspore::kernel { class ConvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel { public: ConvolutionDepthwiseCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~ConvolutionDepthwiseCPUKernel() override { delete sliding_; free(packed_weight_); @@ -55,4 +56,3 @@ class ConvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.cc index 6f2322255b..798cd72a04 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.cc @@ -100,6 +100,10 @@ int ConvolutionDepthwise3x3CPUKernel::InitBuffer() { } int ConvolutionDepthwise3x3CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } // conv base init ConvolutionBaseCPUKernel::Init(); @@ -164,6 +168,11 @@ int ConvDw3x3Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int ConvolutionDepthwise3x3CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } if (conv_param_->input_channel_ != conv_param_->output_channel_) { MS_LOG(ERROR) << "Only support input channel equals output channel."; return RET_ERROR; @@ -184,7 +193,7 @@ int ConvolutionDepthwise3x3CPUKernel::Run() { packed_output_ = output_addr; } - auto ret = LiteBackendParallelLaunch(ConvDw3x3Run, this, conv_param_->thread_num_); + ret = LiteBackendParallelLaunch(ConvDw3x3Run, this, conv_param_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvDw3x3Run error: error_code[" << ret << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h index 63f3d35cd2..ee937456da 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3.h @@ -26,8 +26,9 @@ namespace mindspore::kernel { class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel { public: ConvolutionDepthwise3x3CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~ConvolutionDepthwise3x3CPUKernel() override { free(packed_weight_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.cc index 3deb6a2017..20e224a129 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.cc @@ -135,11 +135,12 @@ int ConvolutionGradFilterCPUKernel::Run() { kernel::LiteKernel *CpuConvGradFilterFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Conv2DGradFilter); - auto *kernel = new (std::nothrow) ConvolutionGradFilterCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) ConvolutionGradFilterCPUKernel(opParameter, inputs, outputs, ctx, primitive); MS_ASSERT(kernel != nullptr); auto ret = kernel->Init(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h index c32a798eaf..20ce826c02 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_filter.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class ConvolutionGradFilterCPUKernel : public LiteKernel { public: explicit ConvolutionGradFilterCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~ConvolutionGradFilterCPUKernel() override { delete workspace; } int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.cc index 6e0683b301..bd9248137b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.cc @@ -23,9 +23,9 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; -using mindspore::schema::PrimitiveType_Conv2DGradInput; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Conv2DGradInput; namespace mindspore::kernel { int ConvolutionGradInputCPUKernel::Init() { @@ -115,11 +115,11 @@ int ConvolutionGradInputCPUKernel::Run() { kernel::LiteKernel *CpuConvGradInputFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Conv2DGradInput); - auto *kernel = new (std::nothrow) ConvolutionGradInputCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) ConvolutionGradInputCPUKernel(opParameter, inputs, outputs, ctx, primitive); MS_ASSERT(kernel != nullptr); auto ret = kernel->Init(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h index 86901b37ba..c1297fef77 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_grad_input.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class ConvolutionGradInputCPUKernel : public LiteKernel { public: explicit ConvolutionGradInputCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~ConvolutionGradInputCPUKernel() override { delete workspace; } int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc index 3204ee8e05..95bbb960e1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc @@ -247,6 +247,10 @@ int ConvolutionWinogradCPUKernel::ConfigInputOutput() { } int ConvolutionWinogradCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvolutionBase init failed."; @@ -339,6 +343,11 @@ int ConvolutionWinogradImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata } int ConvolutionWinogradCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto input_tensor = inputs_.at(kInputIndex); auto ori_input_data = input_tensor->Data(); int in_batch = conv_param_->input_batch_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h index d11cc8ae4b..04261d1acb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h @@ -28,8 +28,9 @@ namespace mindspore::kernel { class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { public: ConvolutionWinogradCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx, int output_unit) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx), output_unit_(output_unit) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive, int output_unit) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive), output_unit_(output_unit) {} ~ConvolutionWinogradCPUKernel() override { if (tmp_data_ != nullptr) { free(tmp_data_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc index 5282c6cada..6b529637e2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc @@ -40,15 +40,7 @@ int CropLaunch(int thread_id, LiteParallelGroupEnv *penv, void *cdata) { } } // namespace -int CropCPUKernel::Init() { - schema::Format input0_format = inputs_[0]->GetFormat(); - if (input0_format != schema::Format_NCHW && input0_format != schema::Format_NHWC) { - MS_LOG(ERROR) << "Unsupport format " << input0_format; - return RET_FORMAT_ERR; - } - outputs_[0]->SetFormat(input0_format); - return RET_OK; -} +int CropCPUKernel::Init() { return RET_OK; } int CropCPUKernel::CropParallelRun(int thread_id) { auto input = inputs_[0]; @@ -61,6 +53,11 @@ int CropCPUKernel::CropParallelRun(int thread_id) { } int CropCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto input = inputs_[0]; auto output = outputs_[0]; auto param = reinterpret_cast(opParameter); @@ -71,7 +68,7 @@ int CropCPUKernel::Run() { return RET_OK; } - int ret = LiteBackendParallelLaunch(CropLaunch, this, param->op_parameter_.thread_num_); + auto ret = LiteBackendParallelLaunch(CropLaunch, this, param->op_parameter_.thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "Crop launch fail!ret: " << ret; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/crop.h b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.h index f9656b2355..8165c84551 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/crop.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.h @@ -24,8 +24,9 @@ namespace mindspore::kernel { class CropCPUKernel : public CropBaseCPUKernel { public: CropCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : CropBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : CropBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~CropCPUKernel() = default; int Init() override; int ReSize() override { return 0; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc index be571478fe..a9e6b15d55 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc @@ -173,6 +173,10 @@ int DeConvolutionCPUKernel::DoPostFunc(int task_id) { } int DeConvolutionCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } ConvolutionBaseCPUKernel::Init(); int error_code = InitParam(); @@ -190,6 +194,11 @@ int DeConvolutionCPUKernel::Init() { } int DeConvolutionCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } float *src_in = reinterpret_cast(inputs_[0]->Data()); float *src_out = reinterpret_cast(outputs_[0]->Data()); @@ -214,14 +223,13 @@ int DeConvolutionCPUKernel::Run() { return RET_OK; } - kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); - auto kernel = new (std::nothrow) kernel::DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx); + auto kernel = new (std::nothrow) kernel::DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h index 165ebf5d28..fb29ff9e5e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.h @@ -31,8 +31,9 @@ namespace mindspore::kernel { class DeConvolutionCPUKernel : public ConvolutionBaseCPUKernel { public: DeConvolutionCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) { matmul_param_ = new MatMulParameter(); } ~DeConvolutionCPUKernel() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc index 07e5747847..479d2b6659 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.cc @@ -102,6 +102,10 @@ int DeconvolutionDepthwiseCPUKernel::InitBuffer() { } int DeconvolutionDepthwiseCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } InitSlideParam(); // conv base init ConvolutionBaseCPUKernel::Init(); @@ -155,6 +159,11 @@ int DeconvDwRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int DeconvolutionDepthwiseCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } if (conv_param_->input_channel_ != conv_param_->output_channel_) { MS_LOG(ERROR) << "Only support input channel equals output channel."; return RET_ERROR; @@ -192,10 +201,11 @@ int DeconvolutionDepthwiseCPUKernel::Run() { kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); - auto kernel = new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); + auto kernel = + new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.h b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.h index f2ed4b5d95..0ad3c18d44 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_depthwise.h @@ -26,8 +26,9 @@ namespace mindspore::kernel { class DeconvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel { public: DeconvolutionDepthwiseCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~DeconvolutionDepthwiseCPUKernel() override { delete sliding_; free(packed_weight_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc index c28a9cd4cd..1d29781a3c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc @@ -22,10 +22,10 @@ #include "include/errorcode.h" using mindspore::lite::KernelRegistrar; -using mindspore::lite::RET_FORMAT_ERR; using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_PARAM_INVALID; +using mindspore::lite::RET_FORMAT_ERR; using mindspore::lite::RET_OK; +using mindspore::lite::RET_PARAM_INVALID; using mindspore::schema::PrimitiveType_DepthToSpace; namespace mindspore::kernel { @@ -41,6 +41,11 @@ int DepthToSpaceCPUKernel::Init() { } int DepthToSpaceCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto input = inputs_[0]; auto output = outputs_[0]; const float *input_data = reinterpret_cast(input->Data()); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h index e1676ccb8e..d4f273ad32 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h @@ -23,17 +23,15 @@ namespace mindspore::kernel { class DepthToSpaceCPUKernel : public DepthToSpaceBaseCPUKernel { public: DepthToSpaceCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : DepthToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : DepthToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~DepthToSpaceCPUKernel() = default; int Init() override; - int ReSize() override { - return 0; - } + int ReSize() override { return 0; } int Run() override; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DEPTH_TO_SPACE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc index 0904f90b5d..fe60dda8a2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.cc @@ -26,6 +26,10 @@ using mindspore::schema::PrimitiveType_EmbeddingLookup; namespace mindspore::kernel { int EmbeddingLookupCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } embedding_lookup_parameter_ = reinterpret_cast(opParameter); embedding_lookup_parameter_->thread_num = thread_count_; embedding_lookup_parameter_->ids_size_ = inputs_.back()->ElementsNum(); @@ -84,6 +88,11 @@ int EmbeddingLookupRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int EmbeddingLookupCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } int dest_loc = 0; for (int i = 0; i < inputs_.size() - 1; i++) { auto input_t = reinterpret_cast(inputs_.at(i)->Data()); @@ -104,13 +113,13 @@ int EmbeddingLookupCPUKernel::Run() { kernel::LiteKernel *CpuEmbeddingLookupFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *parameter, const lite::Context *ctx, - const KernelKey &desc) { + const KernelKey &desc, const lite::Primitive *primitive) { if (parameter == nullptr || ctx == nullptr) { MS_LOG(ERROR) << "parameter or ctx is nullptr"; return nullptr; } MS_ASSERT(desc.type == PrimitiveType_EmbeddingLookup); - auto *kernel = new (std::nothrow) EmbeddingLookupCPUKernel(parameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) EmbeddingLookupCPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "Create Kernel failed, name: " << parameter->name_; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h index 6afa0d5620..fd9defd03b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/embedding_lookup.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class EmbeddingLookupCPUKernel : public LiteKernel { public: explicit EmbeddingLookupCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {} ~EmbeddingLookupCPUKernel() override{}; int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.cc index 8f37cb30bb..98bf98d2ca 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/expandDims.cc @@ -29,6 +29,10 @@ using mindspore::schema::PrimitiveType_ExpandDims; namespace mindspore::kernel { int ExpandDimsCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } int ret = ReSize(); return ret; } @@ -65,9 +69,14 @@ int ExpandDimsRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int ExpandDimsCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } in_ptr_ = reinterpret_cast(inputs_.at(0)->Data()); out_ptr_ = reinterpret_cast(outputs_.at(0)->Data()); - int ret = LiteBackendParallelLaunch(ExpandDimsRun, this, thread_sz_count_); + auto ret = LiteBackendParallelLaunch(ExpandDimsRun, this, thread_sz_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "ExpandDimsRun error error_code[" << ret << "]"; return ret; @@ -78,10 +87,10 @@ int ExpandDimsCPUKernel::Run() { kernel::LiteKernel *CpuExpandsDimsFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_ExpandDims); - auto *kernel = new (std::nothrow) ExpandDimsCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) ExpandDimsCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new ExpandDimsCPUKernel fail!"; return nullptr; @@ -98,4 +107,3 @@ kernel::LiteKernel *CpuExpandsDimsFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {} ~ExpandDimsCPUKernel() override = default; int Init() override; @@ -51,4 +52,3 @@ class ExpandDimsCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_CCSRC_KERNEL_CPU_ARM_FP32_EXPANDDIMS_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fill.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fill.cc index 96a583decc..32866a54f2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fill.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fill.cc @@ -35,6 +35,10 @@ constexpr int kOutputNum = 1; } // namespace int FillCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } data_size_ = outputs_.front()->ElementsNum(); thread_sz_count_ = MSMIN(thread_count_, data_size_); thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); @@ -68,12 +72,17 @@ int FillRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int FillCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto fillData = inputs_.at(inputs_.size() - 1); auto output = outputs_.front(); auto fill_data = reinterpret_cast(fillData->Data()); src_data_ = fill_data[0]; out_ptr_ = reinterpret_cast(output->Data()); - int ret = LiteBackendParallelLaunch(FillRun, this, thread_sz_count_); + auto ret = LiteBackendParallelLaunch(FillRun, this, thread_sz_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "FillRun error error_code[" << ret << "]"; return ret; @@ -84,14 +93,14 @@ int FillCPUKernel::Run() { kernel::LiteKernel *CpuFillFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); if (opParameter == nullptr) { MS_LOG(ERROR) << "Create kernel failed, opParameter is nullptr, type: PrimitiveType_Fill. "; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Fill); - auto *kernel = new (std::nothrow) FillCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) FillCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new FillCPUKernel fail!"; return nullptr; @@ -108,4 +117,3 @@ kernel::LiteKernel *CpuFillFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {} ~FillCPUKernel() override = default; int Init() override; @@ -49,4 +50,3 @@ class FillCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FILL_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.cc index 3adddbf91f..868705746b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/flatten.cc @@ -28,6 +28,10 @@ using mindspore::schema::PrimitiveType_Flatten; namespace mindspore::kernel { int FlattenCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto output_shape = outputs_[0]->shape(); flatten_param_->size = sizeof(float); for (int i = 0; i < output_shape.size(); i++) { @@ -39,6 +43,11 @@ int FlattenCPUKernel::Init() { int FlattenCPUKernel::ReSize() { return RET_OK; } int FlattenCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto input = reinterpret_cast(inputs_[0]->Data()); auto output = reinterpret_cast(outputs_[0]->Data()); Flatten(input, output, flatten_param_); @@ -48,14 +57,14 @@ int FlattenCPUKernel::Run() { kernel::LiteKernel *CpuFlattenFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); if (opParameter == nullptr) { MS_LOG(ERROR) << "Create kernel failed, opParameter is nullptr, type: PrimitiveType_Flatten. "; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Flatten); - auto *kernel = new (std::nothrow) FlattenCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) FlattenCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new FlattenCPUKernel fail!"; return nullptr; @@ -72,4 +81,3 @@ kernel::LiteKernel *CpuFlattenFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { flatten_param_ = reinterpret_cast(parameter); } ~FlattenCPUKernel() override { delete flatten_param_; } @@ -44,4 +45,3 @@ class FlattenCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FLATTEN_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc index 366167b5f3..e3c120c6f9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc @@ -44,6 +44,10 @@ FullconnectionCPUKernel::~FullconnectionCPUKernel() { int FullconnectionCPUKernel::ReSize() { return RET_OK; } int FullconnectionCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } fc_param_->row_ = (inputs_[0]->shape())[0]; fc_param_->col_ = (inputs_[1]->shape())[0]; fc_param_->deep_ = (inputs_[1]->shape())[1]; @@ -105,6 +109,11 @@ int FullconnectionCPUKernel::DoMatmul(int task_id) { } int FullconnectionCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto a_ptr = reinterpret_cast(inputs_.at(0)->Data()); auto output_ptr = reinterpret_cast(outputs_.at(0)->Data()); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h index be4c1f72b5..f10f163f95 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h @@ -29,8 +29,9 @@ namespace mindspore::kernel { class FullconnectionCPUKernel : public FullconnectionBaseCPUKernel { public: FullconnectionCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : FullconnectionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : FullconnectionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~FullconnectionCPUKernel() override; int Init() override; @@ -48,4 +49,3 @@ class FullconnectionCPUKernel : public FullconnectionBaseCPUKernel { }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FULLCONNECTION_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.cc index 4452655bbc..00810f89aa 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.cc @@ -28,6 +28,10 @@ using mindspore::schema::PrimitiveType_FusedBatchNorm; namespace mindspore::kernel { int FusedBatchnormCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } input_shape_ = reinterpret_cast(malloc(sizeof(int) * inputs_[0]->shape().size())); memcpy(input_shape_, inputs_[0]->shape().data(), inputs_[0]->shape().size() * sizeof(int)); return RET_OK; @@ -36,6 +40,11 @@ int FusedBatchnormCPUKernel::Init() { int FusedBatchnormCPUKernel::ReSize() { return RET_OK; } int FusedBatchnormCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto input_addr = reinterpret_cast(inputs_.at(0)->Data()); auto scale_addr = reinterpret_cast(inputs_.at(1)->Data()); auto offest_addr = reinterpret_cast(inputs_.at(2)->Data()); @@ -51,10 +60,11 @@ int FusedBatchnormCPUKernel::Run() { kernel::LiteKernel *CpuFusedBatchnormKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_FusedBatchNorm); - auto *kernel = new (std::nothrow) FusedBatchnormCPUKernel(opParameter, inputs, outputs); + FusedBatchnormCPUKernel *kernel = new (std::nothrow) FusedBatchnormCPUKernel(opParameter, inputs, outputs, ctx, + primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new FusedBatchnormCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.h b/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.h index bce0660be9..55c4ba2bb7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class FusedBatchnormCPUKernel : public LiteKernel { public: FusedBatchnormCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { fused_batchnorm_param_ = reinterpret_cast(parameter); } ~FusedBatchnormCPUKernel() override { delete fused_batchnorm_param_; } @@ -42,4 +43,3 @@ class FusedBatchnormCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FUSED_BATCHNORM_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc index fd073a9f00..d75c836d74 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc @@ -92,6 +92,11 @@ int GatherRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int GatherCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } int error_code = LiteBackendParallelLaunch(GatherRun, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]"; @@ -103,11 +108,11 @@ int GatherCPUKernel::Run() { kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Gather); - auto *kernel = new (std::nothrow) GatherCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) GatherCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { return nullptr; } @@ -123,4 +128,3 @@ kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} ~GatherCPUKernel() override = default; int Init() override; @@ -42,4 +43,3 @@ class GatherCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHER_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.cc index 109ce338f8..3d1c2a245e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd.cc @@ -38,6 +38,10 @@ GatherNdCPUKernel::~GatherNdCPUKernel() { } int GatherNdCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto indices_tensor = inputs_.at(1); auto indices_shape = indices_tensor->shape(); int indices_rank = indices_shape.size(); @@ -112,9 +116,14 @@ int GatherNdRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int GatherNdCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } in_ptr_ = reinterpret_cast(inputs_.front()->Data()); out_ptr_ = reinterpret_cast(outputs_.front()->Data()); - int ret = LiteBackendParallelLaunch(GatherNdRun, this, thread_sz_count_); + auto ret = LiteBackendParallelLaunch(GatherNdRun, this, thread_sz_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "gatherNd error error_code[" << ret << "]"; return ret; @@ -125,11 +134,11 @@ int GatherNdCPUKernel::Run() { kernel::LiteKernel *CpuGatherNdFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_GatherNd); - auto *kernel = new (std::nothrow) GatherNdCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) GatherNdCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { return nullptr; } @@ -145,4 +154,3 @@ kernel::LiteKernel *CpuGatherNdFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {} ~GatherNdCPUKernel() override; int Init() override; @@ -53,4 +54,3 @@ class GatherNdCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GATHERND_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.cc index bb1286eebd..3a62427a29 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.cc @@ -74,6 +74,11 @@ int LocalResponseNormRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int LocalResponseNormCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } int error_code = LiteBackendParallelLaunch(LocalResponseNormRun, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "LocalResponseNorm function error error_code[" << error_code << "]"; @@ -85,11 +90,12 @@ int LocalResponseNormCPUKernel::Run() { kernel::LiteKernel *CpuLocalResponseNormFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_LocalResponseNormalization); - auto *kernel = new (std::nothrow) LocalResponseNormCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) LocalResponseNormCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new LocalResponseNormCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.h b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.h index ea65a3e923..90cdc8a66e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/local_response_norm.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class LocalResponseNormCPUKernel : public LiteKernel { public: LocalResponseNormCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} ~LocalResponseNormCPUKernel() override = default; int Init() override; @@ -40,4 +41,3 @@ class LocalResponseNormCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LOCAL_RESPONSE_NORM_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.cc index cc779de6e5..ea01c6a93a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.cc @@ -99,6 +99,10 @@ int LstmCPUKernel::InitWeightBias() { } int LstmCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto ret = InitParam(); if (ret != RET_OK) { MS_LOG(ERROR) << "LstmCPUKernel InitParam error."; @@ -137,6 +141,11 @@ int LstmCPUKernel::ReSize() { } int LstmCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto input = inputs_.at(kInputIndex); MS_ASSERT(input != nullptr); auto hidden_state = inputs_.at(4); @@ -162,11 +171,12 @@ int LstmCPUKernel::Run() { kernel::LiteKernel *CpuLstmKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, - const lite::Context *ctx, const kernel::KernelKey &desc) { + const lite::Context *ctx, const kernel::KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Lstm); - auto *kernel = new (std::nothrow) LstmCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) LstmCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.h b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.h index 61488ca2c4..8fdd4dd8d9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class LstmCPUKernel : public LiteKernel { public: LstmCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { lstm_parm_ = reinterpret_cast(opParameter); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc index 938a6d16b9..1e955860e1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc @@ -33,6 +33,10 @@ MatmulCPUKernel::~MatmulCPUKernel() { int MatmulCPUKernel::ReSize() { return RET_OK; } int MatmulCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } int batch = 1; auto a_shape = inputs_[0]->shape(); auto c_shape = outputs_[0]->shape(); @@ -88,6 +92,11 @@ int MatmulFloatRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int MatmulCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto a_ptr = reinterpret_cast(inputs_[0]->Data()); auto b_ptr = reinterpret_cast(inputs_[1]->Data()); auto c_ptr = reinterpret_cast(outputs_[0]->Data()); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h index 0f617d0179..8bdc43caee 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h @@ -26,8 +26,9 @@ namespace mindspore::kernel { class MatmulCPUKernel : public MatmulBaseCPUKernel { public: explicit MatmulCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : MatmulBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : MatmulBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~MatmulCPUKernel() override; int Init() override; int ReSize() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc index 9ee4f8a577..571f06ef22 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc @@ -28,6 +28,11 @@ int Nchw2NhwcCPUKernel::Init() { return RET_OK; } int Nchw2NhwcCPUKernel::ReSize() { return RET_OK; } int Nchw2NhwcCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto input = inputs_[0]; auto output = outputs_[0]; @@ -39,10 +44,10 @@ int Nchw2NhwcCPUKernel::Run() { kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Nchw2Nhwc); - auto *kernel = new (std::nothrow) Nchw2NhwcCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) Nchw2NhwcCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new Nchw2NhwcCPUKernel fail!"; return nullptr; @@ -59,4 +64,3 @@ kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~Nchw2NhwcCPUKernel() override = default; int Init() override; @@ -39,4 +40,3 @@ class Nchw2NhwcCPUKernel : public LiteKernel { }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.cc index 480d58eaa2..f511940f9a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw.cc @@ -28,6 +28,11 @@ int Nhwc2NchwCPUKernel::Init() { return RET_OK; } int Nhwc2NchwCPUKernel::ReSize() { return RET_OK; } int Nhwc2NchwCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto input = inputs_[0]; auto output = outputs_[0]; @@ -39,10 +44,10 @@ int Nhwc2NchwCPUKernel::Run() { kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Nhwc2Nchw); - auto *kernel = new (std::nothrow) Nhwc2NchwCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) Nhwc2NchwCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new Nhwc2NchwCPUKernel fail!"; return nullptr; @@ -59,4 +64,3 @@ kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~Nhwc2NchwCPUKernel() override = default; int Init() override; @@ -39,4 +40,3 @@ class Nhwc2NchwCPUKernel : public LiteKernel { }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc index afff5d42de..5cb964a3ea 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc @@ -35,6 +35,10 @@ constexpr size_t kOutputNum = 1; } // namespace int OneHotCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } // indices depth on_value off_value if (inputs_.size() != kInputNum || outputs_.size() != kOutputNum) { MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << ", got " << inputs_.size() @@ -148,6 +152,11 @@ int OneHotCPUKernel::GetParams() { } int OneHotCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } int error_code = LiteBackendParallelLaunch(RunOneHot, this, context_->thread_num_); if (error_code != RET_OK) { MS_LOG(ERROR) << "OneHot function error error_code[" << error_code << "]"; @@ -159,7 +168,7 @@ int OneHotCPUKernel::Run() { kernel::LiteKernel *CpuOneHotFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter != nullptr) { MS_LOG(ERROR) << "OneHot opParameter nullptr."; return nullptr; @@ -168,7 +177,7 @@ kernel::LiteKernel *CpuOneHotFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), context_(ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), context_(ctx) {} ~OneHotCPUKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.cc index 84c51509ba..ab70d88bbf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.cc @@ -22,15 +22,20 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; -using mindspore::schema::PrimitiveType_OptMomentum; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_OptMomentum; namespace mindspore::kernel { int OptMomentumCPUKernel::ReSize() { return 0; } int OptMomentumCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } if (inputs_.size() != 5 || !outputs_.empty()) { MS_LOG(ERROR) << "OptMomentumCPUKernel error input output size!"; return RET_ERROR; @@ -59,9 +64,9 @@ int OptMomentumCPUKernel::Init() { return 0; } kernel::LiteKernel *CpuOptMomentumFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(desc.type == schema::PrimitiveType_OptMomentum); - auto *kernel = new (std::nothrow) OptMomentumCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) OptMomentumCPUKernel(opParameter, inputs, outputs, ctx, primitive); MS_ASSERT(kernel != nullptr); auto ret = kernel->Init(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.h b/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.h index ccc2871779..f9e0395ea3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/opt_momentum.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class OptMomentumCPUKernel : public LiteKernel { public: explicit OptMomentumCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~OptMomentumCPUKernel() override {} int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc index 511c7f9d87..28be57faaa 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc @@ -36,6 +36,10 @@ constexpr int kOutputNum = 1; } // namespace int PadCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } if (inputs_.size() != kInputNum || outputs_.size() != kOutputNum) { MS_LOG(ERROR) << "Pad input size should be " << kInputNum << ", got " << inputs_.size() << ", output size should be" << kOutputNum << ", got " << outputs_.size(); @@ -85,6 +89,11 @@ int PadCPUKernel::RunImpl(int task_id) { } int PadCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto output = outputs_.at(0); int output_size = output->DataSize(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.h b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.h index c48ddf581c..f2a598d339 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.h @@ -26,8 +26,9 @@ namespace mindspore::kernel { class PadCPUKernel : public LiteKernel { public: PadCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), context_(ctx) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), context_(ctx) { pad_param_ = reinterpret_cast(parameter); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc index 960734f994..45255f7ce8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc @@ -29,6 +29,10 @@ using mindspore::schema::PrimitiveType_Pooling; namespace mindspore::kernel { int PoolingCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto ret = PoolingBaseCPUKernel::Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "PoolingBase Init failed."; @@ -68,6 +72,11 @@ int PoolingImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int PoolingCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } int error_code = LiteBackendParallelLaunch(PoolingImpl, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]"; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.h b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.h index 7edd82a537..1833764e2d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.h @@ -33,8 +33,9 @@ using mindspore::schema::RoundMode; class PoolingCPUKernel : public PoolingBaseCPUKernel { public: PoolingCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : PoolingBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : PoolingBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~PoolingCPUKernel() override = default; int Init() override; @@ -47,4 +48,3 @@ class PoolingCPUKernel : public PoolingBaseCPUKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POOLING_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.cc index 23606e6fa7..a7dd2da9ba 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.cc @@ -175,11 +175,11 @@ int PoolingGradCPUKernel::Run() { kernel::LiteKernel *CpuPoolingGradFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_PoolingGrad); - auto *kernel = new (std::nothrow) PoolingGradCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) PoolingGradCPUKernel(opParameter, inputs, outputs, ctx, primitive); MS_ASSERT(kernel != nullptr); auto ret = kernel->Init(); if (RET_OK != ret) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.h index eec333d860..f093062c1f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling_grad.h @@ -30,8 +30,9 @@ using mindspore::schema::RoundMode; class PoolingGradCPUKernel : public LiteKernel { public: explicit PoolingGradCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~PoolingGradCPUKernel() override = default; // int TfPadding(int input_w, int input_h, int &output_w, int &output_h); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc index df4e5974d6..467b92efb6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/power.cc @@ -41,7 +41,12 @@ int PowerImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int PowerCPUKernel::Run() { - int ret = LiteBackendParallelLaunch(PowerImpl, this, thread_count_); + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } + auto ret = LiteBackendParallelLaunch(PowerImpl, this, thread_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "PowerCPUKernel error: " << ret; return RET_ERROR; @@ -74,10 +79,11 @@ int PowerCPUKernel::RunImpl(int task_id) { kernel::LiteKernel *CpuPowerFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Power); - auto *kernel = new (std::nothrow) PowerCPUKernel(opParameter, inputs, outputs, ctx); + PowerCPUKernel *kernel = + new (std::nothrow) PowerCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new PowerCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/power.h b/mindspore/lite/src/runtime/kernel/arm/fp32/power.h index 3005fcb518..89a6404baa 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/power.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/power.h @@ -26,8 +26,9 @@ namespace mindspore::kernel { class PowerCPUKernel : public LiteKernel { public: PowerCPUKernel(OpParameter *param, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(param, inputs, outputs), + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(param, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_), power_(reinterpret_cast(opParameter)->power_), diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.cc index f99bd3e743..2523e3c595 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.cc @@ -50,10 +50,10 @@ int PowerGradCPUKernel::Run() { kernel::LiteKernel *CpuPowerGradFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_PowerGrad); - auto *kernel = new (std::nothrow) PowerGradCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) PowerGradCPUKernel(opParameter, inputs, outputs, ctx, primitive); auto ret = kernel->Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.h index 00b2e882f7..5361d992c4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/power_grad.h @@ -26,12 +26,13 @@ namespace mindspore::kernel { class PowerGradCPUKernel : public LiteKernel { public: PowerGradCPUKernel(OpParameter *param, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(param, inputs, outputs) { - PowerParameter *power_param = reinterpret_cast(param); - power_ = power_param->power_; - scale_ = power_param->scale_; - shift_ = power_param->shift_; + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(param, inputs, outputs, ctx, primitive) { + PowerParameter *power_param = reinterpret_cast(param); + power_ = power_param->power_; + scale_ = power_param->scale_; + shift_ = power_param->shift_; } ~PowerGradCPUKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.cc index 952392761d..756680726f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu.cc @@ -49,6 +49,11 @@ int PReluRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int PReluCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto input = inputs_.at(0); prelu_param_->input_num_ = input->ElementsNum(); input_data = reinterpret_cast(input->Data()); @@ -65,13 +70,13 @@ int PReluCPUKernel::Run() { kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Prelu); - auto *kernel = new (std::nothrow) PReluCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) PReluCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new PReluCPUKernel fail!"; return nullptr; @@ -88,4 +93,3 @@ kernel::LiteKernel *CpuPReluFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { prelu_param_ = (reinterpret_cast(opParameter)); + primitive_ = primitive; } ~PReluCPUKernel() = default; @@ -51,4 +53,3 @@ class PReluCPUKernel : public LiteKernel { }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PRELU_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/range.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/range.cc index b965731a35..880ccd5c99 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/range.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/range.cc @@ -38,6 +38,11 @@ int RangeCPUKernel::Init() { return RET_OK; } int RangeCPUKernel::ReSize() { return RET_OK; } int RangeCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } size_t start = (reinterpret_cast(opParameter))->start_; size_t limit = (reinterpret_cast(opParameter))->limit_; size_t delta = (reinterpret_cast(opParameter))->delta_; @@ -49,11 +54,11 @@ int RangeCPUKernel::Run() { kernel::LiteKernel *CpuRangeFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Range); - auto *kernel = new (std::nothrow) RangeCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) RangeCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new RangeCPUKernel fail!"; return nullptr; @@ -71,4 +76,3 @@ kernel::LiteKernel *CpuRangeFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~RangeCPUKernel() override = default; int Init() override; @@ -36,4 +37,3 @@ class RangeCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RANGE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/rank.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/rank.cc index 7917f0928d..e70b862fd1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/rank.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/rank.cc @@ -38,6 +38,11 @@ int RankCPUKernel::Init() { return RET_OK; } int RankCPUKernel::ReSize() { return RET_OK; } int RankCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } auto output_ptr = reinterpret_cast(outputs_.at(0)->Data()); auto in_shape = inputs_[0]->shape(); auto rank = in_shape.size(); @@ -47,12 +52,12 @@ int RankCPUKernel::Run() { kernel::LiteKernel *CpuRankFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, - OpParameter *opParameter, - const lite::Context *ctx, const kernel::KernelKey &desc) { + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Rank); - auto *kernel = new (std::nothrow) RankCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) RankCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new RankCPUKernel fail!"; return nullptr; @@ -69,4 +74,3 @@ kernel::LiteKernel *CpuRankFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~RankCPUKernel() override = default; int Init() override; @@ -36,4 +37,3 @@ class RankCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RANK_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc index 43bdfe9351..d05b8a17c0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc @@ -91,6 +91,10 @@ int ReduceCPUKernel::CheckParameters() { } int ReduceCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto ret = CheckInputsOutputs(); if (ret != RET_OK) { return ret; @@ -153,6 +157,11 @@ int ReduceImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int ReduceCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; + return prepare_ret; + } tmp_shape_ = inputs_.at(0)->shape(); src_data_ = static_cast(inputs_.at(0)->Data()); for (int i = 0; i < data_buffers_.size(); ++i) { @@ -220,7 +229,7 @@ int ReduceCPUKernel::MallocTmpBuffer() { kernel::LiteKernel *CpuReduceFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Reduce); if (opParameter == nullptr) { @@ -231,8 +240,8 @@ kernel::LiteKernel *CpuReduceFp32KernelCreator(const std::vector(opParameter), inputs, outputs, ctx); + auto *kernel = new (std::nothrow) + ReduceCPUKernel(reinterpret_cast(opParameter), inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "Reduce new ReduceCPUKernel failed."; return nullptr; @@ -250,7 +259,7 @@ kernel::LiteKernel *CpuReduceFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Mean); if (opParameter == nullptr) { @@ -261,8 +270,8 @@ kernel::LiteKernel *CpuMeanFp32KernelCreator(const std::vector(opParameter), inputs, outputs, ctx); + auto *kernel = new (std::nothrow) + ReduceCPUKernel(reinterpret_cast(opParameter), inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "Reduce new ReduceCPUKernel failed."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h index 2273465c27..2857ee9baf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.h @@ -31,8 +31,9 @@ class ReduceCPUKernel : public LiteKernel { public: ReduceCPUKernel(ReduceParameter *param, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(reinterpret_cast(param), inputs, outputs), + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(reinterpret_cast(param), inputs, outputs, ctx, primitive), context_(ctx), keep_dims_(param->keep_dims_), num_axes_(param->num_axes_), diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.cc index dc45381b84..f1602fdc8c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.cc @@ -36,6 +36,11 @@ int ReshapeCPUKernel::Init() { int ReshapeCPUKernel::ReSize() { return RET_OK; } int ReshapeCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input_ptr = inputs_.at(kInputIndex)->Data(); auto output_ptr = outputs_.at(kOutputIndex)->Data(); size_t data_size = inputs_.at(kInputIndex)->Size(); @@ -43,4 +48,3 @@ int ReshapeCPUKernel::Run() { return RET_OK; } } // namespace mindspore::kernel - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.h b/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.h index f366e739d9..65b9de20dc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reshape.h @@ -29,8 +29,9 @@ namespace mindspore::kernel { class ReshapeCPUKernel : public ReshapeBaseCPUKernel { public: ReshapeCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ReshapeBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : ReshapeBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~ReshapeCPUKernel() = default; int Init() override; @@ -42,4 +43,3 @@ class ReshapeCPUKernel : public ReshapeBaseCPUKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RESHAPE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/resize.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/resize.cc index ea89861d32..644c454c93 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/resize.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/resize.cc @@ -88,6 +88,10 @@ int ResizeCPUKernel::CheckInputsOuputs() { } int ResizeCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto ret = CheckParameters(); if (ret != RET_OK) { return ret; @@ -205,6 +209,11 @@ int ResizeCPUKernel::RunImpl(int task_id) { } int ResizeCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } int error_code = LiteBackendParallelLaunch(ResizeImpl, this, context_->thread_num_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Resize run error, error_code[" << error_code << "]"; @@ -216,13 +225,13 @@ int ResizeCPUKernel::Run() { kernel::LiteKernel *CpuResizeFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Resize); - auto *kernel = new (std::nothrow) ResizeCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) ResizeCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new ResizeCPUKernel fail!"; return nullptr; @@ -240,4 +249,3 @@ kernel::LiteKernel *CpuResizeFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), context_(ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), context_(ctx) {} ~ResizeCPUKernel() { if (exec_input_data_ != nullptr) { @@ -62,4 +63,3 @@ class ResizeCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_RESIZE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.cc index 9e3b6bb557..80ddfcf320 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse.cc @@ -89,6 +89,10 @@ int ReverseCPUKernel::ReSize() { } int ReverseCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } data_size_ = inputs_.at(0)->ElementsNum(); thread_sz_count_ = MSMIN(thread_count_, data_size_); thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); @@ -121,9 +125,14 @@ int ReverseCPUKernel::DoReverse(int task_id) { } int ReverseCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } in_ptr_ = reinterpret_cast(inputs_[0]->Data()); out_ptr_ = reinterpret_cast(outputs_[0]->Data()); - int ret = LiteBackendParallelLaunch(ReverseRun, this, thread_sz_count_); + ret = LiteBackendParallelLaunch(ReverseRun, this, thread_sz_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "Reverse run error error_code[" << ret << "]"; return ret; @@ -134,13 +143,13 @@ int ReverseCPUKernel::Run() { kernel::LiteKernel *CpuReverseFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "opParameter is NULL! "; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Reverse); - auto *kernel = new (std::nothrow) ReverseCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) ReverseCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "Kernel is NULL! name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); @@ -159,4 +168,3 @@ kernel::LiteKernel *CpuReverseFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {} ~ReverseCPUKernel() { if (tmp_ != nullptr) { free(tmp_); @@ -60,4 +61,3 @@ class ReverseCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc index 8684adaeeb..3ed50ee3e5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc @@ -24,6 +24,10 @@ using mindspore::schema::PrimitiveType_ReverseSequence; namespace mindspore::kernel { int ReverseSequenceCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto input0 = inputs_.at(0); auto input1 = inputs_.at(1); auto output = outputs_.at(0); @@ -84,6 +88,11 @@ int ReverseSequenceCPUKernel::CalcCountAfterAxis(const std::vector shape, i int ReverseSequenceCPUKernel::ReSize() { return RET_OK; } int ReverseSequenceCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } float *input0 = reinterpret_cast(inputs_.at(0)->Data()); int *input1 = reinterpret_cast(inputs_.at(1)->Data()); float *output = reinterpret_cast(outputs_.at(0)->Data()); @@ -94,10 +103,10 @@ int ReverseSequenceCPUKernel::Run() { kernel::LiteKernel *CpuReverseSequenceFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *parameter, const lite::Context *ctx, - const KernelKey &desc) { + const KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(parameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_ReverseSequence); - auto *kernel = new (std::nothrow) ReverseSequenceCPUKernel(parameter, inputs, outputs); + auto *kernel = new (std::nothrow) ReverseSequenceCPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; return nullptr; @@ -114,4 +123,3 @@ kernel::LiteKernel *CpuReverseSequenceFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~ReverseSequenceCPUKernel() = default; int Init() override; @@ -40,4 +41,3 @@ class ReverseSequenceCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_SEQUENCE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc index 2323e4ce95..433310de10 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc @@ -90,6 +90,10 @@ int ScaleCPUKernel::InitParameter() { } int ScaleCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } if (inputs_.size() < 2 || inputs_.size() > 3) { MS_LOG(ERROR) << "inputs to Scale operator should be 2 or 3, but " << inputs_.size() << " is given."; return RET_ERROR; @@ -133,6 +137,11 @@ int ScaleRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int ScaleCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto in_tensor = inputs_.front(); input_ptr_ = reinterpret_cast(in_tensor->Data()); if (scale_ == nullptr) { @@ -142,7 +151,7 @@ int ScaleCPUKernel::Run() { auto out_tensor = outputs_.front(); output_ptr_ = reinterpret_cast(out_tensor->Data()); - int ret = LiteBackendParallelLaunch(ScaleRun, this, opParameter->thread_num_); + ret = LiteBackendParallelLaunch(ScaleRun, this, opParameter->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; return RET_ERROR; @@ -154,13 +163,13 @@ int ScaleCPUKernel::Run() { kernel::LiteKernel *CpuScaleFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(desc.type == schema::PrimitiveType_Scale); if (opParameter == nullptr) { MS_LOG(ERROR) << "opParameter is nullptr"; return nullptr; } - auto *kernel = new (std::nothrow) ScaleCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) ScaleCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "New kernel fails."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h index 32417bcc26..caf4d35376 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h @@ -25,10 +25,11 @@ namespace mindspore::kernel { class ScaleCPUKernel : public LiteKernel { public: explicit ScaleCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs) { - opParameter->thread_num_ = ctx->thread_num_; - } + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + opParameter->thread_num_ = ctx->thread_num_; + } ~ScaleCPUKernel() override = default; int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.cc index 1a9ab72866..e13d74697b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scatter_nd.cc @@ -30,13 +30,17 @@ using mindspore::schema::PrimitiveType_ScatterND; namespace mindspore::kernel { namespace { - constexpr int kScatterNDInputNum = 3; - constexpr int kScatterNDOutputNum = 1; - constexpr int kScatterShapeIndex = 0; - constexpr int kScatterIndicesIndex = 1; - constexpr int kScatterUpdateIndex = 2; +constexpr int kScatterNDInputNum = 3; +constexpr int kScatterNDOutputNum = 1; +constexpr int kScatterShapeIndex = 0; +constexpr int kScatterIndicesIndex = 1; +constexpr int kScatterUpdateIndex = 2; } // namespace int ScatterNDCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto shape = inputs_.at(kScatterShapeIndex); auto indices = inputs_.at(kScatterIndicesIndex); auto update = inputs_.at(kScatterUpdateIndex); @@ -146,7 +150,12 @@ int ScatterNDRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int ScatterNDCPUKernel::Run() { - int ret = LiteBackendParallelLaunch(ScatterNDRun, this, thread_n_num_); + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } + ret = LiteBackendParallelLaunch(ScatterNDRun, this, thread_n_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "ScatterND error error_code[" << ret << "]"; return RET_ERROR; @@ -158,13 +167,13 @@ int ScatterNDCPUKernel::Run() { kernel::LiteKernel *CpuScatterNDFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(desc.type == schema::PrimitiveType_ScatterND); if (opParameter == nullptr) { MS_LOG(ERROR) << "desc type is not scatterND"; return nullptr; } - auto *kernel = new (std::nothrow) ScatterNDCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) ScatterNDCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "New kernel fails."; return nullptr; @@ -183,4 +192,3 @@ kernel::LiteKernel *CpuScatterNDFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), thread_num_(ctx->thread_num_) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) {} ~ScatterNDCPUKernel() override = default; int Init() override; @@ -49,4 +50,3 @@ class ScatterNDCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SCATTER_ND_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/shape.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/shape.cc index 077e174fcf..aee8f2beaf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/shape.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/shape.cc @@ -27,14 +27,19 @@ using mindspore::schema::PrimitiveType_Shape; namespace mindspore::kernel { namespace { - constexpr int kShapeInputNum = 1; - constexpr int kShapeOutputNum = 1; +constexpr int kShapeInputNum = 1; +constexpr int kShapeOutputNum = 1; } // namespace int ShapeCPUKernel::Init() { return RET_OK; } int ShapeCPUKernel::ReSize() { return RET_OK; } int ShapeCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto out_tensor = outputs_.front(); auto in_tensor = inputs_.front(); if (in_tensor == nullptr || out_tensor == nullptr) { @@ -55,14 +60,14 @@ int ShapeCPUKernel::Run() { kernel::LiteKernel *CpuShapeFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, - OpParameter *opParameter, - const lite::Context *ctx, const kernel::KernelKey &desc) { + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(desc.type == schema::PrimitiveType_Shape); if (opParameter == nullptr) { MS_LOG(ERROR) << "desc type is not Shape"; return nullptr; } - auto *kernel = new (std::nothrow) ShapeCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) ShapeCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "New kernel fails."; return nullptr; @@ -81,4 +86,3 @@ kernel::LiteKernel *CpuShapeFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~ShapeCPUKernel() override = default; int Init() override; @@ -39,4 +40,3 @@ class ShapeCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SHAPE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc index 91d02b5623..97a320d67b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc @@ -40,6 +40,10 @@ int SliceLaunch(int thread_id, LiteParallelGroupEnv *penv, void *cdata) { } // namespace int SliceCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto *param = reinterpret_cast(opParameter); auto input_shape = inputs_[0]->shape(); if (input_shape.size() != param->param_length_) { @@ -68,6 +72,11 @@ int SliceCPUKernel::SliceParallelRun(int thread_id) { } int SliceCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } SliceParameter *param = reinterpret_cast(opParameter); for (int i = 0; i < param->param_length_; ++i) { if (param->size_[i] < 0) { @@ -86,7 +95,7 @@ int SliceCPUKernel::Run() { DoSliceNoParallel(input_data, output_data, param); return RET_OK; } - int ret = LiteBackendParallelLaunch(SliceLaunch, this, param->op_parameter_.thread_num_); + ret = LiteBackendParallelLaunch(SliceLaunch, this, param->op_parameter_.thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "slice launch fail!ret: " << ret; return RET_ERROR; @@ -97,7 +106,7 @@ int SliceCPUKernel::Run() { kernel::LiteKernel *CpuSliceFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (op_parameter == nullptr) { MS_LOG(ERROR) << "Input op_parameter is nullptr!"; return nullptr; @@ -108,7 +117,7 @@ kernel::LiteKernel *CpuSliceFp32KernelCreator(const std::vectorthread_num_ = ctx->thread_num_; - auto *kernel = new (std::nothrow) SliceCPUKernel(op_parameter, inputs, outputs); + auto *kernel = new (std::nothrow) SliceCPUKernel(op_parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new SliceCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/slice.h b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.h index a02baf4918..8ed727a503 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/slice.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.h @@ -19,18 +19,17 @@ #include #include "src/lite_kernel.h" - namespace mindspore::kernel { class SliceCPUKernel : public LiteKernel { public: SliceCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~SliceCPUKernel() = default; int Init() override; - int ReSize() override { - return 0; - } + int ReSize() override { return 0; } int Run() override; int SliceParallelRun(int thread_id); }; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc index 5dbd0acc43..cda47c6a83 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc @@ -30,6 +30,10 @@ using mindspore::schema::PrimitiveType_SoftMax; namespace mindspore::kernel { int SoftmaxCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } SoftmaxBaseCPUKernel::Init(); // malloc tmp buffer @@ -56,6 +60,11 @@ int SoftmaxCPUKernel::Init() { int SoftmaxCPUKernel::ReSize() { return RET_OK; } int SoftmaxCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input_ptr = reinterpret_cast(inputs_.at(kInputIndex)->Data()); auto output_ptr = reinterpret_cast(outputs_.at(kOutputIndex)->Data()); Softmax(input_ptr, output_ptr, sum_data_, softmax_param_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h index 6c46045794..515535a328 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class SoftmaxCPUKernel : public SoftmaxBaseCPUKernel { public: SoftmaxCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~SoftmaxCPUKernel() override { if (sum_data_ != nullptr) { free(sum_data_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.cc index 5ac8d4db97..6849ee358f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.cc @@ -31,6 +31,10 @@ using mindspore::schema::PrimitiveType_SpaceToBatch; namespace mindspore::kernel { int SpaceToBatchCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } if (inputs_[0]->GetFormat() != schema::Format_NHWC) { MS_LOG(ERROR) << "space_to_batch only support NHWC now!"; return RET_FORMAT_ERR; @@ -50,13 +54,17 @@ int SpaceToBatchCPUKernel::Init() { } int SpaceToBatchCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input = inputs_[0]; auto output = outputs_[0]; input_ptr_ = reinterpret_cast(input->Data()); output_ptr_ = reinterpret_cast(output->Data()); SpaceToBatchParameter *param = reinterpret_cast(this->opParameter); - int ret; float *tmp_space[3] = {nullptr, nullptr, nullptr}; if (param->need_paddings_) { tmp_space[0] = reinterpret_cast(malloc(param->num_elements_padded_ * sizeof(float))); @@ -81,12 +89,12 @@ int SpaceToBatchCPUKernel::Run() { kernel::LiteKernel *CpuSpaceToBatchFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } - auto *kernel = new (std::nothrow) SpaceToBatchCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) SpaceToBatchCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new SpaceToBatchCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.h b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.h index 510649f2c0..453364c6d9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.h @@ -23,8 +23,9 @@ namespace mindspore::kernel { class SpaceToBatchCPUKernel : public LiteKernel { public: SpaceToBatchCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~SpaceToBatchCPUKernel() = default; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_depth.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_depth.cc index 13aa702567..cbb5cc426c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_depth.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_depth.cc @@ -32,6 +32,10 @@ using mindspore::schema::PrimitiveType_SpaceToDepth; namespace mindspore::kernel { int SpaceToDepthCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } if (inputs_[0]->GetFormat() != schema::Format_NHWC) { MS_LOG(ERROR) << "space_to_depth only support NHWC now!"; return RET_FORMAT_ERR; @@ -77,10 +81,15 @@ int SpaceToDepthRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int SpaceToDepthCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } input_ptr_ = reinterpret_cast(inputs_[0]->Data()); output_ptr_ = reinterpret_cast(outputs_[0]->Data()); if (inputs_[0]->GetFormat() == schema::Format_NHWC) { - int ret = LiteBackendParallelLaunch(SpaceToDepthRun, this, thread_h_num_); + ret = LiteBackendParallelLaunch(SpaceToDepthRun, this, thread_h_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "SpaceToDepth error error_code[" << ret << "]"; return ret; @@ -90,16 +99,18 @@ int SpaceToDepthCPUKernel::Run() { MS_LOG(ERROR) << "Only support NHWC now!"; return RET_ERROR; } + return RET_OK; } + kernel::LiteKernel *CpuSpaceToDepthFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "Input opParameter is nullptr!"; return nullptr; } - auto *kernel = new (std::nothrow) SpaceToDepthCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) SpaceToDepthCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new SpaceToDepthCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_depth.h b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_depth.h index 749de65503..1756358927 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_depth.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_depth.h @@ -24,8 +24,9 @@ namespace mindspore::kernel { class SpaceToDepthCPUKernel : public LiteKernel { public: SpaceToDepthCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), thread_num_(ctx->thread_num_) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) {} ~SpaceToDepthCPUKernel() = default; int SpaceToDepth(int task_id); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.cc index fec15c1e00..e343b026a5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.cc @@ -105,6 +105,10 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() { } int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto dims = inputs_[0]->shape(); param->n_dim_ = 2; param->number_of_classes_ = dims[1]; @@ -126,10 +130,12 @@ int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Init() { kernel::LiteKernel *CpuSoftmaxCrossEntropyFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_SoftmaxCrossEntropy); - auto *kernel = new (std::nothrow) SparseSoftmaxCrossEntropyWithLogitsCPUKernel(opParameter, inputs, outputs); + auto *kernel = + new (std::nothrow) SparseSoftmaxCrossEntropyWithLogitsCPUKernel(opParameter, inputs, outputs, ctx, primitive); MS_ASSERT(kernel != nullptr); auto ret = kernel->Init(); if (RET_OK != ret) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h index a8dd3439cd..db9a0d1270 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_softmax_cross_entropy_with_logits.h @@ -29,8 +29,10 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LiteKernel { public: explicit SparseSoftmaxCrossEntropyWithLogitsCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) { + const std::vector &outputs, + const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { param = reinterpret_cast(parameter); } ~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.cc index 359a379985..b48f1c4f80 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/sparse_to_dense.cc @@ -49,6 +49,11 @@ int SparseToDenseRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { return RET_OK; } int SparseToDenseCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input = inputs_.at(0); auto input1 = inputs_.at(1); auto input2 = inputs_.at(2); @@ -65,7 +70,7 @@ int SparseToDenseCPUKernel::Run() { std::vector temp_shape = output0->shape(); output_shape_ = reinterpret_cast(temp_shape.data()); - auto ret = LiteBackendParallelLaunch(SparseToDenseRun, this, s2d_param_->thread_num_); + ret = LiteBackendParallelLaunch(SparseToDenseRun, this, s2d_param_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "SparseToDenseRun error: error_code[" << ret << "]"; return RET_ERROR; @@ -76,13 +81,13 @@ int SparseToDenseCPUKernel::Run() { kernel::LiteKernel *CpuSparseToDenseFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_SparseToDense); - auto *kernel = new (std::nothrow) SparseToDenseCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) SparseToDenseCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new SparseToDenseCPUKernel fail!"; return nullptr; @@ -99,4 +104,3 @@ kernel::LiteKernel *CpuSparseToDenseFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { s2d_param_ = (reinterpret_cast(opParameter)); } ~SparseToDenseCPUKernel() = default; @@ -56,4 +57,3 @@ class SparseToDenseCPUKernel : public LiteKernel { }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPARSETODENSE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/split.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/split.cc index 81942f7a37..2d064989c4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/split.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/split.cc @@ -31,6 +31,10 @@ using mindspore::schema::PrimitiveType_Split; namespace mindspore::kernel { int SplitCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } SplitBaseCPUKernel::Init(); auto in_tensor = inputs_.front(); input_ptr_ = reinterpret_cast(in_tensor->Data()); @@ -68,7 +72,12 @@ int SplitRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int SplitCPUKernel::Run() { - int ret = LiteBackendParallelLaunch(SplitRun, this, thread_n_num_); + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } + ret = LiteBackendParallelLaunch(SplitRun, this, thread_n_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/split.h b/mindspore/lite/src/runtime/kernel/arm/fp32/split.h index 5761367abb..0796d7b135 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/split.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/split.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class SplitCPUKernel : public SplitBaseCPUKernel { public: SplitCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : SplitBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : SplitBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~SplitCPUKernel() override = default; int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.cc index fbc102eec0..0074f975ab 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/squeeze.cc @@ -28,8 +28,8 @@ using mindspore::schema::PrimitiveType_Squeeze; namespace mindspore::kernel { namespace { - constexpr int kSqueezeInputNum = 1; - constexpr int kSqueezeOutputNum = 1; +constexpr int kSqueezeInputNum = 1; +constexpr int kSqueezeOutputNum = 1; } // namespace int SqueezeCPUKernel::Init() { return RET_OK; } @@ -37,10 +37,15 @@ int SqueezeCPUKernel::Init() { return RET_OK; } int SqueezeCPUKernel::ReSize() { return RET_OK; } int SqueezeCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input_ptr = reinterpret_cast(inputs_.front()->Data()); auto output_ptr = reinterpret_cast(outputs_.front()->Data()); size_t data_size = inputs_.front()->Size(); - auto ret = DoSqueeze(input_ptr, output_ptr, data_size); + ret = DoSqueeze(input_ptr, output_ptr, data_size); if (ret != RET_OK) { MS_LOG(ERROR) << "Do squeeze failed."; return RET_ERROR; @@ -51,13 +56,13 @@ int SqueezeCPUKernel::Run() { kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(desc.type == schema::PrimitiveType_Squeeze); if (opParameter == nullptr) { MS_LOG(ERROR) << "desc type is not Squeeze"; return nullptr; } - auto *kernel = new (std::nothrow) SqueezeCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) SqueezeCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "New kernel fails."; return nullptr; @@ -76,4 +81,3 @@ kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector #include "src/lite_kernel.h" - namespace mindspore::kernel { class SqueezeCPUKernel : public LiteKernel { public: explicit SqueezeCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~SqueezeCPUKernel() override = default; int Init() override; @@ -40,4 +40,3 @@ class SqueezeCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SQUEEZE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc index 344e6762ca..e12079bcf0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc @@ -27,6 +27,10 @@ using mindspore::schema::PrimitiveType_Stack; namespace mindspore::kernel { int StackCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } StackParameter *param = reinterpret_cast(opParameter); auto input0_shape = inputs_[0]->shape(); axis_ = param->axis_ < 0 ? param->axis_ + input0_shape.size() : param->axis_; @@ -67,6 +71,11 @@ int StackCPUKernel::Init() { } int StackCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } size_t inputs_num = inputs_.size(); auto input0_shape = inputs_[0]->shape(); auto *output_data = reinterpret_cast(outputs_[0]->Data()); @@ -87,13 +96,13 @@ int StackCPUKernel::Run() { kernel::LiteKernel *CpuStackFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (op_parameter == nullptr) { MS_LOG(ERROR) << "Input op_parameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Stack); - auto *kernel = new (std::nothrow) StackCPUKernel(op_parameter, inputs, outputs); + auto *kernel = new (std::nothrow) StackCPUKernel(op_parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new StackCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/stack.h b/mindspore/lite/src/runtime/kernel/arm/fp32/stack.h index c1d76ca193..179bc21392 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/stack.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/stack.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class StackCPUKernel : public LiteKernel { public: StackCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs), + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), convert_functions_(inputs_.size(), nullptr), packed_inputs_(inputs_.size(), nullptr) {} @@ -51,4 +52,3 @@ class StackCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_STACK_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tile.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tile.cc index 62a1502774..9b6b6336fb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tile.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tile.cc @@ -25,6 +25,10 @@ using mindspore::schema::PrimitiveType_Tile; namespace mindspore::kernel { int TileCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto tile_parameter_ = reinterpret_cast(opParameter); for (int i = 0; i < tile_parameter_->in_dim_; ++i) { tile_parameter_->in_shape_[i] = inputs_[0]->shape()[i]; @@ -46,6 +50,11 @@ void TileCPUKernel::ComputeStrides(int *shape, int *strides, int ndim) { int TileCPUKernel::ReSize() { return RET_OK; } int TileCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input_addr = reinterpret_cast(inputs_.at(0)->Data()); auto output_addr = reinterpret_cast(outputs_.at(0)->Data()); @@ -55,13 +64,14 @@ int TileCPUKernel::Run() { kernel::LiteKernel *CpuTileFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *parameter, - const lite::Context *ctx, const KernelKey &desc) { + const lite::Context *ctx, const KernelKey &desc, + const lite::Primitive *primitive) { if (parameter == nullptr || ctx == nullptr) { MS_LOG(ERROR) << "parameter or ctx is nullptr"; return nullptr; } MS_ASSERT(desc.type == PrimitiveType_Tile); - auto *kernel = new (std::nothrow) TileCPUKernel(parameter, inputs, outputs); + auto *kernel = new (std::nothrow) TileCPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; return nullptr; @@ -79,4 +89,3 @@ kernel::LiteKernel *CpuTileFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~TileCPUKernel() override {} int Init() override; @@ -38,4 +39,3 @@ class TileCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TILE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc index 954ec041bc..ece7442b49 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc @@ -25,6 +25,10 @@ using mindspore::schema::PrimitiveType_TopK; namespace mindspore::kernel { int TopKCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } TopkParameter *parameter = reinterpret_cast(opParameter); lite::tensor::Tensor *input = inputs_.at(0); parameter->last_dim_size_ = input->shape()[input->shape().size() - 1]; @@ -44,6 +48,11 @@ int TopKCPUKernel::Init() { int TopKCPUKernel::ReSize() { return RET_OK; } int TopKCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input_data = reinterpret_cast(inputs_.at(0)->Data()); auto output_data = reinterpret_cast(outputs_.at(0)->Data()); auto output_index = reinterpret_cast(outputs_.at(1)->Data()); @@ -54,9 +63,11 @@ int TopKCPUKernel::Run() { kernel::LiteKernel *CpuTopKFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *parameter, - const lite::Context *ctx, const KernelKey &desc) { + const lite::Context *ctx, const KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(parameter != nullptr); - auto *kernel = new (std::nothrow) TopKCPUKernel(parameter, inputs, outputs); + MS_ASSERT(desc.type == PrimitiveType_Tile); + auto *kernel = new (std::nothrow) TopKCPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new TopKCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/topk.h b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.h index 9ed7af2f39..f07d2847fc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/topk.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.h @@ -24,8 +24,9 @@ namespace mindspore::kernel { class TopKCPUKernel : public LiteKernel { public: explicit TopKCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~TopKCPUKernel() override { TopkParameter *parameter = reinterpret_cast(opParameter); free(parameter->topk_node_list_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc index a6473114fa..89485726be 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc @@ -28,10 +28,14 @@ using mindspore::schema::PrimitiveType_Transpose; namespace mindspore::kernel { namespace { - constexpr int kTransposeInputNum = 1; - constexpr int kTransposeOutputNum = 1; +constexpr int kTransposeInputNum = 1; +constexpr int kTransposeOutputNum = 1; } // namespace int TransposeCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto &inTensor = inputs_.front(); auto &outTensor = outputs_.front(); auto param = reinterpret_cast(opParameter); @@ -50,6 +54,11 @@ int TransposeCPUKernel::Init() { int TransposeCPUKernel::ReSize() { return RET_OK; } int TransposeCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } MS_ASSERT(inputs_.size() == TransposeInputNum); MS_ASSERT(outputs_.size() == TransposeOutputNum); auto &inTensor = inputs_.front(); @@ -65,21 +74,20 @@ int TransposeCPUKernel::Run() { auto *input_shape = &in_shape.front(); auto *output_shape = &out_shape.front(); - auto ret = - DoTranspose(in_data, out_data, input_shape, output_shape, reinterpret_cast(opParameter)); + ret = DoTranspose(in_data, out_data, input_shape, output_shape, reinterpret_cast(opParameter)); return ret; } kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(desc.type == schema::PrimitiveType_Transpose); if (opParameter == nullptr) { MS_LOG(ERROR) << "desc type is not Transpose"; return nullptr; } - auto *kernel = new (std::nothrow) TransposeCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) TransposeCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "New kernel fails."; return nullptr; @@ -97,4 +105,3 @@ kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(param, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(param, inputs, outputs, ctx, primitive) {} ~TransposeCPUKernel() override = default; int Init() override; @@ -41,4 +41,3 @@ class TransposeCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_CCSRC_KERNEL_CPU_ARM_FP32_TRANSPOSE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unique.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/unique.cc index dc93554fa7..37b28f59d1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/unique.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unique.cc @@ -28,6 +28,11 @@ int UniqueCPUKernel::Init() { return RET_OK; } int UniqueCPUKernel::ReSize() { return RET_OK; } int UniqueCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input = reinterpret_cast(inputs_.at(0)->Data()); auto output0 = reinterpret_cast(outputs_.at(0)->Data()); auto output1 = reinterpret_cast(outputs_.at(1)->Data()); @@ -43,11 +48,11 @@ int UniqueCPUKernel::Run() { kernel::LiteKernel *CpuUniqueFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, - OpParameter *parameter, const lite::Context *ctx, - const KernelKey &desc) { + OpParameter *parameter, const lite::Context *ctx, const KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(parameter); MS_ASSERT(desc.type == PrimitiveType_Unique); - auto *kernel = new (std::nothrow) UniqueCPUKernel(parameter, inputs, outputs); + auto *kernel = new (std::nothrow) UniqueCPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; return nullptr; @@ -64,4 +69,3 @@ kernel::LiteKernel *CpuUniqueFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~UniqueCPUKernel() = default; int Init() override; @@ -37,4 +38,3 @@ class UniqueCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNIQUE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.cc index 2555605caa..cd1034fcc0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unsqueeze.cc @@ -28,6 +28,10 @@ using mindspore::schema::PrimitiveType_Unsqueeze; namespace mindspore::kernel { int UnsqueezeCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } int ret = ReSize(); return ret; } @@ -64,9 +68,14 @@ int UnsqueezeRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int UnsqueezeCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } in_ptr_ = reinterpret_cast(inputs_.at(0)->Data()); out_ptr_ = reinterpret_cast(outputs_.at(0)->Data()); - int ret = LiteBackendParallelLaunch(UnsqueezeRun, this, thread_sz_count_); + ret = LiteBackendParallelLaunch(UnsqueezeRun, this, thread_sz_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "UnsqueezeRun error error_code[" << ret << "]"; return ret; @@ -75,12 +84,12 @@ int UnsqueezeCPUKernel::Run() { } kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Unsqueeze); - auto *kernel = new (std::nothrow) UnsqueezeCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) UnsqueezeCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new UnsqueezeCPUKernel fail!"; return nullptr; @@ -97,4 +106,3 @@ kernel::LiteKernel *CpuUnsqueezeFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) {} ~UnsqueezeCPUKernel() = default; int Init() override; @@ -48,4 +49,3 @@ class UnsqueezeCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNSQUEEZE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.cc index 8facbc7a9d..1d18036953 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.cc @@ -24,6 +24,10 @@ using mindspore::schema::PrimitiveType_Unstack; namespace mindspore::kernel { int UnstackCPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto input = inputs_.at(0); MS_ASSERT(input != nullptr); size_t shape_size = input->shape().size(); @@ -56,6 +60,11 @@ int UnstackCPUKernel::Init() { int UnstackCPUKernel::ReSize() { return RET_OK; } int UnstackCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } float *input = reinterpret_cast(inputs_.at(0)->Data()); size_t out_num = outputs_.size(); for (size_t i = 0; i < out_num; i++) { @@ -67,11 +76,11 @@ int UnstackCPUKernel::Run() { kernel::LiteKernel *CpuUnstackFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, - OpParameter *parameter, const lite::Context *ctx, - const KernelKey &desc) { + OpParameter *parameter, const lite::Context *ctx, const KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(parameter != nullptr); MS_ASSERT(desc.type == PrimitiveType_Unstack); - auto *kernel = new (std::nothrow) UnstackCPUKernel(parameter, inputs, outputs); + auto *kernel = new (std::nothrow) UnstackCPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.h b/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.h index e652ad6adf..bc7181365f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unstack.h @@ -24,11 +24,10 @@ namespace mindspore::kernel { class UnstackCPUKernel : public LiteKernel { public: UnstackCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} - ~UnstackCPUKernel() { - free(output_addr_array_); - } + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + ~UnstackCPUKernel() { free(output_addr_array_); } int Init() override; int ReSize() override; @@ -40,4 +39,3 @@ class UnstackCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_UNSTACK_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/where.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/where.cc index f2bf03fc46..ad59ff334d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/where.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/where.cc @@ -48,6 +48,11 @@ int WhereRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { return RET_OK; } int WhereCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input = inputs_.at(0); auto input1 = inputs_.at(1); auto input2 = inputs_.at(2); @@ -74,7 +79,7 @@ int WhereCPUKernel::Run() { MS_LOG(ERROR) << "Error, inputs' length are zero !!!"; return RET_ERROR; } - auto ret = LiteBackendParallelLaunch(WhereRun, this, where_param_->thread_num_); + ret = LiteBackendParallelLaunch(WhereRun, this, where_param_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "WhereDwRun error: error_code[" << ret << "]"; return RET_ERROR; @@ -85,13 +90,13 @@ int WhereCPUKernel::Run() { kernel::LiteKernel *CpuWhereFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_Where); - auto *kernel = new (std::nothrow) WhereCPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) WhereCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new WhereCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/where.h b/mindspore/lite/src/runtime/kernel/arm/fp32/where.h index ad9c73a9fa..d8bb43de25 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/where.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/where.h @@ -29,8 +29,9 @@ namespace mindspore::kernel { class WhereCPUKernel : public LiteKernel { public: WhereCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { where_param_ = reinterpret_cast(opParameter); } ~WhereCPUKernel() = default; @@ -53,4 +54,3 @@ class WhereCPUKernel : public LiteKernel { }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_WHERE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.cc index 6fcd25dc48..fc897e43c5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/zeroslike.cc @@ -33,6 +33,11 @@ constexpr int kOutputNum = 1; int ZerosLikeCPUKernel::Init() { return RET_OK; } int ZerosLikeCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input = inputs_.at(0); auto input_data = reinterpret_cast(input->Data()); auto output_data = reinterpret_cast(outputs_.at(0)->Data()); @@ -43,13 +48,13 @@ int ZerosLikeCPUKernel::Run() { kernel::LiteKernel *CpuZerosLikeFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (opParameter == nullptr) { MS_LOG(ERROR) << "input opParameter is nullptr!"; return nullptr; } MS_ASSERT(desc.type == schema::PrimitiveType_ZerosLike); - auto *kernel = new (std::nothrow) ZerosLikeCPUKernel(opParameter, inputs, outputs); + auto *kernel = new (std::nothrow) ZerosLikeCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new ZerosLikeCPUKernel fail!"; return nullptr; @@ -66,4 +71,3 @@ kernel::LiteKernel *CpuZerosLikeFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~ZerosLikeCPUKernel() = default; @@ -35,4 +36,3 @@ class ZerosLikeCPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ZEROSLIKE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/activation.cc b/mindspore/lite/src/runtime/kernel/arm/int8/activation.cc index 3b0d7b94a2..3d9ba79303 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/activation.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/activation.cc @@ -33,7 +33,7 @@ namespace mindspore::kernel { kernel::LiteKernel *CpuActivationInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *parameter, const lite::Context *ctx, - const KernelKey &desc) { + const KernelKey &desc, const lite::Primitive *primitive) { if (parameter == nullptr) { MS_LOG(ERROR) << "parameter is nullptr"; return nullptr; @@ -43,16 +43,16 @@ kernel::LiteKernel *CpuActivationInt8KernelCreator(const std::vector(type)) { case schema::ActivationType_RELU: - kernel = new (std::nothrow) ReluInt8CPUKernel(parameter, inputs, outputs, ctx); + kernel = new (std::nothrow) ReluInt8CPUKernel(parameter, inputs, outputs, ctx, primitive); break; case schema::ActivationType_RELU6: - kernel = new (std::nothrow) Relu6Int8CPUKernel(parameter, inputs, outputs, ctx); + kernel = new (std::nothrow) Relu6Int8CPUKernel(parameter, inputs, outputs, ctx, primitive); break; case schema::ActivationType_HSWISH: - kernel = new (std::nothrow) HswishInt8CPUKernel(parameter, inputs, outputs, ctx); + kernel = new (std::nothrow) HswishInt8CPUKernel(parameter, inputs, outputs, ctx, primitive); break; case schema::ActivationType_SIGMOID: - kernel = new (std::nothrow) SigmoidInt8CPUKernel(parameter, inputs, outputs, ctx); + kernel = new (std::nothrow) SigmoidInt8CPUKernel(parameter, inputs, outputs, ctx, primitive); break; default: break; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc index 18bb5e80c6..3a811f9e57 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc @@ -75,6 +75,11 @@ int QuantizedAddCPUKernel::Init() { int QuantizedAddCPUKernel::ReSize() { return 0; } int QuantizedAddCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } input0_data_ = static_cast(inputs_.at(0)->Data()); input1_data_ = static_cast(inputs_.at(1)->Data()); output_data_ = static_cast(outputs_.at(0)->Data()); @@ -96,13 +101,13 @@ int QuantizedAddCPUKernel::Run() { TileDimensionsUint8(static_cast(inputs_.at(0)->Data()), static_cast(inputs_.at(1)->Data()), reinterpret_cast(input0_data_), reinterpret_cast(input1_data_), &tile_para); - auto ret = LiteBackendParallelLaunch(AddInt8Run, this, thread_count_); + ret = LiteBackendParallelLaunch(AddInt8Run, this, thread_count_); ctx_->allocator->Free(input0_data_); ctx_->allocator->Free(input1_data_); return ret; } - auto ret = LiteBackendParallelLaunch(AddInt8Run, this, thread_count_); + ret = LiteBackendParallelLaunch(AddInt8Run, this, thread_count_); return ret; } @@ -124,13 +129,14 @@ int QuantizedAddCPUKernel::DoExecute(int tId) { kernel::LiteKernel *CpuAddInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *parameter, - const lite::Context *ctx, const KernelKey &desc) { + const lite::Context *ctx, const KernelKey &desc, + const lite::Primitive *primitive) { if (parameter == nullptr || ctx == nullptr) { MS_LOG(ERROR) << "parameter or ctx is nullptr"; return nullptr; } MS_ASSERT(desc.type == PrimitiveType_Add); - auto *kernel = new (std::nothrow) QuantizedAddCPUKernel(parameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) QuantizedAddCPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h index b57aab2279..5a1da51340 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class QuantizedAddCPUKernel : public LiteKernel { public: explicit QuantizedAddCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx_->thread_num_) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx_->thread_num_) {} ~QuantizedAddCPUKernel() override {} int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.cc index 03cae3aa86..36216531b5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.cc @@ -20,8 +20,8 @@ #include "src/runtime/kernel/arm/nnacl/int8/arg_min_max_int8.h" #include "include/errorcode.h" -using mindspore::lite::RET_OK; using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; namespace mindspore::kernel { int ArgMinMaxInt8CPUKernel::Init() { @@ -44,6 +44,11 @@ int ArgMinMaxInt8CPUKernel::Init() { } int ArgMinMaxInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input = inputs_.at(0); const int8_t *input_data = reinterpret_cast(inputs_.at(0)->Data()); @@ -70,6 +75,7 @@ int ArgMinMaxInt8CPUKernel::Run() { ArgMinMaxDim3(input_data, output_data, in_shape, param, &in_quant_arg_, &out_quant_arg_); break; } + FreeTmpMemory(); return RET_OK; } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h index 919acd2037..1a7e331b5f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h @@ -24,8 +24,9 @@ namespace mindspore::kernel { class ArgMinMaxInt8CPUKernel : public ArgMinMaxBaseCPUKernel { public: ArgMinMaxInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : ArgMinMaxBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : ArgMinMaxBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~ArgMinMaxInt8CPUKernel() = default; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc index 670894964c..db2d911e84 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc @@ -167,12 +167,12 @@ int ArithmeticInt8CPUKernel::Run() { kernel::LiteKernel *CpuArithmeticInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *parameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { if (parameter == nullptr) { MS_LOG(ERROR) << "Input parameter is null!"; return nullptr; } - auto kernel = new (std::nothrow) ArithmeticInt8CPUKernel(parameter, inputs, outputs, ctx); + auto kernel = new (std::nothrow) ArithmeticInt8CPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "Create ArithmeticInt8CPUKernel failed, name: " << parameter->name_; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.h index 56ebcd7e0b..dfc5a030a7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.h @@ -27,8 +27,9 @@ class ArithmeticInt8CPUKernel : public LiteKernel { public: ArithmeticInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_), context_(ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_), context_(ctx) {} ~ArithmeticInt8CPUKernel(); int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc index 1688f27bb4..c94689722c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc @@ -28,6 +28,10 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { int ArithmeticSelfInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } int ret = ReSize(); auto *input_tensor = inputs_.at(kInputIndex); auto in_quant_args = input_tensor->GetQuantParams(); @@ -93,11 +97,16 @@ int ArithmeticSelfInt8CPUKernel::DoArithmeticSelf(int task_id) { } int ArithmeticSelfInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input_tensor = inputs_.at(0); auto out_tensor = outputs_.at(0); in_ptr_ = reinterpret_cast(input_tensor->Data()); out_ptr_ = reinterpret_cast(out_tensor->Data()); - int ret = LiteBackendParallelLaunch(ArithmeticSelfInt8Runs, this, thread_sz_count_); + ret = LiteBackendParallelLaunch(ArithmeticSelfInt8Runs, this, thread_sz_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]"; return ret; @@ -108,13 +117,14 @@ int ArithmeticSelfInt8CPUKernel::Run() { kernel::LiteKernel *CpuArithmeticSelfInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); if (opParameter == nullptr) { MS_LOG(ERROR) << "Creator failed, opParameter is nullptr!"; return nullptr; } - auto *kernel = new (std::nothrow) ArithmeticSelfInt8CPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) ArithmeticSelfInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); MS_ASSERT(kernel != nullptr); auto ret = kernel->Init(); if (ret != RET_OK) { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.h index 7dfe7d14ad..a507313a40 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.h @@ -24,10 +24,7 @@ #include "schema/model_generated.h" #include "include/context.h" - using mindspore::lite::Context; -using mindspore::schema::PrimitiveType_Round; -using mindspore::schema::PrimitiveType_Floor; using mindspore::schema::PrimitiveType_Ceil; using mindspore::schema::PrimitiveType_Abs; using mindspore::schema::PrimitiveType_Sin; @@ -37,6 +34,8 @@ using mindspore::schema::PrimitiveType_Sqrt; using mindspore::schema::PrimitiveType_Rsqrt; using mindspore::schema::PrimitiveType_Square; using mindspore::schema::PrimitiveType_LogicalNot; +using mindspore::schema::PrimitiveType_Floor; +using mindspore::schema::PrimitiveType_Round; namespace mindspore::kernel { class ArithmeticSelfInt8CPUKernel : public LiteKernel { @@ -44,8 +43,9 @@ class ArithmeticSelfInt8CPUKernel : public LiteKernel { public: explicit ArithmeticSelfInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { switch (parameter->type_) { case PrimitiveType_Round: arithmeticSelf_run_ = ElementRound; @@ -106,4 +106,3 @@ class ArithmeticSelfInt8CPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARITHMETIC_SELF_INT8_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.cc index bdf49bf14e..db0b7df701 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.cc @@ -42,6 +42,11 @@ int BatchToSpaceInt8CPUKernel::Init() { } int BatchToSpaceInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input = inputs_[0]; auto output = outputs_[0]; const int8_t *input_data = reinterpret_cast(input->Data()); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.h index 17f30f004f..7755cbd09c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.h @@ -23,8 +23,9 @@ namespace mindspore::kernel { class BatchToSpaceInt8CPUKernel : public BatchToSpaceBaseCPUKernel { public: BatchToSpaceInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : BatchToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : BatchToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~BatchToSpaceInt8CPUKernel() = default; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.cc index f80daade18..1c6afa9d4b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.cc @@ -26,6 +26,10 @@ using mindspore::schema::PrimitiveType_BiasAdd; namespace mindspore::kernel { int BiasAddInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto bias_param = reinterpret_cast(opParameter); auto dims = inputs_[0]->shape(); bias_param->ndim_ = dims.size(); @@ -41,6 +45,11 @@ int BiasAddInt8CPUKernel::Init() { int BiasAddInt8CPUKernel::ReSize() { return NNACL_OK; } int BiasAddInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto in = reinterpret_cast(inputs_.at(0)->Data()); auto bias = reinterpret_cast(inputs_.at(1)->Data()); auto out = reinterpret_cast(outputs_.at(0)->Data()); @@ -59,14 +68,14 @@ int BiasAddInt8CPUKernel::Run() { kernel::LiteKernel *CpuBiasAddInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, - OpParameter *parameter, const lite::Context *ctx, - const KernelKey &desc) { + OpParameter *parameter, const lite::Context *ctx, const KernelKey &desc, + const lite::Primitive *primitive) { if (parameter == nullptr || ctx == nullptr) { MS_LOG(ERROR) << "parameter or context is nullptr"; return nullptr; } MS_ASSERT(desc.type == PrimitiveType_BiasAdd); - auto *kernel = new (std::nothrow) BiasAddInt8CPUKernel(parameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) BiasAddInt8CPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.h index 3ced965318..c8c7717be4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/bias_add_int8.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class BiasAddInt8CPUKernel : public LiteKernel { public: BiasAddInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx) {} ~BiasAddInt8CPUKernel() = default; int Init() override; @@ -39,4 +40,3 @@ class BiasAddInt8CPUKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_BAIS_ADD_INT8_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.cc index ac96e28c9c..5ccf9907d4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.cc @@ -27,6 +27,10 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { int ConcatInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } ConcatBaseCPUKernel::Init(); auto input_num = inputs_.size(); concat_param_->input_num_ = input_num; @@ -75,6 +79,12 @@ int ConcatInt8CPUKernel::Init() { int ConcatInt8CPUKernel::ReSize() { return 0; } int ConcatInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } + auto input_num = concat_param_->input_num_; count_unit_ = thread_count_ > 1 ? UP_DIV(before_axis_size, thread_count_) : before_axis_size; concat_param_->count_unit_ = count_unit_; @@ -88,7 +98,7 @@ int ConcatInt8CPUKernel::Run() { } output_data_ = reinterpret_cast(outputs_.at(0)->Data()); - auto ret = LiteBackendParallelLaunch(ConcatInt8Run, this, thread_count_); + ret = LiteBackendParallelLaunch(ConcatInt8Run, this, thread_count_); ctx_->allocator->Free(input_data_); ctx_->allocator->Free(concat_param_->input_shapes_); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.h index 01846f6da9..1d09049ffe 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/concat_int8.h @@ -29,8 +29,9 @@ namespace mindspore::kernel { class ConcatInt8CPUKernel : public ConcatBaseCPUKernel { public: ConcatInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConcatBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : ConcatBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~ConcatInt8CPUKernel() override {} int Init() override; @@ -49,4 +50,3 @@ int ConcatInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata); } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONCAT_INT8_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.cc index e3c9471514..530e1275c3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.cc @@ -161,6 +161,10 @@ void Convolution3x3Int8CPUKernel::ConfigInputOutput() { } int Convolution3x3Int8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvolutionBase init failed."; @@ -232,6 +236,11 @@ int Convolution3x3Int8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) } int Convolution3x3Int8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input_addr = reinterpret_cast(inputs_.at(kInputIndex)->Data()); PackInputToC8Int8(input_addr, input_data_, conv_param_); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.h index ef0a8a5560..b5dbc12449 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_3x3_int8.h @@ -27,8 +27,9 @@ namespace mindspore::kernel { class Convolution3x3Int8CPUKernel : public ConvolutionBaseCPUKernel { public: Convolution3x3Int8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~Convolution3x3Int8CPUKernel() override; int Init() override; @@ -51,4 +52,3 @@ void ProcessFilterUint8(int8_t *origin_weight, int16_t *dst_weight, ConvParamete } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_CONVOLUTION_3X3_INT8_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc index 87fa4e1690..a9208508d3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc @@ -81,6 +81,10 @@ int ConvolutionDepthwiseInt8CPUKernel::InitBuffer() { } int ConvolutionDepthwiseInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } // conv base init ConvolutionBaseCPUKernel::Init(); @@ -145,6 +149,11 @@ int ConvDwInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int ConvolutionDepthwiseInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } if (conv_param_->input_channel_ != conv_param_->output_channel_) { MS_LOG(ERROR) << "Only support input channel equals output channel."; return RET_ERROR; @@ -160,7 +169,7 @@ int ConvolutionDepthwiseInt8CPUKernel::Run() { packed_output_ = output_addr; } - auto ret = LiteBackendParallelLaunch(ConvDwInt8Run, this, conv_param_->thread_num_); + ret = LiteBackendParallelLaunch(ConvDwInt8Run, this, conv_param_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvDwInt8Run error: error_code[" << ret << "]"; return RET_ERROR; @@ -176,10 +185,11 @@ int ConvolutionDepthwiseInt8CPUKernel::Run() { kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); - auto kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx); + auto kernel = + new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.h index e0ce121dad..a6e068d90d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.h @@ -26,8 +26,9 @@ namespace mindspore::kernel { class ConvolutionDepthwiseInt8CPUKernel : public ConvolutionBaseCPUKernel { public: ConvolutionDepthwiseInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~ConvolutionDepthwiseInt8CPUKernel() override { delete sliding; free(packed_weight_); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc index 00415cc542..8aac86c288 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc @@ -269,6 +269,10 @@ void ConvolutionInt8CPUKernel::ConfigInputOutput() { } int ConvolutionInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto ret = ConvolutionBaseCPUKernel::Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvolutionBase init failed."; @@ -379,6 +383,11 @@ int ConvolutionInt8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int ConvolutionInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input_tensor = inputs_.at(kInputIndex); auto ori_input_data = input_tensor->Data(); int in_batch = conv_param_->input_batch_; @@ -398,7 +407,7 @@ int ConvolutionInt8CPUKernel::Run() { kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); auto conv_param = reinterpret_cast(opParameter); @@ -410,9 +419,9 @@ kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vectordilation_w_; kernel::LiteKernel *kernel; if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { - kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx); + kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); } else { - kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx); + kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); } if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.h index 3250cfa112..bc3da9ab01 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.h @@ -27,8 +27,9 @@ namespace mindspore::kernel { class ConvolutionInt8CPUKernel : public ConvolutionBaseCPUKernel { public: ConvolutionInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~ConvolutionInt8CPUKernel() override { if (packed_weight_ != nullptr) { free(packed_weight_); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.cc index 4e73001954..4023309b8d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.cc @@ -54,7 +54,12 @@ int CropInt8CPUKernel::Init() { int CropInt8CPUKernel::ReSize() { return 0; } int CropInt8CPUKernel::Run() { - auto ret = LiteBackendParallelLaunch(CropInt8Run, this, thread_count_); + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } + ret = LiteBackendParallelLaunch(CropInt8Run, this, thread_count_); return ret; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.h index b0ff4af359..598dde3490 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.h @@ -29,8 +29,9 @@ namespace mindspore::kernel { class CropInt8CPUKernel : public CropBaseCPUKernel { public: CropInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : CropBaseCPUKernel(parameter, inputs, outputs, ctx) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : CropBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) { crop_para_ = reinterpret_cast(opParameter); crop_para_->thread_count_ = opParameter->thread_num_; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc index 6de3c47211..7280a05478 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc @@ -115,6 +115,10 @@ int DeconvolutionDepthwiseInt8CPUKernel::InitBuffer() { } int DeconvolutionDepthwiseInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } sliding = new SlidingWindowParam; InitSlideParam(); @@ -174,6 +178,11 @@ int DeconvDwInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int DeconvolutionDepthwiseInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } if (conv_param_->input_channel_ != conv_param_->output_channel_) { MS_LOG(ERROR) << "Only support input channel equals output channel."; return RET_ERROR; @@ -190,7 +199,7 @@ int DeconvolutionDepthwiseInt8CPUKernel::Run() { packed_output_ = output_addr; } - auto ret = LiteBackendParallelLaunch(DeconvDwInt8Run, this, conv_param_->thread_num_); + ret = LiteBackendParallelLaunch(DeconvDwInt8Run, this, conv_param_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "DeconvDwInt8Run error: error_code[" << ret << "]"; return RET_ERROR; @@ -206,10 +215,11 @@ int DeconvolutionDepthwiseInt8CPUKernel::Run() { kernel::LiteKernel *CpuDeconvDwInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); - auto kernel = new (std::nothrow) kernel::DeconvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx); + auto kernel = + new (std::nothrow) kernel::DeconvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h index 3e56264aae..d7b76438f3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.h @@ -26,8 +26,9 @@ namespace mindspore::kernel { class DeconvolutionDepthwiseInt8CPUKernel : public ConvolutionBaseCPUKernel { public: DeconvolutionDepthwiseInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~DeconvolutionDepthwiseInt8CPUKernel() override { delete sliding; free(packed_weight_); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc index f21c5bc1c4..13d078d5e0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc @@ -115,6 +115,10 @@ int DeConvInt8CPUKernel::InitData() { } int DeConvInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } ConvolutionBaseCPUKernel::Init(); int error_code = ConvolutionBaseCPUKernel::SetQuantParam(); if (error_code != RET_OK) { @@ -196,6 +200,11 @@ int DeConvInt8CPUKernel::DoPostFunc(int task_id) { } int DeConvInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } int8_t *src_in = reinterpret_cast(inputs_[0]->Data()); int8_t *src_out = reinterpret_cast(outputs_[0]->Data()); @@ -222,10 +231,10 @@ int DeConvInt8CPUKernel::Run() { kernel::LiteKernel *CpuDeConvInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); - auto kernel = new (std::nothrow) kernel::DeConvInt8CPUKernel(opParameter, inputs, outputs, ctx); + auto kernel = new (std::nothrow) kernel::DeConvInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h index b63633f9a5..ce19d61e9c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h @@ -31,8 +31,9 @@ namespace mindspore::kernel { class DeConvInt8CPUKernel : public ConvolutionBaseCPUKernel { public: DeConvInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~DeConvInt8CPUKernel() override; int ReSize() override; @@ -60,4 +61,3 @@ class DeConvInt8CPUKernel : public ConvolutionBaseCPUKernel { }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_INT8_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.cc index b70624aa18..75af0bf7d7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.cc @@ -21,8 +21,8 @@ #include "src/runtime/kernel/arm/nnacl/int8/depth_to_space_int8.h" #include "include/errorcode.h" -using mindspore::lite::RET_OK; using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; namespace mindspore::kernel { int DepthToSpaceInt8CPUKernel::Init() { @@ -46,6 +46,11 @@ int DepthToSpaceInt8CPUKernel::Init() { } int DepthToSpaceInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input = inputs_[0]; auto output = outputs_[0]; const int8_t *input_data = reinterpret_cast(input->Data()); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h index 427b6d5eb0..4b3520950e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h @@ -23,8 +23,9 @@ namespace mindspore::kernel { class DepthToSpaceInt8CPUKernel : public DepthToSpaceBaseCPUKernel { public: DepthToSpaceInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : DepthToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : DepthToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~DepthToSpaceInt8CPUKernel() = default; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc index a14438c815..aa6ebd5edf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc @@ -25,6 +25,10 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { int FullconnectionInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } fc_param_->row_ = (inputs_[0]->shape())[0]; fc_param_->col_ = (inputs_[1]->shape())[0]; fc_param_->deep_ = (inputs_[1]->shape())[1]; @@ -113,6 +117,11 @@ int FcInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int FullconnectionInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto a_ptr = reinterpret_cast(inputs_[0]->Data()); auto output_ptr = reinterpret_cast(outputs_[0]->Data()); auto &p = quant_params_; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h index 30aef2d45f..66010c584d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h @@ -28,8 +28,9 @@ namespace mindspore::kernel { class FullconnectionInt8CPUKernel : public FullconnectionBaseCPUKernel { public: FullconnectionInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : FullconnectionBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : FullconnectionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~FullconnectionInt8CPUKernel() override { ctx_->allocator->Free(a_c8_ptr_); ctx_->allocator->Free(b_r8_ptr_); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.cc index 7b5f9795eb..3c2b688238 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.cc @@ -89,6 +89,11 @@ int HswishInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int HswishInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } int error_code = LiteBackendParallelLaunch(HswishInt8Run, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "HswishInt8Run function error error_code[" << error_code << "]"; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.h index b907f2c338..7a480f98e7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/hswish_int8.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class HswishInt8CPUKernel : public LiteKernel { public: HswishInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} ~HswishInt8CPUKernel() override = default; int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc index 5015fa0cea..63b23f3fc2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc @@ -31,6 +31,10 @@ MatmulInt8CPUKernel::~MatmulInt8CPUKernel() { } int MatmulInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } int batch = 1; auto x_shape = inputs_[0]->shape(); auto o_shape = outputs_[0]->shape(); @@ -109,6 +113,11 @@ int MatmulInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int MatmulInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto a_ptr = reinterpret_cast(inputs_[0]->Data()); auto b_ptr = reinterpret_cast(inputs_[1]->Data()); auto c_ptr = reinterpret_cast(outputs_[0]->Data()); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h index dc8f5ec0b6..d05aeb7c84 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h @@ -28,8 +28,9 @@ namespace mindspore::kernel { class MatmulInt8CPUKernel : public MatmulBaseCPUKernel { public: MatmulInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : MatmulBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : MatmulBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~MatmulInt8CPUKernel() override; int Init() override; int ReSize() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc index 8c43954f33..eab7550f60 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.cc @@ -62,6 +62,11 @@ int MulInt8CPUKernel::Init() { int MulInt8CPUKernel::ReSize() { return RET_OK; } int MulInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } input0_data_ = static_cast(inputs_.at(0)->Data()); input1_data_ = static_cast(inputs_.at(1)->Data()); output_data_ = static_cast(outputs_.at(0)->Data()); @@ -81,13 +86,13 @@ int MulInt8CPUKernel::Run() { } TileDimensionsInt8(static_cast(inputs_.at(0)->Data()), static_cast(inputs_.at(1)->Data()), input0_data_, input1_data_, &tile_para); - auto ret = LiteBackendParallelLaunch(MulInt8Run, this, thread_count_); + ret = LiteBackendParallelLaunch(MulInt8Run, this, thread_count_); ctx_->allocator->Free(input0_data_); ctx_->allocator->Free(input1_data_); return ret; } - auto ret = LiteBackendParallelLaunch(MulInt8Run, this, thread_count_); + ret = LiteBackendParallelLaunch(MulInt8Run, this, thread_count_); return ret; } @@ -112,10 +117,11 @@ int MulInt8CPUKernel::DoExecute(int task_id) { kernel::LiteKernel *CpuMulInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, const KernelKey &desc) { + OpParameter *opParameter, const lite::Context *ctx, const KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Mul); - auto *kernel = new (std::nothrow) MulInt8CPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) MulInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h index 1acb1b15bc..79f84c4987 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/mul_int8.h @@ -25,9 +25,10 @@ namespace mindspore::kernel { class MulInt8CPUKernel : public LiteKernel { public: explicit MulInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx_->thread_num_) {} - ~MulInt8CPUKernel() override {}; + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx_->thread_num_) {} + ~MulInt8CPUKernel() override{}; int Init() override; int ReSize() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.cc index 42b3f7c16a..a5307f71bf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.cc @@ -93,6 +93,10 @@ int PadInt8CPUKernel::ReSize() { } int PadInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } int error_code = InitPadParam(); if (error_code != RET_OK) { MS_LOG(ERROR) << "InitPadParam failed. errorcode: " << error_code; @@ -108,6 +112,11 @@ int PadInt8CPUKernel::Init() { } int PadInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } int8_t *in_data = reinterpret_cast(inputs_[0]->Data()); int8_t *out_data = reinterpret_cast(outputs_[0]->Data()); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.h index 3d9dda740b..4bc7b5f302 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.h @@ -27,8 +27,9 @@ namespace mindspore::kernel { class PadInt8CPUKernel : public LiteKernel { public: explicit PadInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { opParameter->thread_num_ = ctx->thread_num_; pad_param_ = reinterpret_cast(opParameter); } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc index 9a810d27f8..36b10c7344 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc @@ -26,6 +26,10 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { int PoolingInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto ret = PoolingBaseCPUKernel::Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "PoolingBase Init failed."; @@ -77,6 +81,11 @@ int PoolingInt8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int PoolingInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } int error_code = LiteBackendParallelLaunch(PoolingInt8Impl, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "poolingInt8 error error_code[" << error_code << "]"; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.h index 367c10f59c..fafdf09c6d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.h @@ -29,8 +29,9 @@ namespace mindspore::kernel { class PoolingInt8CPUKernel : public PoolingBaseCPUKernel { public: PoolingInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : PoolingBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : PoolingBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~PoolingInt8CPUKernel() { FreeQuantParam(); } int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/prelu_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/prelu_int8.cc index edeba1d79e..2f4d36a6fc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/prelu_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/prelu_int8.cc @@ -29,6 +29,10 @@ using mindspore::schema::PrimitiveType_Prelu; namespace mindspore::kernel { int PreluInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } PreluBaseCPUKernel::Init(); auto *input_tensor = inputs_.at(kInputIndex); auto in_quant_args = input_tensor->GetQuantParams(); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/prelu_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/prelu_int8.h index c93fb1256b..c50378e63d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/prelu_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/prelu_int8.h @@ -29,8 +29,9 @@ namespace mindspore::kernel { class PreluInt8CPUKernel : public PreluBaseCPUKernel { public: PreluInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : PreluBaseCPUKernel(parameter, inputs, outputs, ctx) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : PreluBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) { quant_prelu_parm_ = reinterpret_cast(opParameter); } ~PreluInt8CPUKernel() override {} diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.cc index 01161b9f06..fcbffcac79 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.cc @@ -28,6 +28,11 @@ using mindspore::schema::ActivationType_RELU; namespace mindspore::kernel { int ReluXInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } + lite::tensor::Tensor *input = inputs_.at(0); lite::tensor::Tensor *output = outputs_.at(0); MS_ASSERT(input); @@ -69,6 +74,11 @@ int ReluXInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int ReluXInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } int error_code = LiteBackendParallelLaunch(ReluXInt8Run, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "ReluXInt8Run function error error_code[" << error_code << "]"; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.h index 44cffe4bc0..90a31ed0c8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.h @@ -26,8 +26,9 @@ namespace mindspore::kernel { class ReluXInt8CPUKernel : public LiteKernel { public: ReluXInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { type_ = (reinterpret_cast(parameter))->type_; } ~ReluXInt8CPUKernel() override = default; @@ -47,8 +48,9 @@ class ReluXInt8CPUKernel : public LiteKernel { class ReluInt8CPUKernel : public ReluXInt8CPUKernel { public: ReluInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : ReluXInt8CPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : ReluXInt8CPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~ReluInt8CPUKernel() override = default; @@ -63,8 +65,9 @@ class ReluInt8CPUKernel : public ReluXInt8CPUKernel { class Relu6Int8CPUKernel : public ReluXInt8CPUKernel { public: Relu6Int8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : ReluXInt8CPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : ReluXInt8CPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~Relu6Int8CPUKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.cc index ad34ff8894..a6b8402253 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.cc @@ -27,6 +27,10 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { int ReshapeInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } ReshapeBaseCPUKernel::Init(); auto *input_tensor = inputs_.at(kInputIndex); auto in_quant_args = input_tensor->GetQuantParams(); @@ -47,6 +51,11 @@ int ReshapeInt8CPUKernel::Init() { int ReshapeInt8CPUKernel::ReSize() { return 0; } int ReshapeInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } MS_ASSERT(inputs_.size() == 1); MS_ASSERT(outputs_.size() == 1); input_data_ = static_cast(inputs_.at(kInputIndex)->Data()); @@ -55,7 +64,7 @@ int ReshapeInt8CPUKernel::Run() { elements_num_ = inputs_.at(kInputIndex)->ElementsNum(); count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; - auto ret = LiteBackendParallelLaunch(ReshapeInt8Run, this, thread_count_); + ret = LiteBackendParallelLaunch(ReshapeInt8Run, this, thread_count_); return ret; } @@ -77,4 +86,3 @@ int ReshapeInt8CPUKernel::DoExecute(int task_id) { return lite::RET_OK; } } // namespace mindspore::kernel - diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.h index c533c42bf2..1112520734 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/reshape_int8.h @@ -29,8 +29,9 @@ namespace mindspore::kernel { class ReshapeInt8CPUKernel : public ReshapeBaseCPUKernel { public: ReshapeInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : ReshapeBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : ReshapeBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~ReshapeInt8CPUKernel() = default; int Init() override; @@ -50,4 +51,3 @@ int ReshapeInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata); } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_RESHAPE_INT8_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/sigmoid_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/sigmoid_int8.cc index 9adfa5165d..5abf0773b5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/sigmoid_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/sigmoid_int8.cc @@ -89,6 +89,11 @@ int SigmoidInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int SigmoidInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } int error_code = LiteBackendParallelLaunch(SigmoidInt8Run, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "SigmoidInt8Run function error error_code[" << error_code << "]"; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/sigmoid_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/sigmoid_int8.h index da0aa408f5..a2d9db97ad 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/sigmoid_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/sigmoid_int8.h @@ -25,8 +25,9 @@ namespace mindspore::kernel { class SigmoidInt8CPUKernel : public LiteKernel { public: SigmoidInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} ~SigmoidInt8CPUKernel() override = default; int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.cc index 2a3ef9d6ca..2c6591bcb7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.cc @@ -26,6 +26,10 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { int SoftmaxInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } SoftmaxBaseCPUKernel::Init(); auto *input_tensor = inputs_.at(kInputIndex); @@ -95,6 +99,11 @@ int SoftmaxRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int SoftmaxInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return RET_ERROR; + } auto input_ptr = reinterpret_cast(inputs_.at(0)->Data()); int ele_size = softmax_param_->element_size_; for (int i = 0; i < ele_size; i++) { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h index d29da4094a..f3cbe6eacf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/softmax_int8.h @@ -24,8 +24,9 @@ namespace mindspore::kernel { class SoftmaxInt8CPUKernel : public SoftmaxBaseCPUKernel { public: SoftmaxInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx) - : SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~SoftmaxInt8CPUKernel() = default; int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/split_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/split_int8.cc index 80f9cdf61c..2c6f0762ef 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/split_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/split_int8.cc @@ -28,6 +28,10 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { int SplitInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } SplitBaseCPUKernel::Init(); auto in_tensor = inputs_.at(kInputIndex); input_ptr_ = reinterpret_cast(in_tensor->Data()); @@ -81,7 +85,12 @@ int SplitInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int SplitInt8CPUKernel::Run() { - int ret = LiteBackendParallelLaunch(SplitInt8Run, this, thread_n_num_); + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } + ret = LiteBackendParallelLaunch(SplitInt8Run, this, thread_n_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/split_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/split_int8.h index 4e1ad3f213..ad20fd4adc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/split_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/split_int8.h @@ -29,8 +29,9 @@ namespace mindspore::kernel { class SplitInt8CPUKernel : public SplitBaseCPUKernel { public: SplitInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : SplitBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : SplitBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~SplitInt8CPUKernel() = default; int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/squeeze_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/squeeze_int8.cc index 7cd4649022..30755bdfc0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/squeeze_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/squeeze_int8.cc @@ -29,6 +29,10 @@ using mindspore::lite::RET_OK; namespace mindspore::kernel { int SqueezeInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } SqueezeBaseCPUKernel::Init(); quant_Squeeze_parm_ = new (std::nothrow) SqueezeQuantArg; auto input_num = inputs_.size(); @@ -108,6 +112,11 @@ int SqueezeInt8CPUKernel::Init() { int SqueezeInt8CPUKernel::ReSize() { return 0; } int SqueezeInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } auto input_dim = quant_Squeeze_parm_->input_num_; int8_t **inputs_array = reinterpret_cast(malloc(sizeof(int8_t *) * input_dim)); for (size_t i = 0; i < input_dim; i++) { @@ -140,7 +149,7 @@ int SqueezeInt8CPUKernel::Run() { free(*(inputs_array + i)); } - auto ret = LiteBackendParallelLaunch(SqueezeInt8Run, this, thread_count_); + ret = LiteBackendParallelLaunch(SqueezeInt8Run, this, thread_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "RunSqueezeParam failed. errorcode: "; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/squeeze_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/squeeze_int8.h index 9ff467ffc7..d4ae65561f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/squeeze_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/squeeze_int8.h @@ -29,8 +29,9 @@ namespace mindspore::kernel { class SqueezeInt8CPUKernel : public SqueezeBaseCPUKernel { public: SqueezeInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : SqueezeBaseCPUKernel(parameter, inputs, outputs, ctx) {} + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : SqueezeBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~SqueezeInt8CPUKernel() override { delete quant_Squeeze_parm_; } int Init() override; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc index 2280d9e078..0e96c8c9be 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc @@ -25,6 +25,10 @@ using mindspore::schema::PrimitiveType_TopK; namespace mindspore::kernel { int TopKInt8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } TopkParameter *parameter = reinterpret_cast(opParameter); lite::tensor::Tensor *input = inputs_.at(0); parameter->last_dim_size_ = input->shape()[input->shape().size() - 1]; @@ -44,6 +48,11 @@ int TopKInt8CPUKernel::Init() { int TopKInt8CPUKernel::ReSize() { return RET_OK; } int TopKInt8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } int8_t *input_data = reinterpret_cast(inputs_.at(0)->Data()); int8_t *output_data = reinterpret_cast(outputs_.at(0)->Data()); int32_t *output_index = reinterpret_cast(outputs_.at(1)->Data()); @@ -54,9 +63,10 @@ int TopKInt8CPUKernel::Run() { kernel::LiteKernel *CpuTopKInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *parameter, - const lite::Context *ctx, const KernelKey &desc) { + const lite::Context *ctx, const KernelKey &desc, + const lite::Primitive *primitive) { MS_ASSERT(parameter != nullptr); - auto *kernel = new (std::nothrow) TopKInt8CPUKernel(parameter, inputs, outputs); + TopKInt8CPUKernel *kernel = new (std::nothrow) TopKInt8CPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new TopKInt8CPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h index 1216033e9c..513455c1e7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h @@ -24,8 +24,9 @@ namespace mindspore::kernel { class TopKInt8CPUKernel : public LiteKernel { public: explicit TopKInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + const std::vector &outputs, const lite::Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} ~TopKInt8CPUKernel() override { TopkParameter *parameter = reinterpret_cast(opParameter); free(parameter->topk_node_list_); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/unsqueeze_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/unsqueeze_int8.cc index 42936af713..450034e93f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/unsqueeze_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/unsqueeze_int8.cc @@ -29,6 +29,10 @@ using mindspore::schema::PrimitiveType_Unsqueeze; namespace mindspore::kernel { int Unsqueezeint8CPUKernel::Init() { + if (context_->infer_shape_interrupt_ && !context_->running_) { + SetNeedReInit(); + return RET_OK; + } auto *input_tensor = inputs_.at(0); auto quant_args = input_tensor->GetQuantParams(); MS_ASSERT(quant_args.size() == 1); @@ -80,9 +84,14 @@ int UnsqueezeIn8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int Unsqueezeint8CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare failed."; + return ret; + } in_ptr_ = reinterpret_cast(inputs_.at(0)->Data()); out_ptr_ = reinterpret_cast(outputs_.at(0)->Data()); - int ret = LiteBackendParallelLaunch(UnsqueezeIn8Run, this, thread_sz_count_); + ret = LiteBackendParallelLaunch(UnsqueezeIn8Run, this, thread_sz_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "UnsqueezeRun error error_code[" << ret << "]"; return ret; @@ -93,10 +102,10 @@ int Unsqueezeint8CPUKernel::Run() { kernel::LiteKernel *CpuUnsqueezeInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Unsqueeze); - auto *kernel = new (std::nothrow) Unsqueezeint8CPUKernel(opParameter, inputs, outputs, ctx); + auto *kernel = new (std::nothrow) Unsqueezeint8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new UnsqueezeCPUKernel fail!"; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/unsqueeze_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/unsqueeze_int8.h index 070cf1fdd1..4c32cd05ef 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/unsqueeze_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/unsqueeze_int8.h @@ -28,8 +28,9 @@ namespace mindspore::kernel { class Unsqueezeint8CPUKernel : public LiteKernel { public: Unsqueezeint8CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const Context *ctx) - : LiteKernel(parameter, inputs, outputs), ctx_(ctx), thread_count_(ctx->thread_num_) { + const std::vector &outputs, const Context *ctx, + const lite::Primitive *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx->thread_num_) { Unsq_para_ = reinterpret_cast(opParameter); Unsq_para_->thread_count_ = opParameter->thread_num_; } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.cc index 7e83fded24..5fa758f610 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.cc @@ -151,7 +151,7 @@ int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); } int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); } -void ReluFp32(float *data, int ele_num) { +void ReluFp32(float *data, float *dst, int ele_num) { int four_block = UP_DIV(ele_num, C4NUM); for (int i = 0; i < four_block - 1; i++) { int index = i * C4NUM; @@ -159,7 +159,7 @@ void ReluFp32(float *data, int ele_num) { float32x4_t relu_data = vld1q_f32(data + index); float32x4_t zero_data = vdupq_n_f32(0); relu_data = vmaxq_f32(relu_data, zero_data); - vst1q_f32(data + index, relu_data); + vst1q_f32(dst + index, relu_data); #else data[index] = data[index] < 0 ? 0 : data[index]; data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; @@ -172,7 +172,7 @@ void ReluFp32(float *data, int ele_num) { } } -void Relu6Fp32(float *data, int ele_num) { +void Relu6Fp32(float *data, float *dst, int ele_num) { int four_block = UP_DIV(ele_num, C4NUM); for (int i = 0; i < four_block - 1; i++) { int index = i * C4NUM; @@ -182,7 +182,7 @@ void Relu6Fp32(float *data, int ele_num) { float32x4_t six_data = vdupq_n_f32(6); relu6_data = vmaxq_f32(relu6_data, zero_data); relu6_data = vminq_f32(relu6_data, six_data); - vst1q_f32(data + index, relu6_data); + vst1q_f32(dst + index, relu6_data); #else data[index] = data[index] < 0 ? 0 : data[index]; data[index] = data[index] > 6 ? 6 : data[index]; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.h index 82daa49f25..1d59aaf902 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.h @@ -29,8 +29,8 @@ extern "C" { int8_t MinInt8(int8_t a, int8_t b); int8_t MaxInt8(int8_t a, int8_t b); -void ReluFp32(float *data, int ele_num); -void Relu6Fp32(float *data, int ele_num); +void ReluFp32(float *data, float *dst, int ele_num); +void Relu6Fp32(float *data, float *dst, int ele_num); void PostFuncInt8(const int *in, const int *bias, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi); void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv.cc index 624ffd7565..ab97e7ee55 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/conv.cc @@ -130,9 +130,9 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ out_unit); int output_num = out_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; if (is_relu) { - ReluFp32(output_data, output_num); + ReluFp32(output_data, output_data, output_num); } else if (is_relu6) { - Relu6Fp32(output_data, output_num); + Relu6Fp32(output_data, output_data, output_num); } else { // do nothing } @@ -219,9 +219,9 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat } int output_num = output_channel * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_batch_; if (is_relu) { - ReluFp32(output_data, output_num); + ReluFp32(output_data, output_data, output_num); } else if (is_relu6) { - Relu6Fp32(output_data, output_num); + Relu6Fp32(output_data, output_data, output_num); } else { // do nothing } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc index 1deeaf51b4..c43c410eb0 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc @@ -163,7 +163,7 @@ int ArithmeticOpenCLKernel::Run() { kernel::LiteKernel *OpenCLArithmeticKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { auto *kernel = new ArithmeticOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs, ctx); if (kernel == nullptr) { MS_LOG(ERROR) << "Create OpenCL Arithmetic kernel failed!"; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc index 82540b49f6..3226d1bc44 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc @@ -196,7 +196,7 @@ int ConcatOpenCLKernel::Run() { kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { auto *kernel = new ConcatOpenCLKernel(opParameter, inputs, outputs); auto ret = kernel->Init(); if (0 != ret) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc index 6aefa36a69..b4ea3fb933 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc @@ -171,7 +171,8 @@ int Conv2dTransposeOpenCLKernel::Run() { kernel::LiteKernel *OpenCLConv2dTransposeKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, + const lite::Primitive *primitive) { auto *kernel = new Conv2dTransposeOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); auto ret = kernel->Init(); if (0 != ret) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h index b3299d2a53..cf97920102 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h @@ -35,8 +35,8 @@ class Conv2dTransposeOpenCLKernel : public LiteKernel { public: explicit Conv2dTransposeOpenCLKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} - ~Conv2dTransposeOpenCLKernel() override {}; + : LiteKernel(parameter, inputs, outputs, nullptr, nullptr) {} + ~Conv2dTransposeOpenCLKernel() override{}; int Init() override; int ReSize() override; @@ -52,4 +52,3 @@ class Conv2dTransposeOpenCLKernel : public LiteKernel { } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_CONV2D_TRANSPOSE_H_ - diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc index c7e6cf8efa..ec9715f1ff 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc @@ -306,7 +306,7 @@ int ConvolutionOpenCLKernel::Run() { kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { auto *kernel = new ConvolutionOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); if (kernel == nullptr) { MS_LOG(ERROR) << "Create OpenCL Convolution kernel failed!"; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc index e7a89821c0..58c343ac90 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc @@ -191,7 +191,8 @@ int DepthwiseConv2dOpenCLKernel::Run() { kernel::LiteKernel *OpenCLDepthwiseConv2dKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, + const lite::Primitive *primitive) { auto *kernel = new DepthwiseConv2dOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); auto ret = kernel->Init(); if (0 != ret) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc index 9a73ef48d7..cef0926aed 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc @@ -154,7 +154,7 @@ int MatMulOpenCLKernel::Run() { kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { bool hasBias = false; if (opParameter->type_ == PrimitiveType_FullConnection) { hasBias = (reinterpret_cast(opParameter))->has_bias_; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h index beb874cc67..a1e3e6cbfa 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h @@ -36,7 +36,7 @@ class MatMulOpenCLKernel : public LiteKernel { public: explicit MatMulOpenCLKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, bool hasBias) - : LiteKernel(parameter, inputs, outputs) { + : LiteKernel(parameter, inputs, outputs, nullptr, nullptr) { hasBias_ = hasBias; } ~MatMulOpenCLKernel() override{}; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc index 476ff23dbc..8a5b05e0a2 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc @@ -149,7 +149,7 @@ int PoolingOpenCLKernel::Run() { kernel::LiteKernel *OpenCLPooling2dKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { auto *kernel = new PoolingOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); if (kernel == nullptr) { MS_LOG(ERROR) << "Create OpenCL Pooling kernel failed!"; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc index 4bdf5db2c4..10e6176ff5 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc @@ -84,7 +84,7 @@ int SoftmaxOpenCLKernel::Run() { kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { auto *kernel = new SoftmaxOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); if (inputs[0]->shape()[0] > 1) { MS_LOG(ERROR) << "Init `Softmax` kernel failed: Unsupported multi-batch."; @@ -101,4 +101,3 @@ kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector &inputs, const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) { + : LiteKernel(parameter, inputs, outputs, nullptr, nullptr) { parameter_ = reinterpret_cast(parameter); } ~SoftmaxOpenCLKernel() override{}; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc index 428c071800..82bf48d58e 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc @@ -93,7 +93,7 @@ int TransposeOpenCLKernel::Run() { kernel::LiteKernel *OpenCLTransposeKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, - const kernel::KernelKey &desc) { + const kernel::KernelKey &desc, const lite::Primitive *primitive) { auto *kernel = new TransposeOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); auto ret = kernel->Init(); if (0 != ret) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h index cf76286e35..10dabfea71 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h @@ -25,7 +25,7 @@ class OpenCLKernel : public LiteKernel { public: explicit OpenCLKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + : LiteKernel(parameter, inputs, outputs, nullptr, nullptr) {} virtual int Init() { return -1; } virtual int Prepare() { return -1; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h index 7f7d5a343e..d965c2e7bd 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h +++ b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h @@ -36,7 +36,7 @@ class SubGraphOpenCLKernel : public SubGraphKernel { const std::vector inKernels, const std::vector outKernels, const std::vector nodes) - : SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes) {} + : SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, nullptr, nullptr) {} ~SubGraphOpenCLKernel() override; int Init() override; diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index d3a7ca7d69..80c05c2720 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -15,8 +15,8 @@ */ #include "src/scheduler.h" -#include #include +#include #include "include/errorcode.h" #include "src/kernel_factory.h" #if SUPPORT_GPU @@ -69,11 +69,21 @@ int Scheduler::InitOp2Kernel(const lite::Model *model, std::vectorprimitive()->value_type()); return RET_ERROR; } - auto ret = primitive->InferShape(inputs, outputs); - if (0 != ret) { - MS_LOG(ERROR) << "InferShape failed, name: " << cNode->name()->str() - << ", type: " << schema::EnumNamePrimitiveType(cNode->primitive()->value_type()); - return ret; + if (!context_->infer_shape_interrupt_) { + auto ret = primitive->InferShape(inputs, outputs); + if (ret == RET_INFER_INVALID) { + MS_LOG(INFO) << "InferShape shouldn't be done before runtime, name: " << cNode->name()->str() + << ", type: " << schema::EnumNamePrimitiveType(cNode->primitive()->value_type()) + << "flag set to false."; + primitive->SetInferFlag(false); + context_->InferShapeInterrupt(); + } else if (ret != RET_OK) { + MS_LOG(ERROR) << "InferShape failed, name: " << cNode->name()->str() + << ", type: " << schema::EnumNamePrimitiveType(cNode->primitive()->value_type()); + return RET_INFER_ERR; + } + } else { + primitive->SetInferFlag(false); } auto *kernel = this->ScheduleNode(inputs, outputs, primitive); diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h index f86ca2b035..38a8cca466 100644 --- a/mindspore/lite/src/scheduler.h +++ b/mindspore/lite/src/scheduler.h @@ -25,7 +25,9 @@ namespace mindspore::lite { class Scheduler { public: - explicit Scheduler(const Context *ctx) : context_(ctx) {} + explicit Scheduler(const Context *ctx) { + context_ = const_cast(ctx); + } int Schedule(const lite::Model *model, std::vector *tensors, std::vector *kernels); @@ -48,7 +50,7 @@ class Scheduler { protected: std::vector> markedKernelGroup; - const Context *context_ = nullptr; + Context *context_ = nullptr; }; } // namespace mindspore::lite diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/common/strided_slice_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/common/strided_slice_tests.cc index 16ea9e1f75..60e2acff00 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/common/strided_slice_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/common/strided_slice_tests.cc @@ -63,7 +63,7 @@ TEST_F(TestStridedSlice, StridedSlice) { ASSERT_NE(creator, nullptr); auto ctx = std::make_shared(); - auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc, nullptr); ASSERT_NE(kernel, nullptr); auto ret = kernel->Run(); @@ -109,7 +109,7 @@ TEST_F(TestStridedSlice, StridedSliceInt8) { ASSERT_NE(creator, nullptr); auto ctx = std::make_shared(); - auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc, nullptr); ASSERT_NE(kernel, nullptr); auto ret = kernel->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc index 0fdfe5f29e..f09caab0b5 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc @@ -114,7 +114,7 @@ TEST_F(TestActivationFp32, HSwishFp32) { lite::Context ctx; ctx.thread_num_ = 7; kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), &ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), &ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor.shape(); kernel->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_grad_fp32_tests.cc index fe4a3b7cf2..5003673285 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/arithmetic_grad_fp32_tests.cc @@ -91,7 +91,7 @@ TEST_F(TestArithmeticGradFp32, TestAddGradFp32) { std::vector outputs = {all_tensors[3], all_tensors[4]}; kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_AddGrad}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc, nullptr); kernel_obj->Run(); float *output_ptr = reinterpret_cast(outputs[1]->Data()); @@ -122,7 +122,7 @@ TEST_F(TestArithmeticGradFp32, TestAddGrad2Fp32) { std::vector outputs = {all_tensors[4], all_tensors[3]}; kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_AddGrad}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc, nullptr); kernel_obj->Run(); float *output_ptr = reinterpret_cast(outputs[0]->Data()); @@ -153,7 +153,7 @@ TEST_F(TestArithmeticGradFp32, TestAddGrad3Fp32) { std::vector outputs = {all_tensors[3], all_tensors[4]}; kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_AddGrad}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc, nullptr); kernel_obj->Run(); float *output_ptr = reinterpret_cast(outputs[0]->Data()); @@ -184,7 +184,7 @@ TEST_F(TestArithmeticGradFp32, TestSubGradFp32) { std::vector outputs = {all_tensors[3], all_tensors[4]}; kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SubGrad}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc, nullptr); kernel_obj->Run(); float *output_ptr = reinterpret_cast(outputs[1]->Data()); @@ -215,7 +215,7 @@ TEST_F(TestArithmeticGradFp32, TestSubGrad2Fp32) { std::vector outputs = {all_tensors[4], all_tensors[3]}; kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_SubGrad}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc, nullptr); kernel_obj->Run(); float *output_ptr = reinterpret_cast(outputs[0]->Data()); @@ -246,7 +246,7 @@ TEST_F(TestArithmeticGradFp32, TestMulGradFp32) { std::vector outputs = {all_tensors[3], all_tensors[4]}; kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc, nullptr); int loop_count = 1000; auto time_start = mindspore::lite::GetTimeUs(); @@ -287,7 +287,7 @@ TEST_F(TestArithmeticGradFp32, TestMulGrad2Fp32) { std::vector outputs = {all_tensors[4], all_tensors[3]}; kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc, nullptr); kernel_obj->Run(); float *output_ptr = reinterpret_cast(outputs[0]->Data()); @@ -318,7 +318,7 @@ TEST_F(TestArithmeticGradFp32, TestMulGrad3Fp32) { std::vector outputs = {all_tensors[3], all_tensors[4]}; kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc, nullptr); kernel_obj->Run(); float *output_ptr = reinterpret_cast(outputs[1]->Data()); @@ -349,7 +349,7 @@ TEST_F(TestArithmeticGradFp32, TestMulGrad4Fp32) { std::vector outputs = {all_tensors[4], all_tensors[3]}; kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_MulGrad}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc, nullptr); kernel_obj->Run(); float *output_ptr = reinterpret_cast(outputs[0]->Data()); @@ -380,7 +380,7 @@ TEST_F(TestArithmeticGradFp32, TestDivGradFp32) { std::vector outputs = {all_tensors[3], all_tensors[4]}; kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc, nullptr); kernel_obj->Run(); float *output_ptr = reinterpret_cast(outputs[1]->Data()); @@ -411,7 +411,7 @@ TEST_F(TestArithmeticGradFp32, TestDivGrad2Fp32) { std::vector outputs = {all_tensors[4], all_tensors[3]}; kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc, nullptr); kernel_obj->Run(); float *output_ptr = reinterpret_cast(outputs[0]->Data()); @@ -442,7 +442,7 @@ TEST_F(TestArithmeticGradFp32, TestDivGrad3Fp32) { std::vector outputs = {all_tensors[3], all_tensors[4]}; kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc, nullptr); kernel_obj->Run(); float *output_ptr = reinterpret_cast(outputs[1]->Data()); @@ -473,7 +473,7 @@ TEST_F(TestArithmeticGradFp32, Test3DDivGrad2Fp32) { std::vector outputs = {all_tensors[3], all_tensors[4]}; kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_DivGrad}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(param), NULL, desc, nullptr); kernel_obj->Run(); float *output_ptr = reinterpret_cast(outputs[1]->Data()); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc index 8e8c191866..b65e6ecaab 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/batchnorm_fp32_tests.cc @@ -77,7 +77,7 @@ TEST_F(TestBatchnormFp32, BNTest) { lite::Context ctx; ctx.thread_num_ = 7; kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), &ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), &ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor.shape(); kernel->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/bias_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/bias_grad_fp32_tests.cc index 7c26e95022..0a68d654de 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/bias_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/bias_grad_fp32_tests.cc @@ -50,7 +50,7 @@ TEST_F(TestBiasGradFp32, BiasGradFp32) { kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_BiasGrad}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel_obj = creator(inputs, outputs, reinterpret_cast(bias_param), NULL, desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(bias_param), NULL, desc, nullptr); kernel_obj->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc index da3c661697..9c8901b845 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/conv1x1_fp32_tests.cc @@ -296,7 +296,7 @@ TEST_F(TestConv1x1Fp32, Conv1x1Test1) { float *correct; int total_size = Conv1x1TestInit1(&inputs_, &outputs_, conv_param, &correct); kernel::Convolution1x1CPUKernel *conv1x1 = - new kernel::Convolution1x1CPUKernel(reinterpret_cast(conv_param), inputs_, outputs_, ctx); + new kernel::Convolution1x1CPUKernel(reinterpret_cast(conv_param), inputs_, outputs_, ctx, nullptr); conv1x1->Init(); conv1x1->Run(); @@ -364,7 +364,7 @@ TEST_F(TestConv1x1Fp32, Conv1x1Test2) { float *correct; int total_size = Conv1x1TestInit2(&inputs_, &outputs_, conv_param, &correct); kernel::Convolution1x1CPUKernel *conv1x1 = - new kernel::Convolution1x1CPUKernel(reinterpret_cast(conv_param), inputs_, outputs_, ctx); + new kernel::Convolution1x1CPUKernel(reinterpret_cast(conv_param), inputs_, outputs_, ctx, nullptr); conv1x1->Init(); conv1x1->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32_tests.cc index 3394ecb5af..86acfe7b57 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32_tests.cc @@ -116,7 +116,8 @@ TEST_F(TestConvolutionDwFp32, ConvDwFp32Accuracy) { kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_DepthwiseConv2D}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); - kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(conv_param), ctx, desc); + kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(conv_param), ctx, desc, + nullptr); ASSERT_NE(kernel, nullptr); // op run kernel->Run(); @@ -166,7 +167,8 @@ TEST_F(TestConvolutionDwFp32, ConvDwFp32Performance) { kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_DepthwiseConv2D}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); - kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(conv_param), ctx, desc); + kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(conv_param), ctx, desc, + nullptr); ASSERT_NE(kernel, nullptr); /* running warm up */ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_grad_fp32_tests.cc index 7847e013a3..721486e6a6 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/convolution_grad_fp32_tests.cc @@ -109,7 +109,7 @@ TEST_F(TestConvolutionGradFp32, ConvFp32FilterGrad) { kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradFilter}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc, nullptr); // warm up loop for (int i = 0; i < 3; i++) { @@ -174,7 +174,7 @@ TEST_F(TestConvolutionGradFp32, ConvFp32InputGrad) { kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradInput}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc, nullptr); // warm up loop for (int i = 0; i < 3; i++) { @@ -234,7 +234,7 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupFilterGrad) { kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradFilter}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc, nullptr); // warm up loop for (int i = 0; i < 3; i++) { @@ -298,7 +298,7 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupInputGrad) { kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradInput}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc, nullptr); // warm up loop for (int i = 0; i < 3; i++) { @@ -359,7 +359,7 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationFilterGrad) { kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradFilter}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc, nullptr); // warm up loop for (int i = 0; i < 3; i++) { @@ -422,7 +422,7 @@ TEST_F(TestConvolutionGradFp32, ConvFp32GroupDilationInputGrad) { kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_Conv2DGradInput}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(conv_param), NULL, desc, nullptr); // warm up loop for (int i = 0; i < 3; i++) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc index 727f86f3b9..df598547e2 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/deconvolution_fp32_tests.cc @@ -297,7 +297,7 @@ TEST_F(TestDeConvolutionFp32, DeConvTest1) { float *correct; int total_size = DeConvTestInit1(&inputs_, &outputs_, deconv_param, &correct); kernel::DeConvolutionCPUKernel *deconv = - new kernel::DeConvolutionCPUKernel(reinterpret_cast(deconv_param), inputs_, outputs_, ctx); + new kernel::DeConvolutionCPUKernel(reinterpret_cast(deconv_param), inputs_, outputs_, ctx, nullptr); deconv->Init(); deconv->Run(); @@ -366,7 +366,7 @@ TEST_F(TestDeConvolutionFp32, DeConvTest2) { lite::Context *ctx = new lite::Context; ctx->thread_num_ = 4; kernel::DeConvolutionCPUKernel *deconv = - new kernel::DeConvolutionCPUKernel(reinterpret_cast(deconv_param), inputs_, outputs_, ctx); + new kernel::DeConvolutionCPUKernel(reinterpret_cast(deconv_param), inputs_, outputs_, ctx, nullptr); deconv->Init(); deconv->Run(); @@ -445,7 +445,7 @@ TEST_F(TestDeConvolutionFp32, DeConvTest3) { lite::Context *ctx = new lite::Context; ctx->thread_num_ = 2; kernel::DeConvolutionCPUKernel *deconv = - new kernel::DeConvolutionCPUKernel(reinterpret_cast(deconv_param), inputs_, outputs_, ctx); + new kernel::DeConvolutionCPUKernel(reinterpret_cast(deconv_param), inputs_, outputs_, ctx, nullptr); deconv->Init(); deconv->Run(); @@ -517,7 +517,7 @@ TEST_F(TestDeConvolutionFp32, DeConvTest4) { lite::Context *ctx = new lite::Context; ctx->thread_num_ = 2; kernel::DeConvolutionCPUKernel *deconv = - new kernel::DeConvolutionCPUKernel(reinterpret_cast(deconv_param), inputs_, outputs_, ctx); + new kernel::DeConvolutionCPUKernel(reinterpret_cast(deconv_param), inputs_, outputs_, ctx, nullptr); deconv->Init(); deconv->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc index 1f6d2997c4..02aff1d460 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/embedding_lookup_fp32_test.cc @@ -65,7 +65,7 @@ TEST_F(TestEmbeddingLookupFp32, ElTest) { lite::Context *ctx = new lite::Context; ctx->thread_num_ = 2; kernel::EmbeddingLookupCPUKernel *el = new kernel::EmbeddingLookupCPUKernel( - reinterpret_cast(embedding_lookup_param_), inputs_, outputs_, ctx); + reinterpret_cast(embedding_lookup_param_), inputs_, outputs_, ctx, nullptr); el->Init(); el->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc index 9a226fef65..fd4ed81a14 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/fullconnection_fp32_tests.cc @@ -79,7 +79,7 @@ TEST_F(TestFcFp32, FcTest1) { lite::Context *ctx = new lite::Context; ctx->thread_num_ = 2; kernel::FullconnectionCPUKernel *fc = - new kernel::FullconnectionCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + new kernel::FullconnectionCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx, nullptr); fc->Init(); fc->Run(); @@ -136,7 +136,7 @@ TEST_F(TestFcFp32, FcTest2) { lite::Context *ctx = new lite::Context; ctx->thread_num_ = 1; kernel::FullconnectionCPUKernel *fc = - new kernel::FullconnectionCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + new kernel::FullconnectionCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx, nullptr); fc->Init(); fc->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc index f591548c00..41312db0b1 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc @@ -152,7 +152,8 @@ TEST_F(LstmFp32, LstmForwardFp32Accuracy) { kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_Lstm}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); - kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(lstm_param), ctx, desc); + kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(lstm_param), ctx, desc, + nullptr); ASSERT_NE(kernel, nullptr); // op run kernel->Run(); @@ -299,7 +300,8 @@ TEST_F(LstmFp32, LstmBackwardFp32Accuracy) { kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_Lstm}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); - kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(lstm_param), ctx, desc); + kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(lstm_param), ctx, desc, + nullptr); ASSERT_NE(kernel, nullptr); // op run kernel->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc index 2d0fe33be1..98e28c1c80 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc @@ -207,7 +207,7 @@ TEST_F(TestMatMulFp32, simple) { int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); auto ctx = new lite::Context; ctx->thread_num_ = 1; - auto mm = new kernel::MatmulCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + auto mm = new kernel::MatmulCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx, nullptr); mm->Init(); mm->Run(); float correct[] = {-0.1256939023733139, -0.07744802534580231, 0.07410638779401779, @@ -258,7 +258,7 @@ TEST_F(TestMatMulFp32, simple2) { int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); auto ctx = new lite::Context; ctx->thread_num_ = 2; - auto mm = new kernel::MatmulCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + auto mm = new kernel::MatmulCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx, nullptr); mm->Init(); mm->Run(); float correct[] = { @@ -327,7 +327,7 @@ TEST_F(TestMatMulFp32, simple_transb) { int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); auto ctx = new lite::Context; ctx->thread_num_ = 2; - auto mm = new kernel::MatmulCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + auto mm = new kernel::MatmulCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx, nullptr); mm->Init(); mm->Run(); float correct[] = {0.00533547, 0.002545945, 0.062974121, -0.445441471, -0.246223617, -0.142070031}; @@ -376,7 +376,7 @@ TEST_F(TestMatMulFp32, batch) { int total_size = MMTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); auto ctx = new lite::Context; ctx->thread_num_ = 1; - auto mm = new kernel::MatmulCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + auto mm = new kernel::MatmulCPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx, nullptr); mm->Init(); mm->Run(); float correct[] = {21.38518524169922, -14.514888763427734, -11.040614128112793, 16.91403579711914, diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/pooling_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/pooling_grad_fp32_tests.cc index fb9e262e28..682f34c6a4 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/pooling_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/pooling_grad_fp32_tests.cc @@ -138,7 +138,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolingKernelGradFp32) { kernel::KernelKey desc = {kernel::kCPU, TypeId::kNumberTypeFloat32, schema::PrimitiveType_PoolingGrad}; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); - auto kernel_obj = creator(inputs, outputs, reinterpret_cast(pooling_param), NULL, desc); + auto kernel_obj = creator(inputs, outputs, reinterpret_cast(pooling_param), NULL, desc, nullptr); kernel_obj->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc index 46423d89a4..4b3538316f 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc @@ -62,7 +62,8 @@ TEST_F(TestPowerFp32, Simple) { int total_size = PowerTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); auto ctx = new lite::Context; ctx->thread_num_ = 1; - auto op = new kernel::PowerCPUKernel(reinterpret_cast(param), inputs_, outputs_, ctx); + kernel::PowerCPUKernel *op = new kernel::PowerCPUKernel(reinterpret_cast(param), inputs_, outputs_, + ctx, nullptr); op->Init(); op->Run(); float correct[] = {1, 64, 2187, 65536}; @@ -88,7 +89,8 @@ TEST_F(TestPowerFp32, Broadcast) { int total_size = PowerTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape); auto ctx = new lite::Context; ctx->thread_num_ = 2; - auto op = new kernel::PowerCPUKernel(reinterpret_cast(param), inputs_, outputs_, ctx); + kernel::PowerCPUKernel *op = new kernel::PowerCPUKernel(reinterpret_cast(param), inputs_, outputs_, + ctx, nullptr); op->Init(); op->Run(); float correct[] = {1, 4, 9, 16}; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc index c6b1e5c7e3..696647687d 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc @@ -149,7 +149,7 @@ TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest3) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(¶m), &ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(¶m), &ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); kernel->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_depth_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_depth_fp32_tests.cc index a758336be2..ac0258b1fa 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_depth_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_depth_fp32_tests.cc @@ -80,7 +80,7 @@ TEST_F(SpaceToDepthTestFp32, SpaceToDepthTest2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), &ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), &ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); kernel->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc index b45037d36c..fd005cff64 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc @@ -45,7 +45,7 @@ TEST_F(TestTopKFp32, TopK) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); - auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), nullptr, desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), nullptr, desc, nullptr); ASSERT_NE(kernel, nullptr); auto ret = kernel->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/add_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/add_int8_tests.cc index a8d738ddaa..40f108d4a3 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/add_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/add_int8_tests.cc @@ -56,7 +56,7 @@ TEST_F(TestQuantizedAdd, Add) { ASSERT_NE(creator, nullptr); auto ctx = std::make_shared(); - auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc, nullptr); ASSERT_NE(kernel, nullptr); auto ret = kernel->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc index 3ab4e5f26b..ecc5bcb3f2 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/arithmetic_self_int8_tests.cc @@ -70,7 +70,7 @@ TEST_F(TestArithmeticSelfInt8, floor_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -129,7 +129,7 @@ TEST_F(TestArithmeticSelfInt8, floor_quant1_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -188,7 +188,7 @@ TEST_F(TestArithmeticSelfInt8, round_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -247,7 +247,7 @@ TEST_F(TestArithmeticSelfInt8, round_quant1_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -306,7 +306,7 @@ TEST_F(TestArithmeticSelfInt8, ceil_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -365,7 +365,7 @@ TEST_F(TestArithmeticSelfInt8, ceil_quant1_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -424,7 +424,7 @@ TEST_F(TestArithmeticSelfInt8, abs_quant0_thread0) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -483,7 +483,7 @@ TEST_F(TestArithmeticSelfInt8, abs_quant1_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -542,7 +542,7 @@ TEST_F(TestArithmeticSelfInt8, sin_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -601,7 +601,7 @@ TEST_F(TestArithmeticSelfInt8, cos_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -660,7 +660,7 @@ TEST_F(TestArithmeticSelfInt8, log_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -719,7 +719,7 @@ TEST_F(TestArithmeticSelfInt8, sqrt_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -778,7 +778,7 @@ TEST_F(TestArithmeticSelfInt8, rsqrt_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -837,7 +837,7 @@ TEST_F(TestArithmeticSelfInt8, square_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -896,7 +896,7 @@ TEST_F(TestArithmeticSelfInt8, square_quant1_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -955,7 +955,7 @@ TEST_F(TestArithmeticSelfInt8, logical_not_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc index 5d9814e268..cc7084b25b 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/concat_int8_tests.cc @@ -82,7 +82,7 @@ TEST_F(TestConcatInt8, Concat1_axis0) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -153,7 +153,7 @@ TEST_F(TestConcatInt8, Concat1_axis1_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -225,7 +225,7 @@ TEST_F(TestConcatInt8, Concat1_axis1_thread2_quant1) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc index d94ae6eefe..6ee9e96976 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/crop_int8_tests.cc @@ -74,7 +74,7 @@ TEST_F(TestCropInt8, crop_1d_axis0_offset0_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -137,7 +137,7 @@ TEST_F(TestCropInt8, crop_2d_axis1_offset0_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -200,7 +200,7 @@ TEST_F(TestCropInt8, crop_3d_axis1_offset0_quant0_thread0) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -264,7 +264,7 @@ TEST_F(TestCropInt8, crop_3d_axis1_offset0_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -327,7 +327,7 @@ TEST_F(TestCropInt8, crop_4d_axis0_offset0_quant0_thread0) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -390,7 +390,7 @@ TEST_F(TestCropInt8, crop_4d_axis1_offset0_quant0_thread0) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -456,7 +456,7 @@ TEST_F(TestCropInt8, crop_4d_axis1_offset1_quant0_thread0) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -522,7 +522,7 @@ TEST_F(TestCropInt8, crop_4d_axis1_offset1_quant1_thread0) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -587,7 +587,7 @@ TEST_F(TestCropInt8, crop_4d_axis0_offset0_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -652,7 +652,7 @@ TEST_F(TestCropInt8, crop_4d_axis0_offset0_quant0_thread3) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc index 632abad686..f6724eb579 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc @@ -250,8 +250,8 @@ TEST_F(TestDeconvInt8, DeConvInt8Test1) { ctx->thread_num_ = 2; int8_t *correct; int total_size = DeConvInt8TestInit1(&inputs_, &outputs_, deconv_param, &correct); - mindspore::kernel::DeConvInt8CPUKernel *deconv = - new mindspore::kernel::DeConvInt8CPUKernel(reinterpret_cast(deconv_param), inputs_, outputs_, ctx); + mindspore::kernel::DeConvInt8CPUKernel *deconv = new mindspore::kernel::DeConvInt8CPUKernel( + reinterpret_cast(deconv_param), inputs_, outputs_, ctx, nullptr); deconv->Init(); deconv->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc index d5ca51733f..51a3796271 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc @@ -107,8 +107,8 @@ TEST_F(TestFcInt8, fcint8) { int total_size = FcInt8TestInit(&inputs_, &outputs_, matmul_param, &correct, &output_scale, &output_zp); lite::Context *ctx = new lite::Context; ctx->thread_num_ = 2; - kernel::FullconnectionInt8CPUKernel *fc = - new kernel::FullconnectionInt8CPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + kernel::FullconnectionInt8CPUKernel *fc = new kernel::FullconnectionInt8CPUKernel( + reinterpret_cast(matmul_param), inputs_, outputs_, ctx, nullptr); fc->Init(); fc->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc index ef2f1b0e77..e9caaec997 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/hswish_int8_tests.cc @@ -56,7 +56,7 @@ TEST_F(TestHSwishInt8, HSwish) { ASSERT_NE(creator, nullptr); auto ctx = std::make_shared(); - auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc, nullptr); ASSERT_NE(kernel, nullptr); auto ret = kernel->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc index 36d7ecd23e..fdce152ef2 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc @@ -108,7 +108,7 @@ TEST_F(TestMatmulInt8, mmint8) { auto ctx = new lite::Context; ctx->thread_num_ = 2; kernel::MatmulInt8CPUKernel *mm = - new kernel::MatmulInt8CPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx); + new kernel::MatmulInt8CPUKernel(reinterpret_cast(matmul_param), inputs_, outputs_, ctx, nullptr); mm->Init(); mm->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc index 3aed147816..7185c13bfa 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/mul_int8_tests.cc @@ -81,7 +81,7 @@ TEST_F(TestMulInt8, Mul_quant0) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -151,7 +151,7 @@ TEST_F(TestMulInt8, Mul_quant0_thread0) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -221,7 +221,7 @@ TEST_F(TestMulInt8, Mul_quant1) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -291,7 +291,7 @@ TEST_F(TestMulInt8, Mul_quant1_thread1) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc index 4230ec4ae4..3bd462cc8a 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc @@ -68,7 +68,7 @@ TEST_F(TestPadInt8, PadInt8Test1) { int8_t *correct; int total_size = PadInt8TestInit1(&inputs_, &outputs_, pad_param, &correct); kernel::PadInt8CPUKernel *pad = - new kernel::PadInt8CPUKernel(reinterpret_cast(pad_param), inputs_, outputs_, ctx); + new kernel::PadInt8CPUKernel(reinterpret_cast(pad_param), inputs_, outputs_, ctx, nullptr); pad->Init(); pad->Run(); @@ -120,7 +120,7 @@ TEST_F(TestPadInt8, PadInt8Test2) { int8_t *correct; int total_size = PadInt8TestInit2(&inputs_, &outputs_, pad_param, &correct); kernel::PadInt8CPUKernel *pad = - new kernel::PadInt8CPUKernel(reinterpret_cast(pad_param), inputs_, outputs_, ctx); + new kernel::PadInt8CPUKernel(reinterpret_cast(pad_param), inputs_, outputs_, ctx, nullptr); pad->Init(); pad->Run(); @@ -186,7 +186,7 @@ TEST_F(TestPadInt8, PadInt8TestInit4) { int8_t *correct; int total_size = PadInt8TestInit2(&inputs_, &outputs_, pad_param, &correct); kernel::PadInt8CPUKernel *pad = - new kernel::PadInt8CPUKernel(reinterpret_cast(pad_param), inputs_, outputs_, ctx); + new kernel::PadInt8CPUKernel(reinterpret_cast(pad_param), inputs_, outputs_, ctx, nullptr); pad->Init(); pad->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc index d326022761..297316a7b5 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/prelu_int8_tests.cc @@ -75,7 +75,7 @@ TEST_F(TestPreluInt8, prelu_1) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc index 89b2385b3f..665d18a217 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/quant_dtype_cast_tests.cc @@ -67,7 +67,7 @@ TEST_F(QuantDTypeCastTestFp32, QuantDTypeCastTest1) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(¶m), &ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(¶m), &ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); kernel->Run(); @@ -113,7 +113,7 @@ TEST_F(QuantDTypeCastTestFp32, QuantDTypeCastTest2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(¶m), &ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(¶m), &ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); kernel->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relux_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relux_int8_tests.cc index 910e746cf4..70de924b24 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relux_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/relux_int8_tests.cc @@ -54,7 +54,7 @@ TEST_F(TestReluXInt8, Relu) { ASSERT_NE(creator, nullptr); auto ctx = std::make_shared(); - auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc, nullptr); ASSERT_NE(kernel, nullptr); auto ret = kernel->Run(); @@ -97,7 +97,7 @@ TEST_F(TestReluXInt8, Relu6) { ASSERT_NE(creator, nullptr); auto ctx = std::make_shared(); - auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc, nullptr); ASSERT_NE(kernel, nullptr); auto ret = kernel->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc index e6921779b8..33dcc3050d 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/reshape_int8_tests.cc @@ -70,7 +70,7 @@ TEST_F(TestReshapeInt8, reshape_quant0) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); @@ -129,7 +129,7 @@ TEST_F(TestReshapeInt8, reshape_quant1_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc index 86ceb4c294..361c7be45f 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/sigmoid_int8_tests.cc @@ -53,7 +53,7 @@ TEST_F(TestSigmoidInt8, Sigmoid) { ASSERT_NE(creator, nullptr); auto ctx = std::make_shared(); - auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), ctx.get(), desc, nullptr); ASSERT_NE(kernel, nullptr); auto ret = kernel->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc index 1c22a52dee..20a5ff88a4 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/softmax_int8_tests.cc @@ -75,7 +75,7 @@ TEST_F(TestSoftmaxInt8, SoftmaxInt8) { ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx.get(), desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx.get(), desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor.shape(); kernel->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc index 8437edaa76..70d7e5c313 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/split_int8_tests.cc @@ -84,7 +84,7 @@ TEST_F(TestSplitInt8, Split_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output1_tensor_shape = output1_tensor->shape(); auto output2_tensor_shape = output2_tensor->shape(); @@ -172,7 +172,7 @@ TEST_F(TestSplitInt8, Split_quant0_thread2_num) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output1_tensor_shape = output1_tensor->shape(); auto output2_tensor_shape = output2_tensor->shape(); @@ -268,7 +268,7 @@ TEST_F(TestSplitInt8, Split_quant1_thread2_num) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output1_tensor_shape = output1_tensor->shape(); auto output2_tensor_shape = output2_tensor->shape(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc index d77bc2eb79..a2c2a76481 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/squeeze_int8_tests.cc @@ -74,7 +74,7 @@ TEST_F(TestSqueezeInt8, Squeeze_1d_axis0_offset0_quant0_thread2) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc index 1334f72b35..ea9bf54dbc 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc @@ -45,7 +45,7 @@ TEST_F(TestTopKInt8, TopK) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); - auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), nullptr, desc); + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), nullptr, desc, nullptr); ASSERT_NE(kernel, nullptr); auto ret = kernel->Run(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc index 59edaffa5b..b8fb3423c1 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/unsqueeze_int8_tests.cc @@ -74,7 +74,7 @@ TEST_F(TestUnsqueezeInt8, Unsqueeze_1) { auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); ASSERT_NE(creator, nullptr); kernel::LiteKernel *kernel = - creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc); + creator(inputs_tensor, outputs_tensor, reinterpret_cast(&op_param), ctx, desc, nullptr); ASSERT_NE(kernel, nullptr); auto output_tensor_shape = output0_tensor->shape(); ASSERT_EQ(output_tensor_shape, output_shape);