From 437f54ed4ecef9a55ce826766ce13b6d0d2c0752 Mon Sep 17 00:00:00 2001 From: zhangxuetong Date: Tue, 18 Aug 2020 10:12:36 +0800 Subject: [PATCH] add fp16 reshape and fix register kernel bug --- mindspore/lite/src/kernel_registry.cc | 16 +++- .../runtime/kernel/arm/fp16/reshape_fp16.cc | 76 +++++++++++++++++++ .../runtime/kernel/arm/fp16/reshape_fp16.h | 43 +++++++++++ 3 files changed, 132 insertions(+), 3 deletions(-) create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/reshape_fp16.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/fp16/reshape_fp16.h diff --git a/mindspore/lite/src/kernel_registry.cc b/mindspore/lite/src/kernel_registry.cc index 351ac83752..e30c91544d 100644 --- a/mindspore/lite/src/kernel_registry.cc +++ b/mindspore/lite/src/kernel_registry.cc @@ -82,6 +82,11 @@ kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { return nullptr; } int index = GetCreatorFuncIndex(desc); + if (index >= array_size_) { + MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " + << desc.type; + return nullptr; + } auto it = creator_arrays_[index]; if (it != nullptr) { return it; @@ -91,9 +96,9 @@ kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { int KernelRegistry::GetCreatorFuncIndex(const kernel::KernelKey desc) { int index; - int device_index = static_cast(desc.arch); - int dType_index = static_cast(desc.data_type); - int op_index = static_cast(desc.type); + int device_index = static_cast(desc.arch) - kKernelArch_MIN; + int dType_index = static_cast(desc.data_type) - kNumberTypeBegin; + int op_index = static_cast(desc.type) - PrimitiveType_MIN; index = device_index * data_type_length_ * op_type_length_ + dType_index * op_type_length_ + op_index; return index; } @@ -115,6 +120,11 @@ void KernelRegistry::RegKernel(const KERNEL_ARCH arch, const TypeId data_type, c } KernelKey desc = {arch, data_type, op_type}; int index = GetCreatorFuncIndex(desc); + if (index >= array_size_) { + MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " + << desc.type; + return; + } creator_arrays_[index] = creator; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/reshape_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/reshape_fp16.cc new file mode 100644 index 0000000000..037f4d53d1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/reshape_fp16.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp16/reshape_fp16.h" +#include +#include "nnacl/fp16/cast_fp16.h" +#include "nnacl/reshape.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Reshape; + +namespace mindspore::kernel { + +int ReshapeCPUKernel::Run() { + auto ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail!ret: " << ret; + return ret; + } + auto in_tensor = in_tensors_.at(kInputIndex); + auto out_tensor = out_tensors_.at(kOutputIndex); + auto input_ptr = in_tensor->Data(); + auto output_ptr = out_tensor->Data(); + size_t data_size = out_tensor->Size(); + + auto in_datatype = in_tensor->data_type(); + auto out_datatype = out_tensor->data_type(); + if (in_datatype != out_datatype) { + if (in_datatype == kNumberTypeFloat32 && out_datatype == kNumberTypeFloat16) { + input_ptr = context_->allocator->Malloc(in_tensor->ElementsNum() * sizeof(float16_t)); + if (input_ptr == nullptr) { + MS_LOG(ERROR) << "malloc in tensor fail!"; + return mindspore::lite::RET_MEMORY_FAILED; + } + Float32ToFloat16(reinterpret_cast(in_tensor->Data()), reinterpret_cast(input_ptr), + in_tensor->ElementsNum()); + } else if ((in_datatype == kNumberTypeFloat16 && out_datatype == kNumberTypeFloat32)) { + input_ptr = context_->allocator->Malloc(in_tensor->ElementsNum() * sizeof(float)); + if (input_ptr == nullptr) { + MS_LOG(ERROR) << "malloc in tensor fail!"; + return mindspore::lite::RET_MEMORY_FAILED; + } + Float16ToFloat32(reinterpret_cast(in_tensor->Data()), reinterpret_cast(input_ptr), + in_tensor->ElementsNum()); + } else { + MS_LOG(ERROR) << "unsupported data type, in_datatype: " << in_datatype << ",out_datatype: " << out_datatype; + return RET_ERROR; + } + } + + Reshape(input_ptr, output_ptr, data_size); + if (in_datatype != out_datatype) { + context_->allocator->Free(input_ptr); + } + return RET_OK; +} // namespace mindspore::kernel +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/reshape_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/reshape_fp16.h new file mode 100644 index 0000000000..b501f554e8 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/reshape_fp16.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_RESHAPE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_RESHAPE_H_ + +#include +#include "src/lite_kernel.h" + +#include "include/context.h" +#include "src/runtime/kernel/arm/fp32/reshape.h" + +using mindspore::lite::Context; + +namespace mindspore::kernel { +class ReshapeFp16CPUKernel : public ReshapeCPUKernel { + public: + ReshapeFp16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const Context *ctx, + const mindspore::lite::PrimitiveC *primitive) + : ReshapeCPUKernel(parameter, inputs, outputs, ctx, primitive) {} + ~ReshapeFp16CPUKernel() = default; + + int Run() override; + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_RESHAPE_H_