Merge pull request !1051 from yao_yf/auto_parallel_reshape_reconstructtags/v0.3.0-alpha
| @@ -1377,7 +1377,6 @@ Status CostGraph::InitSelectedStrategy() { | |||
| if (pre_iter != in_edges.end()) { | |||
| MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name(); | |||
| int32_t pre_index = reshape_info->pre_operator_index(); | |||
| Dimensions stra; | |||
| TensorInfo pre_info; | |||
| if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) { | |||
| pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index]; | |||
| @@ -1385,7 +1384,10 @@ Status CostGraph::InitSelectedStrategy() { | |||
| pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index]; | |||
| } | |||
| reshape_info->SetInputLayout(pre_info.tensor_layout()); | |||
| InferStraByTensorInfo(pre_info, &stra); | |||
| Dimensions stra = pre_info.InferStrategy(); | |||
| if (stra.empty()) { | |||
| MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed"; | |||
| } | |||
| std::vector<Dimensions> stra_inputs = {stra}; | |||
| StrategyPtr reshape_stra = | |||
| std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs); | |||
| @@ -440,5 +440,57 @@ Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs, | |||
| const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs, | |||
| int32_t out_index, int32_t in_index, bool is_prev_param) { | |||
| for (auto pre_stra_cost : pre_stra_costs) { | |||
| std::vector<TensorInfo> pre_out_tensor_infos; | |||
| if (is_prev_param) { | |||
| pre_out_tensor_infos = pre_stra_cost->inputs_ptr; | |||
| } else { | |||
| pre_out_tensor_infos = pre_stra_cost->outputs_ptr; | |||
| } | |||
| if (pre_out_tensor_infos.size() <= IntToSize(out_index)) { | |||
| MS_LOG(ERROR) << "out_index is out of range of the tensor_infos in setting reshape's input_layout"; | |||
| return FAILED; | |||
| } | |||
| TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index]; | |||
| TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout(); | |||
| SetInputLayout(pre_out_tensor_layout); | |||
| // infer pre_node output strategy from output_layout. | |||
| Dimensions stra = pre_out_tensor_info.InferStrategy(); | |||
| if (stra.empty()) { | |||
| MS_LOG(ERROR) << "Infer strategy by tensor_info failed"; | |||
| return FAILED; | |||
| } | |||
| std::vector<Dimensions> stra_inputs = {stra}; | |||
| StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs); | |||
| if (next_stra_costs.empty()) { | |||
| if (Init(nullptr) == FAILED) { | |||
| MS_LOG(ERROR) << "Failure:operator reshape init failed"; | |||
| return FAILED; | |||
| } | |||
| SetCostForReshape(reshape_stra); | |||
| continue; | |||
| } | |||
| for (auto next_stra_cost : next_stra_costs) { | |||
| std::vector<TensorInfo> next_in_tensor_infos = next_stra_cost->inputs_ptr; | |||
| if (next_in_tensor_infos.size() <= IntToSize(in_index)) { | |||
| MS_LOG(ERROR) << "in_index is out of range of the tensor_infos in setting reshape's output_layout"; | |||
| return FAILED; | |||
| } | |||
| TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; | |||
| TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout(); | |||
| SetOutputLayout(next_in_tensor_layout); | |||
| if (Init(nullptr) == FAILED) { | |||
| MS_LOG(ERROR) << "Failure:operator reshape init failed"; | |||
| return FAILED; | |||
| } | |||
| SetCostForReshape(reshape_stra); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -56,6 +56,9 @@ class ReshapeInfo : public OperatorInfo { | |||
| void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; } | |||
| void set_pre_operator_index(int32_t pre_index) { pre_operator_index_ = pre_index; } | |||
| void set_next_operator_index(int32_t next_index) { next_operator_index_ = next_index; } | |||
| Status GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs, | |||
| const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs, int32_t out_index, | |||
| int32_t in_index, bool is_prev_param); | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| Status GenerateStrategies(int32_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| @@ -999,18 +999,6 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator | |||
| return false; | |||
| } | |||
| void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra) { | |||
| Shape shape = pre_out_tensor_info.shape(); | |||
| Shape slice_shape = pre_out_tensor_info.slice_shape(); | |||
| for (size_t i = 0; i < shape.size(); ++i) { | |||
| if ((slice_shape[i] == 0) || (shape[i] % slice_shape[i] != 0)) { | |||
| MS_LOG(EXCEPTION) << "slice_shape is wrong in reshape operator"; | |||
| } | |||
| int32_t dim = (int32_t)(shape[i] / slice_shape[i]); | |||
| (*stra).push_back(dim); | |||
| } | |||
| } | |||
| void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||
| for (auto node : all_nodes) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| @@ -1054,46 +1042,10 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||
| reshape_info->set_next_operator_name(next_operator_info->name()); | |||
| reshape_info->set_next_operator_index(in_index); | |||
| } | |||
| for (auto pre_stra_cost : pre_stra_costs) { | |||
| std::vector<TensorInfo> pre_out_tensor_infos; | |||
| if (pre_node->isa<Parameter>()) { | |||
| pre_out_tensor_infos = pre_stra_cost->inputs_ptr; | |||
| } else { | |||
| pre_out_tensor_infos = pre_stra_cost->outputs_ptr; | |||
| } | |||
| if (pre_out_tensor_infos.size() <= IntToSize(out_index)) { | |||
| MS_LOG(EXCEPTION) << "out_index is out of range of the tensor_infos in setting reshape's input_layout"; | |||
| } | |||
| TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index]; | |||
| TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout(); | |||
| reshape_info->SetInputLayout(pre_out_tensor_layout); | |||
| // infer pre_node output strategy from output_layout. | |||
| Dimensions stra; | |||
| InferStraByTensorInfo(pre_out_tensor_info, &stra); | |||
| std::vector<Dimensions> stra_inputs = {stra}; | |||
| StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs); | |||
| if (next_stra_costs.empty()) { | |||
| if (reshape_info->Init(nullptr) == FAILED) { | |||
| MS_LOG(EXCEPTION) << "Failure:operator reshape init failed"; | |||
| } | |||
| // set cost for each input_layout and output_layout pairs. | |||
| reshape_info->SetCostForReshape(reshape_stra); | |||
| continue; | |||
| } | |||
| for (auto next_stra_cost : next_stra_costs) { | |||
| std::vector<TensorInfo> next_in_tensor_infos = next_stra_cost->inputs_ptr; | |||
| if (next_in_tensor_infos.size() <= IntToSize(in_index)) { | |||
| MS_LOG(EXCEPTION) << "in_index is out of range of the tensor_infos in setting reshape's output_layout"; | |||
| } | |||
| TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; | |||
| TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout(); | |||
| reshape_info->SetOutputLayout(next_in_tensor_layout); | |||
| if (reshape_info->Init(nullptr) == FAILED) { | |||
| MS_LOG(EXCEPTION) << "Failure:operator reshape init failed"; | |||
| } | |||
| // set cost for each input_layout and output_layout pairs. | |||
| reshape_info->SetCostForReshape(reshape_stra); | |||
| } | |||
| bool is_prev_param = pre_node->isa<Parameter>(); | |||
| if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) != | |||
| SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "reshape genetate strategy_costs failed!"; | |||
| } | |||
| } | |||
| } | |||
| @@ -51,8 +51,6 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes); | |||
| void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes); | |||
| void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra); | |||
| Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | |||
| Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | |||
| @@ -46,6 +46,17 @@ class TensorInfo { | |||
| Shape shape() const { return shape_; } | |||
| void set_reduce_dim(const std::vector<int32_t> &dim) { reduce_dim_ = dim; } | |||
| std::vector<int32_t> reduce_dim() const { return reduce_dim_; } | |||
| Dimensions InferStrategy() const { | |||
| Dimensions stra; | |||
| for (size_t i = 0; i < shape_.size(); ++i) { | |||
| if ((slice_shape_[i] == 0) || (shape_[i] % slice_shape_[i] != 0)) { | |||
| return stra; | |||
| } | |||
| int32_t dim = (int32_t)(shape_[i] / slice_shape_[i]); | |||
| stra.push_back(dim); | |||
| } | |||
| return stra; | |||
| } | |||
| private: | |||
| TensorLayout tensor_layout_; | |||
| @@ -86,6 +86,7 @@ def test_reshape_auto_1(): | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x) | |||
| def test_reshape_auto_2(): | |||
| @@ -112,6 +113,7 @@ def test_reshape_auto_2(): | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x) | |||
| def test_reshape_auto_3(): | |||
| @@ -135,6 +137,7 @@ def test_reshape_auto_3(): | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x) | |||
| def test_reshape_auto_4(): | |||
| @@ -159,6 +162,7 @@ def test_reshape_auto_4(): | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x) | |||
| @@ -208,6 +212,7 @@ def test_reshape_auto_5(): | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x, y) | |||
| def test_reshape_auto_6(): | |||
| @@ -254,4 +259,5 @@ def test_reshape_auto_6(): | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x, y) | |||