diff --git a/mindspore/lite/nnacl/fp16/concat_fp16.c b/mindspore/lite/nnacl/fp16/concat_fp16.c index 25984f82fe..de47da48d0 100644 --- a/mindspore/lite/nnacl/fp16/concat_fp16.c +++ b/mindspore/lite/nnacl/fp16/concat_fp16.c @@ -17,13 +17,14 @@ #include "nnacl/fp16/concat_fp16.h" #include -void ConcatFp16(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output) { +void ConcatFp16(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output, + int dtype_len) { int before_axis_size = 1; for (int i = 0; i < axis; ++i) { before_axis_size *= inputs_output_shape[0][i]; } - // sizeof float16 / byte - int after_axis_size = 2; + // sizeof float16,int32 + int after_axis_size = dtype_len; for (size_t i = axis + 1; i < shape_size; ++i) { after_axis_size *= inputs_output_shape[0][i]; } diff --git a/mindspore/lite/nnacl/fp16/concat_fp16.h b/mindspore/lite/nnacl/fp16/concat_fp16.h index 786471cbbf..ae9e1bf618 100644 --- a/mindspore/lite/nnacl/fp16/concat_fp16.h +++ b/mindspore/lite/nnacl/fp16/concat_fp16.h @@ -22,7 +22,8 @@ #ifdef __cplusplus extern "C" { #endif -void ConcatFp16(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output); +void ConcatFp16(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output, + int dtype_len); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/concat_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/concat_fp16.cc index b7009b378c..70eb457562 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/concat_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/concat_fp16.cc @@ -129,9 +129,9 @@ int ConcatFp16CPUKernel::Run() { if (out_tensors_.at(0)->data_type() == kNumberTypeFloat16) { fp16_output_ = reinterpret_cast(out_tensors_.at(0)->MutableData()); } - + int dtype_len = in_tensors_.at(0)->data_type() == kNumberTypeInt32 ? sizeof(int32_t) : sizeof(float16_t); ConcatFp16(reinterpret_cast(fp16_inputs_.data()), input_num, axis_, inputs_output_shape.data(), - output_shape.size(), reinterpret_cast(fp16_output_)); + output_shape.size(), reinterpret_cast(fp16_output_), dtype_len); if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32 || out_tensors_.at(0)->data_type() == kNumberTypeFloat) { Float16ToFloat32(fp16_output_, reinterpret_cast(output_addr), out_tensors_.at(0)->ElementsNum()); @@ -148,12 +148,7 @@ kernel::LiteKernel *CpuConcatFp16KernelCreator(const std::vector MS_LOG(ERROR) << "Input parameter is nullptr!"; return nullptr; } - kernel::LiteKernel *kernel = nullptr; - if (IsExistFp16Tensor(inputs, outputs)) { - kernel = new (std::nothrow) ConcatFp16CPUKernel(parameter, inputs, outputs, ctx, primitive); - } else { - kernel = new (std::nothrow) ConcatCPUKernel(parameter, inputs, outputs, ctx, primitive); - } + kernel::LiteKernel *kernel = new (std::nothrow) ConcatFp16CPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; return nullptr;