Merge pull request !7495 from Xiaoda/27-star-elimination-overloadtags/v1.1.0
| @@ -201,7 +201,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||||
| right_edge->set_selected_cost(decision->right_edge_cost_); | right_edge->set_selected_cost(decision->right_edge_cost_); | ||||
| // 'left_node' recovers the strategy. | // 'left_node' recovers the strategy. | ||||
| left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); | left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); | ||||
| if (TRIANGLE_STRATEGY_OVERWRITE) { | |||||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||||
| // 'right_node' recovers the strategy. | // 'right_node' recovers the strategy. | ||||
| MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination."; | MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination."; | ||||
| right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_); | right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_); | ||||
| @@ -225,10 +225,16 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||||
| MS_EXCEPTION_IF_NULL(succ_nodes[0]); | MS_EXCEPTION_IF_NULL(succ_nodes[0]); | ||||
| MS_EXCEPTION_IF_NULL(decision->succ_ops_stra_list_[0]); | MS_EXCEPTION_IF_NULL(decision->succ_ops_stra_list_[0]); | ||||
| MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]); | MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]); | ||||
| // Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy. | |||||
| // Star is eliminated into 'succ_nodes[0]' | |||||
| succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); | succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); | ||||
| for (size_t k = 1; k < succ_nodes.size(); ++k) { | for (size_t k = 1; k < succ_nodes.size(); ++k) { | ||||
| succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]); | |||||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||||
| // 'succ_nodes[k]' is overwritten strategy and cost. | |||||
| succ_nodes[k]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[k], decision->succ_ops_cost_list_[k]); | |||||
| } else { | |||||
| // In this case, 'succ_nodes[k]' is NOT overwritten strategy and cost, however, it checks the strategy. | |||||
| succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]); | |||||
| } | |||||
| } | } | ||||
| MS_LOG(INFO) << "Recover starElimination succeeded."; | MS_LOG(INFO) << "Recover starElimination succeeded."; | ||||
| } else { | } else { | ||||
| @@ -20,9 +20,9 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "ir/value.h" | |||||
| #include "frontend/parallel/auto_parallel/edge_costmodel.h" | #include "frontend/parallel/auto_parallel/edge_costmodel.h" | ||||
| #include "frontend/parallel/auto_parallel/graph_costmodel.h" | #include "frontend/parallel/auto_parallel/graph_costmodel.h" | ||||
| #include "ir/value.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -22,11 +22,11 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "utils/ms_utils.h" | |||||
| #include "frontend/parallel/auto_parallel/costmodel.h" | #include "frontend/parallel/auto_parallel/costmodel.h" | ||||
| #include "frontend/parallel/ops_info/operator_info.h" | #include "frontend/parallel/ops_info/operator_info.h" | ||||
| #include "frontend/parallel/tensor_layout/tensor_info.h" | #include "frontend/parallel/tensor_layout/tensor_info.h" | ||||
| #include "frontend/parallel/tensor_layout/tensor_layout.h" | #include "frontend/parallel/tensor_layout/tensor_layout.h" | ||||
| #include "utils/ms_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -40,7 +40,7 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; | |||||
| bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | ||||
| bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; | bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; | ||||
| int32_t RUN_PHASE = DEFAULT_RUN_PHASE; | int32_t RUN_PHASE = DEFAULT_RUN_PHASE; | ||||
| bool TRIANGLE_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE; | |||||
| bool TRIANGLE_STAR_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE; | |||||
| void CostGraph::SetDeviceMemoryAndCostParameter() { | void CostGraph::SetDeviceMemoryAndCostParameter() { | ||||
| MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | ||||
| @@ -155,12 +155,12 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { | |||||
| MS_LOG(INFO) << "multi_subgraphs: false."; | MS_LOG(INFO) << "multi_subgraphs: false."; | ||||
| } | } | ||||
| auto overwrite = CostModelContext::GetInstance()->triangle_strategy_overwrite(); | |||||
| TRIANGLE_STRATEGY_OVERWRITE = overwrite; | |||||
| if (TRIANGLE_STRATEGY_OVERWRITE) { | |||||
| MS_LOG(INFO) << "triangle_strategy_overwrite: true."; | |||||
| auto overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite(); | |||||
| TRIANGLE_STAR_STRATEGY_OVERWRITE = overwrite; | |||||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||||
| MS_LOG(INFO) << "triangle_star_strategy_overwrite: true."; | |||||
| } else { | } else { | ||||
| MS_LOG(INFO) << "triangle_strategy_overwrite: false."; | |||||
| MS_LOG(INFO) << "triangle_star_strategy_overwrite: false."; | |||||
| } | } | ||||
| // RUN_PHASE | // RUN_PHASE | ||||
| @@ -1303,7 +1303,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, | |||||
| elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + | elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + | ||||
| left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; | left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; | ||||
| if (TRIANGLE_STRATEGY_OVERWRITE) { | |||||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||||
| new_computation += right_op_cost->computation_cost_; | new_computation += right_op_cost->computation_cost_; | ||||
| new_memory += right_op_cost->memory_with_reuse_; | new_memory += right_op_cost->memory_with_reuse_; | ||||
| new_commu_cost += right_op_cost->communication_cost_; | new_commu_cost += right_op_cost->communication_cost_; | ||||
| @@ -1399,7 +1399,9 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op, | |||||
| } | } | ||||
| if (!valid) { | if (!valid) { | ||||
| MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name() << " failed."; | |||||
| MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name() | |||||
| << " failed. It may be caused by " | |||||
| "configuring inconsistent strategies for operators."; | |||||
| } | } | ||||
| elimi_op->SetNotAlive(); | elimi_op->SetNotAlive(); | ||||
| MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded."; | MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded."; | ||||
| @@ -1440,6 +1442,13 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n | |||||
| commu_cost += succ_edges_costs[i]->communication_cost_; | commu_cost += succ_edges_costs[i]->communication_cost_; | ||||
| commu_forward += succ_edges_costs[i]->communication_forward_; | commu_forward += succ_edges_costs[i]->communication_forward_; | ||||
| commu_without += succ_edges_costs[i]->communication_without_parameter_; | commu_without += succ_edges_costs[i]->communication_without_parameter_; | ||||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||||
| computation_cost += succ_nodes_costs[i]->computation_cost_; | |||||
| memory_cost += succ_nodes_costs[i]->memory_with_reuse_; | |||||
| commu_cost += succ_nodes_costs[i]->communication_cost_; | |||||
| commu_forward += succ_nodes_costs[i]->communication_forward_; | |||||
| commu_without += succ_nodes_costs[i]->communication_without_parameter_; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1544,7 +1553,9 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo | |||||
| } | } | ||||
| if (!valid) { | if (!valid) { | ||||
| MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name() << " failed."; | |||||
| MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name() | |||||
| << " failed. It may be caused by " | |||||
| "configuring inconsistent strategies for operators."; | |||||
| } | } | ||||
| merged_op->SetNotAlive(); | merged_op->SetNotAlive(); | ||||
| @@ -22,11 +22,11 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "utils/ms_utils.h" | |||||
| #include "frontend/parallel/auto_parallel/edge_costmodel.h" | #include "frontend/parallel/auto_parallel/edge_costmodel.h" | ||||
| #include "frontend/parallel/costmodel_context.h" | #include "frontend/parallel/costmodel_context.h" | ||||
| #include "frontend/parallel/ops_info/operator_info.h" | #include "frontend/parallel/ops_info/operator_info.h" | ||||
| #include "frontend/parallel/ops_info/tmp_identity_info.h" | #include "frontend/parallel/ops_info/tmp_identity_info.h" | ||||
| #include "utils/ms_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -46,7 +46,7 @@ extern bool FULLY_USE_DEVICES; | |||||
| extern bool ELEMENTWISE_OP_STRA_FOLLOW; | extern bool ELEMENTWISE_OP_STRA_FOLLOW; | ||||
| extern bool MULTI_SUBGRAPHS; | extern bool MULTI_SUBGRAPHS; | ||||
| extern int32_t RUN_PHASE; | extern int32_t RUN_PHASE; | ||||
| extern bool TRIANGLE_STRATEGY_OVERWRITE; | |||||
| extern bool TRIANGLE_STAR_STRATEGY_OVERWRITE; | |||||
| class CostGraph { | class CostGraph { | ||||
| // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have | // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have | ||||
| @@ -16,8 +16,8 @@ | |||||
| #include "frontend/parallel/auto_parallel/operator_costmodel.h" | #include "frontend/parallel/auto_parallel/operator_costmodel.h" | ||||
| #include <random> | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <random> | |||||
| #include "frontend/parallel/device_matrix.h" | #include "frontend/parallel/device_matrix.h" | ||||
| #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | ||||
| @@ -19,9 +19,9 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <cstdint> | #include <cstdint> | ||||
| #include <functional> | #include <functional> | ||||
| #include <map> | |||||
| #include <memory> | #include <memory> | ||||
| #include <utility> | #include <utility> | ||||
| #include <map> | |||||
| #include "frontend/parallel/device_manager.h" | #include "frontend/parallel/device_manager.h" | ||||
| @@ -18,18 +18,18 @@ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_ | #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_ | ||||
| #include <cstdint> | #include <cstdint> | ||||
| #include <memory> | |||||
| #include <map> | #include <map> | ||||
| #include <memory> | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "abstract/abstract_value.h" | |||||
| #include "frontend/parallel/ops_info/ops_utils.h" | #include "frontend/parallel/ops_info/ops_utils.h" | ||||
| #include "frontend/parallel/status.h" | #include "frontend/parallel/status.h" | ||||
| #include "utils/convert_utils.h" | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "utils/convert_utils.h" | |||||
| #include "utils/info.h" | #include "utils/info.h" | ||||
| #include "abstract/abstract_value.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -64,7 +64,7 @@ void CostModelContext::ResetAlgoParameters() { | |||||
| tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; | tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; | ||||
| fully_use_device_ = DEFAULT_FULLY_USE_DEVICES; | fully_use_device_ = DEFAULT_FULLY_USE_DEVICES; | ||||
| elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | ||||
| triangle_strategy_overwrite_ = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE; | |||||
| triangle_star_strategy_overwrite_ = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE; | |||||
| } | } | ||||
| void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) { | void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) { | ||||
| @@ -134,7 +134,9 @@ void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { | |||||
| elementwise_stra_follow_ = elementwise_follow; | elementwise_stra_follow_ = elementwise_follow; | ||||
| } | } | ||||
| void CostModelContext::set_triangle_strategy_overwrite(bool overwrite) { triangle_strategy_overwrite_ = overwrite; } | |||||
| void CostModelContext::set_triangle_star_strategy_overwrite(bool overwrite) { | |||||
| triangle_star_strategy_overwrite_ = overwrite; | |||||
| } | |||||
| void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; } | void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; } | ||||
| @@ -44,7 +44,7 @@ namespace parallel { | |||||
| #define DEFAULT_RUN_PHASE 0 | #define DEFAULT_RUN_PHASE 0 | ||||
| #define TRAINING_PHASE 0 | #define TRAINING_PHASE 0 | ||||
| #define INFERENCE_PHASE 1 | #define INFERENCE_PHASE 1 | ||||
| #define DEFAULT_TRIANGLE_STRATEGY_OVERWRITE true; | |||||
| #define DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE true; | |||||
| class CostModelContext { | class CostModelContext { | ||||
| public: | public: | ||||
| @@ -135,8 +135,8 @@ class CostModelContext { | |||||
| void set_elementwise_stra_follow(bool); | void set_elementwise_stra_follow(bool); | ||||
| bool elementwise_stra_follow() const { return elementwise_stra_follow_; } | bool elementwise_stra_follow() const { return elementwise_stra_follow_; } | ||||
| void set_triangle_strategy_overwrite(bool); | |||||
| bool triangle_strategy_overwrite() const { return triangle_strategy_overwrite_; } | |||||
| void set_triangle_star_strategy_overwrite(bool); | |||||
| bool triangle_star_strategy_overwrite() const { return triangle_star_strategy_overwrite_; } | |||||
| void set_run_phase(int32_t); | void set_run_phase(int32_t); | ||||
| int32_t run_phase() const { return run_phase_; } | int32_t run_phase() const { return run_phase_; } | ||||
| @@ -172,9 +172,9 @@ class CostModelContext { | |||||
| // MULTI_SUBGRAPHS | // MULTI_SUBGRAPHS | ||||
| bool is_multi_subgraphs_; | bool is_multi_subgraphs_; | ||||
| // In the recovery phase of DP algorithm, when encountering triangle structure, | |||||
| // In the recovery phase of DP algorithm, when encountering triangle structure and star structure, | |||||
| // whether overwrite the right-node strategy | // whether overwrite the right-node strategy | ||||
| bool triangle_strategy_overwrite_; | |||||
| bool triangle_star_strategy_overwrite_; | |||||
| int32_t run_phase_; // 0: 'training', 1: 'inference' | int32_t run_phase_; // 0: 'training', 1: 'inference' | ||||
| @@ -25,13 +25,13 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "utils/ms_utils.h" | |||||
| #include "frontend/parallel/device.h" | #include "frontend/parallel/device.h" | ||||
| #include "frontend/parallel/device_matrix.h" | #include "frontend/parallel/device_matrix.h" | ||||
| #include "frontend/parallel/group_manager.h" | #include "frontend/parallel/group_manager.h" | ||||
| #include "frontend/parallel/status.h" | #include "frontend/parallel/status.h" | ||||
| #include "frontend/parallel/strategy.h" | #include "frontend/parallel/strategy.h" | ||||
| #include "utils/convert_utils.h" | #include "utils/convert_utils.h" | ||||
| #include "utils/ms_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -17,8 +17,8 @@ | |||||
| #include "frontend/parallel/group_manager.h" | #include "frontend/parallel/group_manager.h" | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <vector> | #include <vector> | ||||
| #include "frontend/parallel/device_manager.h" | |||||
| #include "backend/session/executor_manager.h" | #include "backend/session/executor_manager.h" | ||||
| #include "frontend/parallel/device_manager.h" | |||||
| #include "utils/comm_manager.h" | #include "utils/comm_manager.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| @@ -26,10 +26,8 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include <unordered_set> | |||||
| #include "ir/anf.h" | |||||
| #include "ir/param_info.h" | |||||
| #include "ir/tensor.h" | |||||
| #include "frontend/optimizer/opt.h" | #include "frontend/optimizer/opt.h" | ||||
| #include "frontend/optimizer/optimizer.h" | #include "frontend/optimizer/optimizer.h" | ||||
| #include "frontend/parallel/auto_parallel/dp_algo_costmodel.h" | #include "frontend/parallel/auto_parallel/dp_algo_costmodel.h" | ||||
| @@ -39,11 +37,14 @@ | |||||
| #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h" | #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h" | ||||
| #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" | #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" | ||||
| #include "frontend/parallel/context.h" | #include "frontend/parallel/context.h" | ||||
| #include "frontend/parallel/ops_info/tmp_identity_info.h" | |||||
| #include "frontend/parallel/ops_info/reshape_info.h" | |||||
| #include "frontend/parallel/graph_util/node_info.h" | #include "frontend/parallel/graph_util/node_info.h" | ||||
| #include "frontend/parallel/ops_info/reshape_info.h" | |||||
| #include "frontend/parallel/ops_info/tmp_identity_info.h" | |||||
| #include "frontend/parallel/step_parallel.h" | #include "frontend/parallel/step_parallel.h" | ||||
| #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | ||||
| #include "ir/anf.h" | |||||
| #include "ir/param_info.h" | |||||
| #include "ir/tensor.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -21,9 +21,9 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "ir/anf.h" | |||||
| #include "frontend/optimizer/opt.h" | #include "frontend/optimizer/opt.h" | ||||
| #include "frontend/parallel/status.h" | #include "frontend/parallel/status.h" | ||||
| #include "ir/anf.h" | |||||
| #include "pipeline/jit/pipeline.h" | #include "pipeline/jit/pipeline.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -27,8 +27,6 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| #include "ir/tensor.h" | |||||
| #include "ir/param_info.h" | |||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "frontend/optimizer/optimizer.h" | #include "frontend/optimizer/optimizer.h" | ||||
| #include "frontend/parallel/auto_parallel/graph_costmodel.h" | #include "frontend/parallel/auto_parallel/graph_costmodel.h" | ||||
| @@ -41,9 +39,11 @@ | |||||
| #include "frontend/parallel/node_check.h" | #include "frontend/parallel/node_check.h" | ||||
| #include "frontend/parallel/ops_info/matmul_info.h" | #include "frontend/parallel/ops_info/matmul_info.h" | ||||
| #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | ||||
| #include "ir/param_info.h" | |||||
| #include "ir/tensor.h" | |||||
| #include "utils/comm_manager.h" | #include "utils/comm_manager.h" | ||||
| #include "utils/symbolic.h" | |||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "utils/symbolic.h" | |||||
| using mindspore::tensor::Tensor; | using mindspore::tensor::Tensor; | ||||
| @@ -21,10 +21,10 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <memory> | #include <memory> | ||||
| #include <set> | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| #include <set> | |||||
| #include "frontend/optimizer/opt.h" | #include "frontend/optimizer/opt.h" | ||||
| #include "frontend/parallel/strategy.h" | #include "frontend/parallel/strategy.h" | ||||
| @@ -23,8 +23,8 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "frontend/parallel/status.h" | |||||
| #include "frontend/parallel/device_matrix.h" | #include "frontend/parallel/device_matrix.h" | ||||
| #include "frontend/parallel/status.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -0,0 +1,134 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.api import _executor | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.parallel._utils import _reset_op_id as reset_op_id | |||||
| from mindspore import context, Tensor, Parameter | |||||
| from mindspore.parallel import set_algo_parameters | |||||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| class NetWithLoss(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(NetWithLoss, self).__init__() | |||||
| self.loss = VirtualLoss() | |||||
| self.network = network | |||||
| def construct(self, x): | |||||
| predict = self.network(x) | |||||
| return self.loss(predict) | |||||
| class GradWarp(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(GradWarp, self).__init__() | |||||
| self.network = network | |||||
| def construct(self, x): | |||||
| return grad_all(self.network)(x) | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, strategy_dict=None): | |||||
| super(Net, self).__init__() | |||||
| self.mul1 = P.Mul() | |||||
| self.mul2 = P.Mul() | |||||
| self.mul3 = P.Mul() | |||||
| self.mul4 = P.Mul() | |||||
| self.relu1 = P.ReLU() | |||||
| self.relu2 = P.ReLU() | |||||
| self.ba1 = P.BiasAdd() | |||||
| self.add = P.TensorAdd() | |||||
| self.weight = Parameter(Tensor(np.ones([128, 1000]), dtype=ms.float32), name="weight") | |||||
| self.bias = Parameter(Tensor(np.ones([1000]), dtype=ms.float32), name="bias") | |||||
| if strategy_dict is not None: | |||||
| self.mul1.shard(strategy_dict["mul1"]) | |||||
| self.mul2.shard(strategy_dict["mul2"]) | |||||
| self.relu1.shard(strategy_dict["relu1"]) | |||||
| self.relu2.shard(strategy_dict["relu2"]) | |||||
| self.ba1.shard(strategy_dict["bias_add"]) | |||||
| self.add.shard(strategy_dict["add"]) | |||||
| def construct(self, inputs): | |||||
| x = self.mul1(inputs, self.weight) | |||||
| y = self.relu1(x) | |||||
| y = self.mul2(y, self.weight) | |||||
| z = self.mul3(x, self.weight) | |||||
| z = self.ba1(z, self.bias) | |||||
| x = self.add(y, z) | |||||
| x = self.mul4(x, self.weight) | |||||
| x = self.relu2(x) | |||||
| return x | |||||
| def test_star_strategy_consistency1(): | |||||
| size = 8 | |||||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||||
| set_algo_parameters(fully_use_devices=False) | |||||
| x = Tensor(np.ones([128, 1000]), dtype=ms.float32) | |||||
| strategy_dict = {"mul1": ((2, 4), (2, 4)), "mul2": None, "relu1": ((4, 1),), "bias_add": ((8, 1), (1,)), | |||||
| "relu2": ((2, 2),), "add": ((1, 8), (1, 8))} | |||||
| net = NetWithLoss(Net(strategy_dict)) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| net.set_auto_parallel() | |||||
| reset_op_id() | |||||
| _executor.compile(net, x, phase='train') | |||||
| def test_star_strategy_consistency2(): | |||||
| size = 8 | |||||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||||
| set_algo_parameters(fully_use_devices=False) | |||||
| x = Tensor(np.ones([128, 1000]), dtype=ms.float32) | |||||
| strategy_dict = {"mul1": None, "mul2": ((1, 4), (1, 4)), "relu1": ((2, 1),), "bias_add": ((4, 2), (2,)), | |||||
| "relu2": ((2, 2),), "add": ((8, 1), (8, 1))} | |||||
| net = NetWithLoss(Net(strategy_dict)) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| net.set_auto_parallel() | |||||
| reset_op_id() | |||||
| _executor.compile(net, x, phase='train') | |||||
| def test_star_strategy_consistency3(): | |||||
| size = 8 | |||||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||||
| set_algo_parameters(fully_use_devices=False) | |||||
| x = Tensor(np.ones([128, 1000]), dtype=ms.float32) | |||||
| strategy_dict = {"mul1": None, "mul2": None, "relu1": ((8, 1),), "bias_add": ((1, 4), (4,)), | |||||
| "relu2": ((4, 1),), "add": ((2, 2), (2, 2))} | |||||
| net = NetWithLoss(Net(strategy_dict)) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| net.set_auto_parallel() | |||||
| reset_op_id() | |||||
| _executor.compile(net, x, phase='train') | |||||
| def test_star_strategy_consistency4(): | |||||
| size = 8 | |||||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||||
| set_algo_parameters(fully_use_devices=False) | |||||
| x = Tensor(np.ones([128, 1000]), dtype=ms.float32) | |||||
| strategy_dict = {"mul1": ((1, 8), (1, 8)), "mul2": ((1, 4), (1, 4)), "relu1": None, "bias_add": None, | |||||
| "relu2": None, "add": None} | |||||
| net = NetWithLoss(Net(strategy_dict)) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| net.set_auto_parallel() | |||||
| reset_op_id() | |||||
| with pytest.raises(RuntimeError): | |||||
| _executor.compile(net, x, phase='train') | |||||