Browse Source

!6935 [MSLITE][Develop] fix concat fp16

Merge pull request !6935 from sunsuodong/fix_concat
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
5b122cd1eb
3 changed files with 9 additions and 12 deletions
  1. +4
    -3
      mindspore/lite/nnacl/fp16/concat_fp16.c
  2. +2
    -1
      mindspore/lite/nnacl/fp16/concat_fp16.h
  3. +3
    -8
      mindspore/lite/src/runtime/kernel/arm/fp16/concat_fp16.cc

+ 4
- 3
mindspore/lite/nnacl/fp16/concat_fp16.c View File

@@ -17,13 +17,14 @@
#include "nnacl/fp16/concat_fp16.h" #include "nnacl/fp16/concat_fp16.h"
#include <string.h> #include <string.h>


void ConcatFp16(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output) {
void ConcatFp16(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output,
int dtype_len) {
int before_axis_size = 1; int before_axis_size = 1;
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
before_axis_size *= inputs_output_shape[0][i]; before_axis_size *= inputs_output_shape[0][i];
} }
// sizeof float16 / byte
int after_axis_size = 2;
// sizeof float16,int32
int after_axis_size = dtype_len;
for (size_t i = axis + 1; i < shape_size; ++i) { for (size_t i = axis + 1; i < shape_size; ++i) {
after_axis_size *= inputs_output_shape[0][i]; after_axis_size *= inputs_output_shape[0][i];
} }


+ 2
- 1
mindspore/lite/nnacl/fp16/concat_fp16.h View File

@@ -22,7 +22,8 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
void ConcatFp16(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output);
void ConcatFp16(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output,
int dtype_len);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif


+ 3
- 8
mindspore/lite/src/runtime/kernel/arm/fp16/concat_fp16.cc View File

@@ -129,9 +129,9 @@ int ConcatFp16CPUKernel::Run() {
if (out_tensors_.at(0)->data_type() == kNumberTypeFloat16) { if (out_tensors_.at(0)->data_type() == kNumberTypeFloat16) {
fp16_output_ = reinterpret_cast<float16_t *>(out_tensors_.at(0)->MutableData()); fp16_output_ = reinterpret_cast<float16_t *>(out_tensors_.at(0)->MutableData());
} }
int dtype_len = in_tensors_.at(0)->data_type() == kNumberTypeInt32 ? sizeof(int32_t) : sizeof(float16_t);
ConcatFp16(reinterpret_cast<void **>(fp16_inputs_.data()), input_num, axis_, inputs_output_shape.data(), ConcatFp16(reinterpret_cast<void **>(fp16_inputs_.data()), input_num, axis_, inputs_output_shape.data(),
output_shape.size(), reinterpret_cast<void *>(fp16_output_));
output_shape.size(), reinterpret_cast<void *>(fp16_output_), dtype_len);


if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32 || out_tensors_.at(0)->data_type() == kNumberTypeFloat) { if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32 || out_tensors_.at(0)->data_type() == kNumberTypeFloat) {
Float16ToFloat32(fp16_output_, reinterpret_cast<float *>(output_addr), out_tensors_.at(0)->ElementsNum()); Float16ToFloat32(fp16_output_, reinterpret_cast<float *>(output_addr), out_tensors_.at(0)->ElementsNum());
@@ -148,12 +148,7 @@ kernel::LiteKernel *CpuConcatFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "Input parameter is nullptr!"; MS_LOG(ERROR) << "Input parameter is nullptr!";
return nullptr; return nullptr;
} }
kernel::LiteKernel *kernel = nullptr;
if (IsExistFp16Tensor(inputs, outputs)) {
kernel = new (std::nothrow) ConcatFp16CPUKernel(parameter, inputs, outputs, ctx, primitive);
} else {
kernel = new (std::nothrow) ConcatCPUKernel(parameter, inputs, outputs, ctx, primitive);
}
kernel::LiteKernel *kernel = new (std::nothrow) ConcatFp16CPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; MS_LOG(ERROR) << "new ConcatCPUKernel fail!";
return nullptr; return nullptr;


Loading…
Cancel
Save