|
|
|
@@ -423,7 +423,6 @@ void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &stra |
|
|
|
std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_); |
|
|
|
swc->cost_list.push_back(result); |
|
|
|
strategy_cost_.emplace_back(swc); |
|
|
|
ResetQueueMember(); |
|
|
|
} |
|
|
|
|
|
|
|
Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { |
|
|
|
@@ -489,6 +488,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra |
|
|
|
} |
|
|
|
TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; |
|
|
|
SetOutputLayout(next_in_tensor_info.tensor_layout()); |
|
|
|
ResetQueueMember(); |
|
|
|
InferTensorInfoByLayout(); |
|
|
|
SetCostForReshape(reshape_stra); |
|
|
|
} |
|
|
|
|