From ef330cdffe2d4a5c1bf7cff6c970c743e001145f Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Sun, 27 Sep 2020 16:12:04 +0800 Subject: [PATCH] Refactor slice and add fp16 kernel --- mindspore/lite/nnacl/fp16/slice_fp16.c | 70 +++++++++++ mindspore/lite/nnacl/fp16/slice_fp16.h | 34 ++++++ .../src/runtime/kernel/arm/base/slice_base.cc | 114 ------------------ .../src/runtime/kernel/arm/fp16/slice_fp16.cc | 91 ++++++++++++++ .../{base/slice_base.h => fp16/slice_fp16.h} | 29 ++--- .../lite/src/runtime/kernel/arm/fp32/slice.cc | 80 ++++++------ .../lite/src/runtime/kernel/arm/fp32/slice.h | 14 ++- .../src/runtime/kernel/arm/int8/slice_int8.cc | 33 +++-- .../src/runtime/kernel/arm/int8/slice_int8.h | 7 +- 9 files changed, 282 insertions(+), 190 deletions(-) create mode 100644 mindspore/lite/nnacl/fp16/slice_fp16.c create mode 100644 mindspore/lite/nnacl/fp16/slice_fp16.h delete mode 100644 mindspore/lite/src/runtime/kernel/arm/base/slice_base.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/slice_fp16.cc rename mindspore/lite/src/runtime/kernel/arm/{base/slice_base.h => fp16/slice_fp16.h} (58%) diff --git a/mindspore/lite/nnacl/fp16/slice_fp16.c b/mindspore/lite/nnacl/fp16/slice_fp16.c new file mode 100644 index 0000000000..89d3a5ab60 --- /dev/null +++ b/mindspore/lite/nnacl/fp16/slice_fp16.c @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/slice_fp16.h" +#include +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" + +void DoSliceFp16(const float16_t *input, float16_t *output, SliceParameter *param, int thread_id) { + int32_t out_dim1 = param->size_[1]; + int32_t out_dim2 = param->size_[2]; + int32_t out_dim3 = param->size_[3]; + size_t out_stride2 = out_dim3; + size_t out_stride1 = out_stride2 * out_dim2; + size_t out_stride0 = out_stride1 * out_dim1; + size_t count_per_thread = UP_DIV(out_dim1, param->op_parameter_.thread_num_); + size_t thread_stride = thread_id * count_per_thread; + size_t copy_size = param->size_[3] * sizeof(float16_t); + size_t in_stride2 = param->shape_[3]; + size_t in_stride1 = param->shape_[2] * in_stride2; + size_t in_stride0 = param->shape_[1] * in_stride1; + for (int i = 0; i < param->size_[0]; ++i) { + size_t out_offset0 = i * out_stride0; + size_t in_offset0 = (i + param->begin_[0]) * in_stride0 + param->begin_[3]; + for (size_t j = 0; j < count_per_thread; ++j) { + size_t k = j + thread_stride; + if (k >= out_dim1) { + break; + } + size_t out_offset1 = k * out_stride1 + out_offset0; + size_t in_offset1 = (k + param->begin_[1]) * in_stride1 + in_offset0; + for (int l = 0; l < out_dim2; ++l) { + size_t out_offset = out_offset1 + l * out_stride2; + size_t in_offset = in_offset1 + (l + param->begin_[2]) * in_stride2; + memcpy(output + out_offset, input + in_offset, copy_size); + } + } + } +} + +void DoSliceFp16NoParallel(const float16_t *input, float16_t *output, SliceParameter *param) { + size_t copy_size = param->size_[3] * sizeof(float16_t); + size_t in_stride2 = param->shape_[3]; + size_t in_stride1 = param->shape_[2] * in_stride2; + size_t in_stride0 = param->shape_[1] * in_stride1; + size_t out_offset = 0; + for (int32_t dim0 = param->begin_[0]; dim0 < param->end_[0]; ++dim0) { + size_t in_offset0 = dim0 * in_stride0 + param->begin_[3]; + for (size_t dim1 = param->begin_[1]; dim1 < param->end_[1]; ++dim1) { + size_t in_offset1 = dim1 * in_stride1 + in_offset0; + for (int32_t dim2 = param->begin_[2]; dim2 < param->end_[2]; ++dim2) { + size_t in_offset = in_offset1 + dim2 * in_stride2; + memcpy(output + out_offset, input + in_offset, copy_size); + out_offset += param->size_[3]; + } + } + } +} diff --git a/mindspore/lite/nnacl/fp16/slice_fp16.h b/mindspore/lite/nnacl/fp16/slice_fp16.h new file mode 100644 index 0000000000..b3dc63d5a0 --- /dev/null +++ b/mindspore/lite/nnacl/fp16/slice_fp16.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_FP16_SLICE_FP16_H_ +#define MINDSPORE_LITE_NNACL_FP16_SLICE_FP16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/slice_parameter.h" +#ifdef ENABLE_NEON +#include +#endif + +#ifdef __cplusplus +extern "C" { +#endif +void DoSliceFp16(const float16_t *input, float16_t *output, SliceParameter *param, int thread_id); +void DoSliceFp16NoParallel(const float16_t *input, float16_t *output, SliceParameter *param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_FP16_SLICE_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/slice_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/slice_base.cc deleted file mode 100644 index cbfc4bfa27..0000000000 --- a/mindspore/lite/src/runtime/kernel/arm/base/slice_base.cc +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "src/runtime/kernel/arm/base/slice_base.h" -#include -#include "src/runtime/kernel/arm/int8/slice_int8.h" -#include "src/runtime/kernel/arm/fp32/slice.h" -#include "schema/model_generated.h" -#include "src/kernel_registry.h" -#include "include/errorcode.h" - -using mindspore::lite::KernelRegistrar; -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; -using mindspore::schema::PrimitiveType_Slice; - -namespace mindspore::kernel { -int SliceBaseCPUKernel::Init() { return RET_OK; } - -int SliceBaseCPUKernel::ReSize() { - auto input_shape = in_tensors_[0]->shape(); - if (input_shape.size() > DIMENSION_4D) { - MS_LOG(ERROR) << "input dimension num should <= " << DIMENSION_4D; - return RET_ERROR; - } - - for (size_t i = 0; i < input_shape.size(); ++i) { - param_->shape_[i] = input_shape[i]; - } - - if (param_->param_length_ < DIMENSION_4D) { - for (int i = param_->param_length_ - 1, j = 1; i >= 0; --i, ++j) { - param_->begin_[DIMENSION_4D - j] = param_->begin_[i]; - param_->size_[DIMENSION_4D - j] = param_->size_[i]; - } - for (int i = 0; i < DIMENSION_4D - param_->param_length_; i++) { - param_->begin_[i] = 0; - param_->size_[i] = 1; - } - } - param_->param_length_ = DIMENSION_4D; - for (int i = 0; i < DIMENSION_4D; ++i) { - if (param_->size_[i] < 0) { - param_->size_[i] = param_->shape_[i] - param_->begin_[i]; - } - param_->end_[i] = param_->begin_[i] + param_->size_[i]; - } - - return RET_OK; -} - -kernel::LiteKernel *CpuSliceInt8KernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *opParameter, - const lite::InnerContext *ctx, const kernel::KernelKey &desc, - const mindspore::lite::PrimitiveC *primitive) { - if (opParameter == nullptr) { - MS_LOG(ERROR) << "Input opParameter is nullptr!"; - return nullptr; - } - MS_ASSERT(desc.type == schema::PrimitiveType_Slice); - auto *kernel = new (std::nothrow) SliceInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); - if (kernel == nullptr) { - MS_LOG(ERROR) << "new SliceInt8CPUKernel fail!"; - return nullptr; - } - auto ret = kernel->Init(); - if (ret != RET_OK) { - delete kernel; - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - return nullptr; - } - return kernel; -} - -kernel::LiteKernel *CpuSliceFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *opParameter, - const lite::InnerContext *ctx, const kernel::KernelKey &desc, - const mindspore::lite::PrimitiveC *primitive) { - if (opParameter == nullptr) { - MS_LOG(ERROR) << "Input opParameter is nullptr!"; - return nullptr; - } - MS_ASSERT(desc.type == schema::PrimitiveType_Slice); - auto *kernel = new (std::nothrow) SliceCPUKernel(opParameter, inputs, outputs, ctx, primitive); - if (kernel == nullptr) { - MS_LOG(ERROR) << "new SliceCPUKernel fail!"; - return nullptr; - } - auto ret = kernel->Init(); - if (ret != RET_OK) { - delete kernel; - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - return nullptr; - } - return kernel; -} - -REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Slice, CpuSliceInt8KernelCreator) -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Slice, CpuSliceFp32KernelCreator) -} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/slice_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/slice_fp16.cc new file mode 100644 index 0000000000..e482e33ded --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/slice_fp16.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/fp16/slice_fp16.h" +#include "src/runtime/kernel/arm/fp16/common_fp16.h" +#include "src/kernel_registry.h" +#include "nnacl/fp16/cast_fp16.h" +#include "nnacl/fp16/slice_fp16.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_Slice; + +namespace mindspore::kernel { +int SliceFp16CPUKernel::SliceParallelRun(int thread_id) { + DoSliceFp16(input_fp16_, output_fp16_, param_, thread_id); + return RET_OK; +} + +int SliceFp16CPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << ret; + return ret; + } + input_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_); + output_fp16_ = MallocOutputFp16(out_tensors_.at(0), context_); + if (input_fp16_ == nullptr || output_fp16_ == nullptr) { + FreeInputAndOutput(); + MS_LOG(ERROR) << "input or output is nullptr"; + return RET_ERROR; + } + if (param_->size_[1] < op_parameter_->thread_num_) { + DoSliceFp16NoParallel(input_fp16_, output_fp16_, param_); + return RET_OK; + } + ret = ParallelLaunch(this->context_->thread_pool_, SliceLaunch, this, op_parameter_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "slice launch fail!ret: " << ret; + } + if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) { + Float16ToFloat32(output_fp16_, reinterpret_cast(out_tensors_.at(0)->MutableData()), + out_tensors_.at(0)->ElementsNum()); + } + FreeInputAndOutput(); + return ret; +} + +void SliceFp16CPUKernel::FreeInputAndOutput() { + if (in_tensors_.at(0)->data_type() == kNumberTypeFloat32) { + context_->allocator->Free(input_fp16_); + input_fp16_ = nullptr; + } + if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) { + context_->allocator->Free(output_fp16_); + output_fp16_ = nullptr; + } +} + +kernel::LiteKernel *CpuSliceFp16KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::InnerContext *ctx, const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + auto *kernel = new (std::nothrow) SliceFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SliceFp16CPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Slice, CpuSliceFp16KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/slice_base.h b/mindspore/lite/src/runtime/kernel/arm/fp16/slice_fp16.h similarity index 58% rename from mindspore/lite/src/runtime/kernel/arm/base/slice_base.h rename to mindspore/lite/src/runtime/kernel/arm/fp16/slice_fp16.h index be1410f6a9..48bed73e45 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/slice_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/slice_fp16.h @@ -13,32 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SLICE_BASE_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SLICE_BASE_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_ #include -#include "src/lite_kernel.h" -#include "nnacl/slice_parameter.h" +#include "src/runtime/kernel/arm/fp32/slice.h" namespace mindspore::kernel { -class SliceBaseCPUKernel : public LiteKernel { +class SliceFp16CPUKernel : public SliceCPUKernel { public: - SliceBaseCPUKernel(OpParameter *parameter, const std::vector &inputs, + SliceFp16CPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) { - param_ = reinterpret_cast(op_parameter_); - } - ~SliceBaseCPUKernel() = default; + : SliceCPUKernel(parameter, inputs, outputs, ctx, primitive) {} + ~SliceFp16CPUKernel() = default; - int Init() override; - int ReSize() override; - int Run() override { return 0; } + int Run() override; + int SliceParallelRun(int thread_id) override; protected: - SliceParameter *param_; + void FreeInputAndOutput(); + float16_t *input_fp16_ = nullptr; + float16_t *output_fp16_ = nullptr; }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SLICE_BASE_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc index f7ebd7d2f2..3cc1e71ef9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.cc @@ -14,55 +14,41 @@ * limitations under the License. */ #include "src/runtime/kernel/arm/fp32/slice.h" -#include -#include "schema/model_generated.h" #include "src/kernel_registry.h" #include "nnacl/fp32/slice.h" -#include "include/errorcode.h" -#include "src/runtime/runtime_api.h" #include "src/ops/slice.h" using mindspore::lite::KernelRegistrar; -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_NULL_PTR; -using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Slice; namespace mindspore::kernel { -namespace { int SliceLaunch(void *cdata, int task_id) { if (cdata == nullptr) { MS_LOG(ERROR) << "Input cdata is nullptr!"; - return RET_NULL_PTR; + return RET_ERROR; } auto kernel = reinterpret_cast(cdata); return kernel->SliceParallelRun(task_id); } -} // namespace int SliceCPUKernel::ReSize() { auto primitive_slice = reinterpret_cast(primitive_); auto begin = primitive_slice->GetPostProcessBegin(); auto size = primitive_slice->GetPostProcessSize(); - auto param = reinterpret_cast(op_parameter_); - param->param_length_ = in_tensors_[0]->shape().size(); - for (int i = 0; i < param->param_length_; ++i) { - param->begin_[i] = begin[i]; - param->size_[i] = size[i]; - } - auto input_shape = in_tensors_[0]->shape(); - if (static_cast(input_shape.size()) != param->param_length_) { - MS_LOG(ERROR) << "Input begin's lenth " << param->param_length_ << "is not equal to input shape size " - << input_shape.size(); - return RET_ERROR; - } - if (input_shape.size() > DIMENSION_4D) { + + param_->param_length_ = in_tensors_.at(0)->shape().size(); + if (param_->param_length_ > DIMENSION_4D) { MS_LOG(ERROR) << "input dimension num should <= " << DIMENSION_4D; return RET_ERROR; } - - for (size_t i = 0; i < input_shape.size(); ++i) { - param->shape_[i] = input_shape[i]; + for (int i = 0; i < param_->param_length_; ++i) { + param_->shape_[i] = in_tensors_.at(0)->DimensionSize(i); + param_->begin_[i] = begin[i]; + param_->size_[i] = size[i] < 0 ? param_->shape_[i] - param_->begin_[i] : size[i]; + param_->end_[i] = param_->begin_[i] + param_->size_[i]; + } + if (param_->param_length_ < DIMENSION_4D) { + PadSliceParameterTo4D(param_); } return RET_OK; } @@ -77,8 +63,7 @@ int SliceCPUKernel::Init() { int SliceCPUKernel::SliceParallelRun(int thread_id) { const float *input_data = reinterpret_cast(in_tensors_[0]->MutableData()); float *output_data = reinterpret_cast(out_tensors_[0]->MutableData()); - SliceParameter *param = reinterpret_cast(op_parameter_); - DoSlice(input_data, output_data, param, thread_id); + DoSlice(input_data, output_data, param_, thread_id); return RET_OK; } @@ -88,29 +73,38 @@ int SliceCPUKernel::Run() { MS_LOG(ERROR) << "Prepare fail!ret: " << ret; return ret; } - SliceParameter *param = reinterpret_cast(op_parameter_); - for (int i = 0; i < param->param_length_; ++i) { - if (param->size_[i] < 0) { - param->size_[i] = param->shape_[i] - param->begin_[i]; - } - param->end_[i] = param->begin_[i] + param->size_[i]; - } - - if (param->param_length_ < DIMENSION_4D) { - PadSliceParameterTo4D(param); - } - const float *input_data = reinterpret_cast(in_tensors_[0]->MutableData()); float *output_data = reinterpret_cast(out_tensors_[0]->MutableData()); - if (param->size_[1] < param->op_parameter_.thread_num_) { - DoSliceNoParallel(input_data, output_data, param); + if (param_->size_[1] < op_parameter_->thread_num_) { + DoSliceNoParallel(input_data, output_data, param_); return RET_OK; } - ret = ParallelLaunch(this->context_->thread_pool_, SliceLaunch, this, param->op_parameter_.thread_num_); + ret = ParallelLaunch(this->context_->thread_pool_, SliceLaunch, this, op_parameter_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "slice launch fail!ret: " << ret; return RET_ERROR; } return RET_OK; } + +kernel::LiteKernel *CpuSliceFp32KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::InnerContext *ctx, const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + auto *kernel = new (std::nothrow) SliceCPUKernel(opParameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SliceCPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Slice, CpuSliceFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/slice.h b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.h index d9d85cfe93..5d71edc608 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/slice.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/slice.h @@ -18,22 +18,28 @@ #include #include "src/lite_kernel.h" -#include "src/runtime/kernel/arm/base/slice_base.h" +#include "nnacl/slice_parameter.h" namespace mindspore::kernel { -class SliceCPUKernel : public SliceBaseCPUKernel { +class SliceCPUKernel : public LiteKernel { public: SliceCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : SliceBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + param_ = reinterpret_cast(op_parameter_); + } ~SliceCPUKernel() = default; int Init() override; int ReSize() override; int Run() override; - int SliceParallelRun(int thread_id); + virtual int SliceParallelRun(int thread_id); + + protected: + SliceParameter *param_; }; +int SliceLaunch(void *cdata, int task_id); } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SLICE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/slice_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/slice_int8.cc index 0cd60b4952..134cbec714 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/slice_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/slice_int8.cc @@ -16,23 +16,19 @@ #include "src/runtime/kernel/arm/int8/slice_int8.h" #include -#include "nnacl/slice_parameter.h" +#include "src/kernel_registry.h" #include "nnacl/int8/slice_int8.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" -using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Slice; namespace mindspore::kernel { int SliceInt8CPUKernel::Init() { - auto ret = SliceBaseCPUKernel::Init(); - if (ret != RET_OK) { - return ret; - } - auto input = in_tensors_.at(0); auto output = out_tensors_.at(0); MS_ASSERT(input); @@ -54,8 +50,6 @@ int SliceInt8CPUKernel::Init() { return ReSize(); } -int SliceInt8CPUKernel::ReSize() { return SliceBaseCPUKernel::ReSize(); } - int SliceInt8CPUKernel::DoSlice(int task_id) { const int8_t *input_data = reinterpret_cast(in_tensors_[0]->MutableData()); int8_t *output_data = reinterpret_cast(out_tensors_[0]->MutableData()); @@ -97,4 +91,25 @@ int SliceInt8CPUKernel::Run() { } return ret; } + +kernel::LiteKernel *CpuSliceInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::InnerContext *ctx, const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + auto *kernel = new (std::nothrow) SliceInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new SliceInt8CPUKernel fail!"; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Slice, CpuSliceInt8KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/slice_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/slice_int8.h index 90e70bbd96..8dd7c8dc7b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/slice_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/slice_int8.h @@ -18,20 +18,19 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SLICE_INT8_H_ #include -#include "src/runtime/kernel/arm/base/slice_base.h" +#include "src/runtime/kernel/arm/fp32/slice.h" #include "nnacl/quantization/quantize.h" namespace mindspore::kernel { -class SliceInt8CPUKernel : public SliceBaseCPUKernel { +class SliceInt8CPUKernel : public SliceCPUKernel { public: SliceInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : SliceBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} + : SliceCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~SliceInt8CPUKernel() {} int Init() override; - int ReSize() override; int Run() override; int DoSlice(int task_id); };