|
|
|
@@ -217,71 +217,105 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_inter |
|
|
|
namespace { |
|
|
|
#ifndef SUPPORT_TRAIN |
|
|
|
int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type) { |
|
|
|
MS_ASSERT(restored_origin_tensors != nullptr); |
|
|
|
#if defined(ENABLE_ARM) && defined(ENABLE_FP16) |
|
|
|
MS_ASSERT(tensor != nullptr); |
|
|
|
if (dst_data_type != kNumberTypeFloat32 && dst_data_type != kNumberTypeFloat16) { |
|
|
|
MS_LOG(ERROR) << "Only support fp32 or fp16 as dst_data_type."; |
|
|
|
return RET_PARAM_INVALID; |
|
|
|
} |
|
|
|
// tensorlist not support fp16 now |
|
|
|
if (!tensor->IsConst() || tensor->data_type() == kObjectTypeTensorType) { |
|
|
|
MS_ASSERT(tensor->IsConst()); |
|
|
|
MS_ASSERT(tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeFloat16); |
|
|
|
MS_ASSERT(dst_data_type == kNumberTypeFloat32 || dst_data_type == kNumberTypeFloat16); |
|
|
|
if (tensor->data_type() == dst_data_type) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
auto origin_data = tensor->data_c(); |
|
|
|
MS_ASSERT(origin_data != nullptr); |
|
|
|
if (tensor->data_type() == kNumberTypeFloat32 && dst_data_type == kNumberTypeFloat16) { |
|
|
|
#if defined(ENABLE_ARM) && defined(ENABLE_FP16) |
|
|
|
auto restore_tensor = Tensor::CopyTensor(*tensor, false); |
|
|
|
restore_tensor->set_data(origin_data); |
|
|
|
restore_tensor->set_own_data(tensor->own_data()); |
|
|
|
tensor->set_data(nullptr); |
|
|
|
tensor->set_data_type(kNumberTypeFloat16); |
|
|
|
auto ret = tensor->MallocData(); |
|
|
|
if (RET_OK != ret) { |
|
|
|
MS_LOG(ERROR) << "malloc data failed"; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
auto new_tensor_data = tensor->data_c(); |
|
|
|
MS_ASSERT(new_tensor_data != nullptr); |
|
|
|
auto restore_tensor = Tensor::CopyTensor(*tensor, false); |
|
|
|
restore_tensor->set_data(origin_data); |
|
|
|
restore_tensor->set_own_data(tensor->own_data()); |
|
|
|
tensor->set_data(nullptr); |
|
|
|
tensor->set_data_type(dst_data_type); |
|
|
|
auto ret = tensor->MallocData(); |
|
|
|
if (RET_OK != ret) { |
|
|
|
MS_LOG(ERROR) << "malloc data failed"; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
auto new_tensor_data = tensor->data_c(); |
|
|
|
MS_ASSERT(new_tensor_data != nullptr); |
|
|
|
if (dst_data_type == kNumberTypeFloat32) { |
|
|
|
Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum()); |
|
|
|
} else { // dst_data_type == kNumberTypeFloat16 |
|
|
|
Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum()); |
|
|
|
(*restored_origin_tensors)[tensor] = restore_tensor; |
|
|
|
#else |
|
|
|
MS_LOG(ERROR) << "Unsupported dst data type: float16"; |
|
|
|
} |
|
|
|
if (restored_origin_tensors->find(tensor) != restored_origin_tensors->end()) { |
|
|
|
MS_LOG(ERROR) << "Tensor " << tensor->tensor_name() << " is already be stored"; |
|
|
|
return RET_ERROR; |
|
|
|
#endif |
|
|
|
} else if (tensor->data_type() == kNumberTypeFloat16 && dst_data_type == kNumberTypeFloat32) { |
|
|
|
#if defined(ENABLE_ARM64) && defined(ENABLE_FP16) |
|
|
|
auto restore_tensor = Tensor::CopyTensor(*tensor, false); |
|
|
|
restore_tensor->set_data(origin_data); |
|
|
|
restore_tensor->set_own_data(tensor->own_data()); |
|
|
|
tensor->set_data(nullptr); |
|
|
|
tensor->set_data_type(kNumberTypeFloat32); |
|
|
|
auto ret = tensor->MallocData(); |
|
|
|
if (RET_OK != ret) { |
|
|
|
MS_LOG(ERROR) << "malloc data failed"; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
auto new_tensor_data = tensor->data_c(); |
|
|
|
MS_ASSERT(new_tensor_data != nullptr); |
|
|
|
Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum()); |
|
|
|
(*restored_origin_tensors)[tensor] = restore_tensor; |
|
|
|
} |
|
|
|
(*restored_origin_tensors)[tensor] = restore_tensor; |
|
|
|
return RET_OK; |
|
|
|
#else |
|
|
|
MS_LOG(ERROR) << "Unsupported dst data type: float16"; |
|
|
|
return RET_ERROR; |
|
|
|
return RET_NOT_SUPPORT; |
|
|
|
#endif |
|
|
|
} else { |
|
|
|
} |
|
|
|
|
|
|
|
int CastConstTensorsData(const std::vector<Tensor *> &tensors, std::map<Tensor *, Tensor *> *restored_origin_tensors, |
|
|
|
TypeId dst_data_type) { |
|
|
|
MS_ASSERT(restored_origin_tensors != nullptr); |
|
|
|
if (dst_data_type != kNumberTypeFloat32 && dst_data_type != kNumberTypeFloat16) { |
|
|
|
MS_LOG(ERROR) << "Only support fp32 or fp16 as dst_data_type."; |
|
|
|
return RET_PARAM_INVALID; |
|
|
|
} |
|
|
|
for (auto *tensor : tensors) { |
|
|
|
MS_ASSERT(tensor != nullptr); |
|
|
|
// only cast const tensor |
|
|
|
// tensorlist not support fp16 now |
|
|
|
if (!tensor->IsConst() || tensor->data_type() == kObjectTypeTensorType) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// only support fp32->fp16 or fp16->fp32 |
|
|
|
if (tensor->data_type() != kNumberTypeFloat32 && tensor->data_type() != kNumberTypeFloat16) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (tensor->data_type() == kNumberTypeFloat32 && dst_data_type == kNumberTypeFloat16) { |
|
|
|
auto ret = CastConstTensorData(tensor, restored_origin_tensors, kNumberTypeFloat16); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Cast const tensor from fp32 to fp16 failed, tensor name : " << tensor->tensor_name(); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
} else if (tensor->data_type() == kNumberTypeFloat16 && dst_data_type == kNumberTypeFloat32) { |
|
|
|
auto ret = CastConstTensorData(tensor, restored_origin_tensors, kNumberTypeFloat32); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Cast const tensor from fp16 to fp32 failed, tensor name : " << tensor->tensor_name(); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "No need to cast from " << tensor->data_type() << " to " << dst_data_type; |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int CopyConstTensorData(const std::vector<Tensor *> &tensors, int op_type) { |
|
|
|
// packed kernels such as conv don't need to copy because weight will be packed in kernel |
|
|
|
if (IsPackedOp(op_type)) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
for (auto *tensor : tensors) { |
|
|
|
// only cast const tensor |
|
|
|
// tensorlist not support fp16 now |
|
|
|
if (!tensor->IsConst() || tensor->data_type() == kObjectTypeTensorType) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (tensor->own_data()) { |
|
|
|
return RET_OK; |
|
|
|
continue; |
|
|
|
} |
|
|
|
tensor->set_data(nullptr); |
|
|
|
auto ret = tensor->MallocData(); |
|
|
|
if (RET_OK != ret) { |
|
|
|
MS_LOG(ERROR) << "malloc data failed"; |
|
|
|
return ret; |
|
|
|
auto copy_tensor = Tensor::CopyTensor(*tensor, true); |
|
|
|
if (copy_tensor == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Copy tensor failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto new_data = tensor->data_c(); |
|
|
|
MS_ASSERT(new_data != nullptr); |
|
|
|
memcpy(new_data, origin_data, tensor->Size()); |
|
|
|
tensor->FreeData(); |
|
|
|
tensor->set_data(copy_tensor->data_c()); |
|
|
|
tensor->set_own_data(true); |
|
|
|
copy_tensor->set_data(nullptr); |
|
|
|
delete (copy_tensor); |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
@@ -335,12 +369,16 @@ kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_ten |
|
|
|
} |
|
|
|
std::map<Tensor *, Tensor *> restored_origin_tensors; |
|
|
|
#ifndef SUPPORT_TRAIN |
|
|
|
for (auto &tensor : in_tensors) { |
|
|
|
ret = CastConstTensorData(tensor, &restored_origin_tensors, kernel_data_type); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(DEBUG) << "CastConstTensorData failed: " << ret; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
// we don't need to restore tensor for copy data |
|
|
|
ret = CopyConstTensorData(in_tensors, op_type); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
#endif |
|
|
|
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter); |
|
|
|
|