From 2bdd18421256e641a91166bafb249a0cebd961a6 Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Wed, 16 Sep 2020 14:27:23 +0800 Subject: [PATCH] fix concat fp16 when tensor is int32 --- .../runtime/kernel/arm/fp16/common_fp16.cc | 20 +++++++++++++++ .../src/runtime/kernel/arm/fp16/common_fp16.h | 2 ++ .../runtime/kernel/arm/fp16/concat_fp16.cc | 25 +++++++++++-------- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.cc index 6e733b77df..cd984734f8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.cc @@ -43,4 +43,24 @@ float16_t *MallocOutputFp16(lite::Tensor *output, const lite::Context *ctx) { } return fp16_data; } + +bool IsExistFp16Tensor(const std::vector &inputs, const std::vector &outputs) { + bool result = false; + for (auto &input : inputs) { + if (input->data_type() == kNumberTypeFloat16) { + result = true; + break; + } + } + if (result) { + return true; + } + for (auto &output : outputs) { + if (output->data_type() == kNumberTypeFloat16) { + result = true; + break; + } + } + return result; +} } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.h index 1d056f0517..79309ebc6b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/common_fp16.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_COMMON_FP16_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_COMMON_FP16_H_ +#include #include "src/lite_kernel.h" namespace mindspore::kernel { @@ -23,6 +24,7 @@ float16_t *ConvertInputFp32toFp16(lite::Tensor *input, const lite::Context *ctx) float16_t *MallocOutputFp16(lite::Tensor *output, const lite::Context *ctx); +bool IsExistFp16Tensor(const std::vector &inputs, const std::vector &outputs); } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_COMMON_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/concat_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/concat_fp16.cc index d288df203f..fea1bff604 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/concat_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/concat_fp16.cc @@ -13,12 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#include -#include "nnacl/fp16/concat_fp16.h" #include "src/runtime/kernel/arm/fp16/concat_fp16.h" +#include "src/runtime/kernel/arm/fp16/common_fp16.h" +#include "src/runtime/kernel/arm/fp32/concat.h" +#include "nnacl/fp16/concat_fp16.h" #include "src/kernel_registry.h" -#include "schema/model_generated.h" #include "include/errorcode.h" #include "nnacl/fp16/cast_fp16.h" @@ -142,24 +141,28 @@ int ConcatFp16CPUKernel::Run() { } kernel::LiteKernel *CpuConcatFp16KernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *opParameter, + const std::vector &outputs, OpParameter *parameter, const Context *ctx, const kernel::KernelKey &desc, const mindspore::lite::PrimitiveC *primitive) { - if (opParameter == nullptr) { - MS_LOG(ERROR) << "Input opParameter is nullptr!"; + if (parameter == nullptr) { + MS_LOG(ERROR) << "Input parameter is nullptr!"; return nullptr; } - MS_ASSERT(desc.type == schema::PrimitiveType_Concat); - auto *kernel = new (std::nothrow) ConcatFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); + kernel::LiteKernel *kernel = nullptr; + if (IsExistFp16Tensor(inputs, outputs)) { + kernel = new (std::nothrow) ConcatFp16CPUKernel(parameter, inputs, outputs, ctx, primitive); + } else { + kernel = new (std::nothrow) ConcatCPUKernel(parameter, inputs, outputs, ctx, primitive); + } if (kernel == nullptr) { MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; return nullptr; } auto ret = kernel->Init(); if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); delete kernel; - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); return nullptr; } return kernel;