|
|
|
@@ -465,7 +465,7 @@ double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, co |
|
|
|
CheckGlobalDeviceManager(); |
|
|
|
MS_EXCEPTION_IF_NULL(g_device_manager); |
|
|
|
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); |
|
|
|
TensorRedistribution tensor_redistribution; |
|
|
|
TensorRedistribution tensor_redistribution(false, true); |
|
|
|
if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; |
|
|
|
} |
|
|
|
@@ -503,7 +503,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo> &inp |
|
|
|
CheckGlobalDeviceManager(); |
|
|
|
MS_EXCEPTION_IF_NULL(g_device_manager); |
|
|
|
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); |
|
|
|
TensorRedistribution tensor_redistribution; |
|
|
|
TensorRedistribution tensor_redistribution(false, true); |
|
|
|
if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; |
|
|
|
} |
|
|
|
|