| @@ -465,7 +465,7 @@ double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, co | |||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | MS_EXCEPTION_IF_NULL(g_device_manager); | ||||
| RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); | RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); | ||||
| TensorRedistribution tensor_redistribution; | |||||
| TensorRedistribution tensor_redistribution(false, true); | |||||
| if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { | if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { | ||||
| MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; | MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; | ||||
| } | } | ||||
| @@ -503,7 +503,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo> &inp | |||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | MS_EXCEPTION_IF_NULL(g_device_manager); | ||||
| RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); | RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); | ||||
| TensorRedistribution tensor_redistribution; | |||||
| TensorRedistribution tensor_redistribution(false, true); | |||||
| if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { | if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { | ||||
| MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; | MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; | ||||
| } | } | ||||