|
|
|
@@ -215,7 +215,6 @@ int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_inter |
|
|
|
} |
|
|
|
|
|
|
|
namespace { |
|
|
|
#ifndef SUPPORT_TRAIN |
|
|
|
int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origin_tensors, TypeId dst_data_type) { |
|
|
|
#if defined(ENABLE_ARM) && defined(ENABLE_FP16) |
|
|
|
MS_ASSERT(tensor != nullptr); |
|
|
|
@@ -319,7 +318,6 @@ int CopyConstTensorData(const std::vector<Tensor *> &tensors, int op_type) { |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
inline void FreeRestoreTensors(std::map<Tensor *, Tensor *> *restored_origin_tensors) { |
|
|
|
MS_ASSERT(restored_origin_tensors != nullptr); |
|
|
|
@@ -368,19 +366,20 @@ kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_ten |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
std::map<Tensor *, Tensor *> restored_origin_tensors; |
|
|
|
#ifndef SUPPORT_TRAIN |
|
|
|
ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
// we don't need to restore tensor for copy data |
|
|
|
ret = CopyConstTensorData(in_tensors, op_type); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret; |
|
|
|
return nullptr; |
|
|
|
|
|
|
|
if (!is_train_session_) { |
|
|
|
ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
// we don't need to restore tensor for copy data |
|
|
|
ret = CopyConstTensorData(in_tensors, op_type); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
} |
|
|
|
#endif |
|
|
|
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter); |
|
|
|
if (kernel != nullptr) { |
|
|
|
MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type); |
|
|
|
|