|
|
|
@@ -13,10 +13,10 @@ |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
#include "nnacl/fp16/cast_fp16.h" |
|
|
|
#include "nnacl/fp16/split_fp16.h" |
|
|
|
#include "src/runtime/kernel/arm/fp16/split_fp16.h" |
|
|
|
#include "src/runtime/kernel/arm/fp16/common_fp16.h" |
|
|
|
#include "src/runtime/kernel/arm/base/split_base.h" |
|
|
|
#include "nnacl/fp16/split_fp16.h" |
|
|
|
#include "nnacl/split.h" |
|
|
|
#include "nnacl/split_parameter.h" |
|
|
|
#include "src/kernel_registry.h" |
|
|
|
@@ -36,9 +36,10 @@ int SplitFp16CPUKernel::Init() { |
|
|
|
if (ret != RET_OK) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
output_ptr_.resize(param->num_split_); |
|
|
|
|
|
|
|
for (size_t i = 0; i < output_ptr_.size(); i++) { |
|
|
|
output_ptr_[i] = nullptr; |
|
|
|
} |
|
|
|
if (!InferShapeDone()) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
@@ -79,48 +80,37 @@ int SplitFp16CPUKernel::Run() { |
|
|
|
MS_LOG(ERROR) << "Prepare failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto in_tensor = in_tensors_.front(); |
|
|
|
if (in_tensor->data_type() == kNumberTypeFloat32) { |
|
|
|
input_ptr_ = |
|
|
|
reinterpret_cast<float16_t *>(context_->allocator->Malloc(in_tensor->ElementsNum() * sizeof(float16_t))); |
|
|
|
if (input_ptr_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "malloc input_ptr_ failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
Float32ToFloat16(reinterpret_cast<float *>(in_tensor->MutableData()), input_ptr_, in_tensor->ElementsNum()); |
|
|
|
} else { |
|
|
|
input_ptr_ = reinterpret_cast<float16_t *>(in_tensor->MutableData()); |
|
|
|
input_ptr_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_); |
|
|
|
if (input_ptr_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "input or output is nullptr"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
for (int i = 0; i < param->num_split_; i++) { |
|
|
|
if (in_tensor->data_type() == kNumberTypeFloat32) { |
|
|
|
output_ptr_[i] = reinterpret_cast<float16_t *>( |
|
|
|
context_->allocator->Malloc(out_tensors_.at(i)->ElementsNum() * sizeof(float16_t))); |
|
|
|
if (output_ptr_[i] == nullptr) { |
|
|
|
MS_LOG(ERROR) << "malloc output_ptr_[" << i << "]" << " failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
Float32ToFloat16(reinterpret_cast<float *>(out_tensors_.at(i)->MutableData()), output_ptr_[i], |
|
|
|
out_tensors_.at(i)->ElementsNum()); |
|
|
|
} else { |
|
|
|
output_ptr_[i] = reinterpret_cast<float16_t *>(out_tensors_.at(i)->MutableData()); |
|
|
|
output_ptr_[i] = MallocOutputFp16(out_tensors_.at(i), context_); |
|
|
|
if (output_ptr_[i] == nullptr) { |
|
|
|
FreeInputAndOutput(); |
|
|
|
MS_LOG(ERROR) << "input or output is nullptr"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
ret = ParallelLaunch(this->context_->thread_pool_, SplitRun, this, thread_n_num_); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "split error error_code[" << ret << "]"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (in_tensor->data_type() == kNumberTypeFloat32) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
void SplitFp16CPUKernel::FreeInputAndOutput() { |
|
|
|
if (in_tensors_.at(0)->data_type() == kNumberTypeFloat32) { |
|
|
|
context_->allocator->Free(input_ptr_); |
|
|
|
input_ptr_ = nullptr; |
|
|
|
} |
|
|
|
for (int i = 0; i < param->num_split_; i++) { |
|
|
|
if (in_tensor->data_type() == kNumberTypeFloat32) { |
|
|
|
if (out_tensors_.at(i)->data_type() == kNumberTypeFloat32) { |
|
|
|
context_->allocator->Free(output_ptr_[i]); |
|
|
|
output_ptr_[i] = nullptr; |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
kernel::LiteKernel *CpuSplitFp16KernelCreator(const std::vector<lite::Tensor *> &inputs, |
|
|
|
|