|
|
@@ -133,9 +133,13 @@ Status ReshapeInfo::GetParameterInput() { |
|
|
|
|
|
|
|
|
Status ReshapeInfo::ComputeReplaceOp() { |
|
|
Status ReshapeInfo::ComputeReplaceOp() { |
|
|
RankList dev_list = global_device_list(); |
|
|
RankList dev_list = global_device_list(); |
|
|
TensorRedistribution tensor_redistribution(true, true); |
|
|
|
|
|
|
|
|
TensorRedistribution tensor_redistribution(!is_generating_costs_, true); |
|
|
if (tensor_redistribution.Init(input_layout_, output_layout_, dev_list) == FAILED) { |
|
|
if (tensor_redistribution.Init(input_layout_, output_layout_, dev_list) == FAILED) { |
|
|
MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed."; |
|
|
|
|
|
|
|
|
if (is_generating_costs_) { |
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": tensor_redistribution init failed."; |
|
|
|
|
|
} else { |
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed."; |
|
|
|
|
|
} |
|
|
return FAILED; |
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|
MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString(); |
|
|
MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString(); |
|
|
@@ -143,7 +147,11 @@ Status ReshapeInfo::ComputeReplaceOp() { |
|
|
MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size(); |
|
|
MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size(); |
|
|
RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); |
|
|
RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); |
|
|
if (redistribution_oplist_ptr == nullptr) { |
|
|
if (redistribution_oplist_ptr == nullptr) { |
|
|
MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; |
|
|
|
|
|
|
|
|
if (is_generating_costs_) { |
|
|
|
|
|
MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed."; |
|
|
|
|
|
} else { |
|
|
|
|
|
MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; |
|
|
|
|
|
} |
|
|
return FAILED; |
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|
replace_op_ = redistribution_oplist_ptr->first; |
|
|
replace_op_ = redistribution_oplist_ptr->first; |
|
|
@@ -444,6 +452,7 @@ Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { |
|
|
Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs, |
|
|
Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs, |
|
|
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs, |
|
|
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs, |
|
|
int32_t out_index, int32_t in_index, bool is_prev_param) { |
|
|
int32_t out_index, int32_t in_index, bool is_prev_param) { |
|
|
|
|
|
is_generating_costs_ = true; |
|
|
for (auto pre_stra_cost : pre_stra_costs) { |
|
|
for (auto pre_stra_cost : pre_stra_costs) { |
|
|
std::vector<TensorInfo> pre_out_tensor_infos; |
|
|
std::vector<TensorInfo> pre_out_tensor_infos; |
|
|
if (is_prev_param) { |
|
|
if (is_prev_param) { |
|
|
@@ -488,6 +497,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra |
|
|
SetCostForReshape(reshape_stra); |
|
|
SetCostForReshape(reshape_stra); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
is_generating_costs_ = false; |
|
|
if (strategy_cost_.empty()) { |
|
|
if (strategy_cost_.empty()) { |
|
|
return FAILED; |
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|
|