| @@ -41,15 +41,7 @@ int ConcatFp16CPUKernel::Init() { | |||||
| return ReSize(); | return ReSize(); | ||||
| } | } | ||||
| int ConcatFp16CPUKernel::ReSize() { | |||||
| FreeTmpBuffer(); | |||||
| auto ret = MallocTmpBuffer(); | |||||
| if (ret != RET_OK) { | |||||
| FreeTmpBuffer(); | |||||
| return ret; | |||||
| } | |||||
| return ConcatBaseCPUKernel::ReSize(); | |||||
| } | |||||
| int ConcatFp16CPUKernel::ReSize() { return ConcatBaseCPUKernel::ReSize(); } | |||||
| int ConcatFp16CPUKernel::MallocTmpBuffer() { | int ConcatFp16CPUKernel::MallocTmpBuffer() { | ||||
| for (const auto &in_tensor : in_tensors_) { | for (const auto &in_tensor : in_tensors_) { | ||||
| @@ -105,6 +97,13 @@ int ConcatFp16CPUKernel::Run() { | |||||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | ||||
| return prepare_ret; | return prepare_ret; | ||||
| } | } | ||||
| auto ret = MallocTmpBuffer(); | |||||
| if (ret != RET_OK) { | |||||
| FreeTmpBuffer(); | |||||
| return ret; | |||||
| } | |||||
| auto input_num = in_tensors_.size(); | auto input_num = in_tensors_.size(); | ||||
| std::vector<int *> inputs_output_shape(input_num + 1, nullptr); | std::vector<int *> inputs_output_shape(input_num + 1, nullptr); | ||||
| @@ -58,17 +58,7 @@ int ReduceFp16CPUKernel::Init() { | |||||
| } | } | ||||
| int ReduceFp16CPUKernel::ReSize() { | int ReduceFp16CPUKernel::ReSize() { | ||||
| FreeTmpBuffer(); | |||||
| auto ret = ReduceBaseCPUKernel::ReSize(); | |||||
| if (ret != RET_OK) { | |||||
| return ret; | |||||
| } | |||||
| ret = MallocTmpBuffer(); | |||||
| if (ret != RET_OK) { | |||||
| FreeTmpBuffer(); | |||||
| return ret; | |||||
| } | |||||
| return RET_OK; | |||||
| return ReduceBaseCPUKernel::ReSize(); | |||||
| } | } | ||||
| int ReduceFp16CPUKernel::CallReduceUnit(int task_id) { | int ReduceFp16CPUKernel::CallReduceUnit(int task_id) { | ||||
| @@ -94,6 +84,12 @@ int ReduceFp16CPUKernel::Run() { | |||||
| return prepare_ret; | return prepare_ret; | ||||
| } | } | ||||
| auto ret = MallocTmpBuffer(); | |||||
| if (ret != RET_OK) { | |||||
| FreeTmpBuffer(); | |||||
| return ret; | |||||
| } | |||||
| tmp_shape_ = in_tensors_.at(0)->shape(); | tmp_shape_ = in_tensors_.at(0)->shape(); | ||||
| auto in_tensor = in_tensors_.at(0); | auto in_tensor = in_tensors_.at(0); | ||||
| if (in_tensor->data_type() == kNumberTypeFloat32 || in_tensor->data_type() == kNumberTypeFloat) { | if (in_tensor->data_type() == kNumberTypeFloat32 || in_tensor->data_type() == kNumberTypeFloat) { | ||||
| @@ -59,12 +59,6 @@ int TransposeFp16CPUKernel::ReSize() { | |||||
| param->out_strides_[i] = out_shape[i + 1] * param->out_strides_[i + 1]; | param->out_strides_[i] = out_shape[i + 1] * param->out_strides_[i + 1]; | ||||
| } | } | ||||
| FreeFp16Buffer(); | |||||
| auto ret = MallocFp16Buffer(); | |||||
| if (ret != RET_OK) { | |||||
| FreeFp16Buffer(); | |||||
| return ret; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -149,10 +143,16 @@ int TransposeFp16CPUKernel::Run() { | |||||
| auto &out_tensor = out_tensors_.front(); | auto &out_tensor = out_tensors_.front(); | ||||
| if (in_tensor == nullptr || out_tensor == nullptr) { | if (in_tensor == nullptr || out_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "null pointer referencing."; | MS_LOG(ERROR) << "null pointer referencing."; | ||||
| FreeFp16Buffer(); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // malloc when Run | |||||
| ret = MallocFp16Buffer(); | |||||
| if (ret != RET_OK) { | |||||
| FreeFp16Buffer(); | |||||
| return ret; | |||||
| } | |||||
| if (in_tensor->data_type() == kNumberTypeFloat || in_tensor->data_type() == kNumberTypeFloat32) { | if (in_tensor->data_type() == kNumberTypeFloat || in_tensor->data_type() == kNumberTypeFloat32) { | ||||
| in_data_ = reinterpret_cast<float *>(in_tensor->Data()); | in_data_ = reinterpret_cast<float *>(in_tensor->Data()); | ||||
| Float32ToFloat16(in_data_, fp16_in_data_, in_tensor->ElementsNum()); | Float32ToFloat16(in_data_, fp16_in_data_, in_tensor->ElementsNum()); | ||||