|
|
|
@@ -17,6 +17,7 @@ |
|
|
|
#include "src/runtime/kernel/arm/fp16/common_fp16.h" |
|
|
|
#include "src/runtime/kernel/arm/base/split_base.h" |
|
|
|
#include "nnacl/fp16/split_fp16.h" |
|
|
|
#include "nnacl/fp16/cast_fp16.h" |
|
|
|
#include "nnacl/split.h" |
|
|
|
#include "nnacl/split_parameter.h" |
|
|
|
#include "src/kernel_registry.h" |
|
|
|
@@ -94,6 +95,13 @@ int SplitFp16CPUKernel::Run() { |
|
|
|
} |
|
|
|
} |
|
|
|
ret = ParallelLaunch(this->context_->thread_pool_, SplitRun, this, thread_n_num_); |
|
|
|
for (int i = 0; i < param->num_split_; i++) { |
|
|
|
if (out_tensors_.at(i)->data_type() == kNumberTypeFloat32) { |
|
|
|
Float16ToFloat32(output_ptr_[i], reinterpret_cast<float *>(out_tensors_.at(i)->MutableData()), |
|
|
|
out_tensors_.at(i)->ElementsNum()); |
|
|
|
} |
|
|
|
} |
|
|
|
FreeInputAndOutput(); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "split error error_code[" << ret << "]"; |
|
|
|
} |
|
|
|
|