|
|
@@ -129,9 +129,9 @@ int ConcatFp16CPUKernel::Run() { |
|
|
if (out_tensors_.at(0)->data_type() == kNumberTypeFloat16) { |
|
|
if (out_tensors_.at(0)->data_type() == kNumberTypeFloat16) { |
|
|
fp16_output_ = reinterpret_cast<float16_t *>(out_tensors_.at(0)->MutableData()); |
|
|
fp16_output_ = reinterpret_cast<float16_t *>(out_tensors_.at(0)->MutableData()); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
int dtype_len = in_tensors_.at(0)->data_type() == kNumberTypeInt32 ? sizeof(int32_t) : sizeof(float16_t); |
|
|
ConcatFp16(reinterpret_cast<void **>(fp16_inputs_.data()), input_num, axis_, inputs_output_shape.data(), |
|
|
ConcatFp16(reinterpret_cast<void **>(fp16_inputs_.data()), input_num, axis_, inputs_output_shape.data(), |
|
|
output_shape.size(), reinterpret_cast<void *>(fp16_output_)); |
|
|
|
|
|
|
|
|
output_shape.size(), reinterpret_cast<void *>(fp16_output_), dtype_len); |
|
|
|
|
|
|
|
|
if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32 || out_tensors_.at(0)->data_type() == kNumberTypeFloat) { |
|
|
if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32 || out_tensors_.at(0)->data_type() == kNumberTypeFloat) { |
|
|
Float16ToFloat32(fp16_output_, reinterpret_cast<float *>(output_addr), out_tensors_.at(0)->ElementsNum()); |
|
|
Float16ToFloat32(fp16_output_, reinterpret_cast<float *>(output_addr), out_tensors_.at(0)->ElementsNum()); |
|
|
@@ -148,12 +148,7 @@ kernel::LiteKernel *CpuConcatFp16KernelCreator(const std::vector<lite::Tensor *> |
|
|
MS_LOG(ERROR) << "Input parameter is nullptr!"; |
|
|
MS_LOG(ERROR) << "Input parameter is nullptr!"; |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
kernel::LiteKernel *kernel = nullptr; |
|
|
|
|
|
if (IsExistFp16Tensor(inputs, outputs)) { |
|
|
|
|
|
kernel = new (std::nothrow) ConcatFp16CPUKernel(parameter, inputs, outputs, ctx, primitive); |
|
|
|
|
|
} else { |
|
|
|
|
|
kernel = new (std::nothrow) ConcatCPUKernel(parameter, inputs, outputs, ctx, primitive); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
kernel::LiteKernel *kernel = new (std::nothrow) ConcatFp16CPUKernel(parameter, inputs, outputs, ctx, primitive); |
|
|
if (kernel == nullptr) { |
|
|
if (kernel == nullptr) { |
|
|
MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; |
|
|
MS_LOG(ERROR) << "new ConcatCPUKernel fail!"; |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
|