| @@ -13,9 +13,6 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "parallel/auto_parallel/graph_costmodel.h" | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <cstdlib> | #include <cstdlib> | ||||
| #include <iterator> | #include <iterator> | ||||
| @@ -24,6 +21,10 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "parallel/auto_parallel/graph_costmodel.h" | |||||
| #include "parallel/ops_info/reshape_info.h" | |||||
| #include "parallel/step_auto_parallel.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| CostGraphPtr entire_costgraph = nullptr; | CostGraphPtr entire_costgraph = nullptr; | ||||
| @@ -40,6 +41,7 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; | |||||
| bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | ||||
| bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; | bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; | ||||
| int32_t RUN_PHASE = DEFAULT_RUN_PHASE; | int32_t RUN_PHASE = DEFAULT_RUN_PHASE; | ||||
| constexpr char RESHAPEINFO[] = "ReshapeInfo"; | |||||
| void CostGraph::SetDeviceMemoryAndCostParameter() { | void CostGraph::SetDeviceMemoryAndCostParameter() { | ||||
| MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | ||||
| @@ -182,6 +184,20 @@ bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) { | |||||
| return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test)); | return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test)); | ||||
| } | } | ||||
| void CostGraph::AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) { | |||||
| std::vector<EdgePtr> curr_edges(edges_[{u_node, v_node}]); | |||||
| curr_edges.push_back(edge); | |||||
| edges_[{u_node, v_node}] = curr_edges; | |||||
| std::vector<EdgePtr> curr_out_edges(out_edges_[u_node]); | |||||
| curr_out_edges.push_back(edge); | |||||
| out_edges_[u_node] = curr_out_edges; | |||||
| std::vector<EdgePtr> curr_in_edges(in_edges_[v_node]); | |||||
| curr_in_edges.push_back(edge); | |||||
| in_edges_[v_node] = curr_in_edges; | |||||
| } | |||||
| bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) { | bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) { | ||||
| for (auto &edge_pair : edges_) { | for (auto &edge_pair : edges_) { | ||||
| auto edges = edge_pair.second; | auto edges = edge_pair.second; | ||||
| @@ -1338,11 +1354,51 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo | |||||
| Status CostGraph::InitSelectedStrategy() { | Status CostGraph::InitSelectedStrategy() { | ||||
| for (auto &op : ops_) { | for (auto &op : ops_) { | ||||
| MS_EXCEPTION_IF_NULL(op); | MS_EXCEPTION_IF_NULL(op); | ||||
| if (op->name().find(RESHAPEINFO) != std::string::npos) { | |||||
| continue; | |||||
| } | |||||
| auto result = op->InitSelectedStrategy(op->selected_strategy()); | auto result = op->InitSelectedStrategy(op->selected_strategy()); | ||||
| if (result != SUCCESS) { | if (result != SUCCESS) { | ||||
| return result; | return result; | ||||
| } | } | ||||
| } | } | ||||
| // reshape init should be apply after the init of it's previous node and next node. | |||||
| for (size_t i = 0; i < ops_.size(); ++i) { | |||||
| if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) { | |||||
| auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(ops_[i]); | |||||
| auto in_edges = GetOriginalPrevEdges(ops_[i]); | |||||
| auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](std::shared_ptr<Edge> edge) { | |||||
| return edge->prev_operator()->name() == reshape_info->pre_operator_name(); | |||||
| }); | |||||
| auto out_edges = GetOriginalNextEdges(ops_[i]); | |||||
| auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr<Edge> edge) { | |||||
| return edge->next_operator()->name() == reshape_info->next_operator_name(); | |||||
| }); | |||||
| if (pre_iter != in_edges.end()) { | |||||
| MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name(); | |||||
| int32_t pre_index = reshape_info->pre_operator_index(); | |||||
| Dimensions stra; | |||||
| TensorInfo pre_info; | |||||
| if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) { | |||||
| pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index]; | |||||
| } else { | |||||
| pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index]; | |||||
| } | |||||
| reshape_info->SetInputLayout(pre_info.tensor_layout()); | |||||
| InferStraByTensorInfo(pre_info, &stra); | |||||
| std::vector<Dimensions> stra_inputs = {stra}; | |||||
| StrategyPtr reshape_stra = | |||||
| std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs); | |||||
| reshape_info->set_strategy(reshape_stra); | |||||
| } | |||||
| if (next_iter != out_edges.end()) { | |||||
| MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name(); | |||||
| int32_t next_index = reshape_info->next_operator_index(); | |||||
| reshape_info->SetOutputLayout((*next_iter)->next_operator()->inputs_tensor_info()[next_index].tensor_layout()); | |||||
| } | |||||
| return reshape_info->Init(nullptr); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -87,11 +87,9 @@ class CostGraph { | |||||
| void RemoveOperator(const OperatorInfoPtr &op); | void RemoveOperator(const OperatorInfoPtr &op); | ||||
| bool IsOperatorInCostGraph(const OperatorInfoPtr &op); | bool IsOperatorInCostGraph(const OperatorInfoPtr &op); | ||||
| // the edge is in the form: u --> v | // the edge is in the form: u --> v | ||||
| void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) { | |||||
| std::vector<EdgePtr> curr_edges(edges_[{u_node, v_node}]); | |||||
| curr_edges.push_back(edge); | |||||
| edges_[{u_node, v_node}] = curr_edges; | |||||
| } | |||||
| void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge); | |||||
| std::vector<std::shared_ptr<Edge>> GetOriginalPrevEdges(OperatorInfoPtr v_node) { return in_edges_[v_node]; } | |||||
| std::vector<std::shared_ptr<Edge>> GetOriginalNextEdges(OperatorInfoPtr u_node) { return out_edges_[u_node]; } | |||||
| // An edge is uniquely identified by its name, and its output index and input index. | // An edge is uniquely identified by its name, and its output index and input index. | ||||
| bool IsEdgeInCostGraph(const std::string &, size_t, size_t); | bool IsEdgeInCostGraph(const std::string &, size_t, size_t); | ||||
| @@ -219,6 +217,8 @@ class CostGraph { | |||||
| std::vector<OperatorInfoPtr> ops_; | std::vector<OperatorInfoPtr> ops_; | ||||
| std::map<std::pair<OperatorInfoPtr, OperatorInfoPtr>, std::vector<EdgePtr>> edges_; | std::map<std::pair<OperatorInfoPtr, OperatorInfoPtr>, std::vector<EdgePtr>> edges_; | ||||
| std::vector<std::shared_ptr<CostGraph>> connected_compoents_; | std::vector<std::shared_ptr<CostGraph>> connected_compoents_; | ||||
| std::map<OperatorInfoPtr, std::vector<EdgePtr>> out_edges_; | |||||
| std::map<OperatorInfoPtr, std::vector<EdgePtr>> in_edges_; | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -111,6 +111,7 @@ class OperatorInfo { | |||||
| Shape dev_matrix_shape() const { return dev_matrix_shape_; } | Shape dev_matrix_shape() const { return dev_matrix_shape_; } | ||||
| std::vector<TensorInfo> inputs_tensor_info() const { return inputs_tensor_info_; } | std::vector<TensorInfo> inputs_tensor_info() const { return inputs_tensor_info_; } | ||||
| std::vector<TensorInfo> outputs_tensor_info() const { return outputs_tensor_info_; } | std::vector<TensorInfo> outputs_tensor_info() const { return outputs_tensor_info_; } | ||||
| std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost() const { return strategy_cost_; } | |||||
| const std::string &name() const { return name_; } | const std::string &name() const { return name_; } | ||||
| void set_name(const std::string &name) { name_ = name; } | void set_name(const std::string &name) { name_ = name; } | ||||
| RankList global_device_list() const { return global_device_list_; } | RankList global_device_list() const { return global_device_list_; } | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "parallel/device_manager.h" | #include "parallel/device_manager.h" | ||||
| #include "parallel/device_matrix.h" | #include "parallel/device_matrix.h" | ||||
| #include "parallel/step_parallel.h" | #include "parallel/step_parallel.h" | ||||
| #include "parallel/auto_parallel/graph_costmodel.h" | |||||
| #include "utils/convert_utils.h" | #include "utils/convert_utils.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -46,26 +47,6 @@ Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| } | } | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| std::vector<Dimensions> stra = strategy->GetInputDim(); | |||||
| for (size_t i = 0; i < strategy_size; ++i) { | |||||
| Shape sub_strategy = stra.at(i); | |||||
| size_t strategy_len = sub_strategy.size(); | |||||
| bool flag = false; | |||||
| for (size_t j = 0; j < strategy_len; ++j) { | |||||
| int32_t strategy_value = sub_strategy.at(j); | |||||
| if (strategy_value > 1) { | |||||
| if (flag) { | |||||
| if (is_auto_parallel_) { | |||||
| MS_LOG(DEBUG) << name_ << ": Only support batch parallel strategy."; | |||||
| } else { | |||||
| MS_LOG(ERROR) << name_ << ": Only support batch parallel strategy."; | |||||
| } | |||||
| return FAILED; | |||||
| } | |||||
| flag = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -402,6 +383,41 @@ Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void ReshapeInfo::SetCostForReshapeWithParameter() { | |||||
| size_t success = 0; | |||||
| for (auto &sp : sp_vector_) { | |||||
| if (SetCostUnderStrategy(sp) == SUCCESS) { | |||||
| success++; | |||||
| MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; | |||||
| PrintStrategy(sp); | |||||
| } | |||||
| } | |||||
| } | |||||
| void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy) { | |||||
| MS_EXCEPTION_IF_NULL(strategy); | |||||
| int32_t stage_id = strategy->GetInputStage(); | |||||
| double computation_cost = | |||||
| operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||||
| double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | |||||
| result->communication_without_parameter_ = | |||||
| operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||||
| result->communication_with_partial_para_ = | |||||
| result->communication_without_parameter_ + | |||||
| COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); | |||||
| // Breaking ties for preferring data parallelization | |||||
| BreakingTiesForPerferringDataParallel(strategy, result); | |||||
| // refine communication cost calculation for practice | |||||
| RefineForPracticalCost(result, false); | |||||
| std::shared_ptr<StrategyWithCost> swc = | |||||
| std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_); | |||||
| swc->cost_list.push_back(result); | |||||
| strategy_cost_.emplace_back(swc); | |||||
| } | |||||
| Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { | Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { | ||||
| if (GetAttrs() != SUCCESS) { | if (GetAttrs() != SUCCESS) { | ||||
| MS_LOG(ERROR) << name_ << ": GetAttrs failed."; | MS_LOG(ERROR) << name_ << ": GetAttrs failed."; | ||||
| @@ -414,22 +430,14 @@ Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { | |||||
| } | } | ||||
| is_auto_parallel_ = true; | is_auto_parallel_ = true; | ||||
| Shape input0_split; | Shape input0_split; | ||||
| input0_split.emplace_back(1); | |||||
| (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size() - 1, 0); | |||||
| (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size(), 1); | |||||
| Shapes splittable_inputs = {input0_split}; | Shapes splittable_inputs = {input0_split}; | ||||
| std::vector<StrategyPtr> sp_vector; | |||||
| if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { | |||||
| // strategy used only in the input node is parameter, | |||||
| // in other case, use the input node's output_layout as input_layout. | |||||
| if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; | MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| size_t success = 0; | |||||
| for (auto &sp : sp_vector) { | |||||
| if (SetCostUnderStrategy(sp) == SUCCESS) { | |||||
| success++; | |||||
| MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; | |||||
| PrintStrategy(sp); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace parallel | } // namespace parallel | ||||
| @@ -50,9 +50,19 @@ class ReshapeInfo : public OperatorInfo { | |||||
| output_layout_ = output_layout; | output_layout_ = output_layout; | ||||
| output_layout_set_flag_ = true; | output_layout_set_flag_ = true; | ||||
| } | } | ||||
| void SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy); | |||||
| void SetCostForReshapeWithParameter(); | |||||
| void set_pre_operator_name(const std::string &pre_name) { pre_operator_name_ = pre_name; } | |||||
| void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; } | |||||
| void set_pre_operator_index(int32_t pre_index) { pre_operator_index_ = pre_index; } | |||||
| void set_next_operator_index(int32_t next_index) { next_operator_index_ = next_index; } | |||||
| Status InitForCostModel(const StrategyPtr &strategy) override; | Status InitForCostModel(const StrategyPtr &strategy) override; | ||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | ||||
| std::string pre_operator_name() const { return pre_operator_name_; } | |||||
| std::string next_operator_name() const { return next_operator_name_; } | |||||
| int32_t pre_operator_index() const { return pre_operator_index_; } | |||||
| int32_t next_operator_index() const { return next_operator_index_; } | |||||
| protected: | protected: | ||||
| Status CheckStrategy(const StrategyPtr &strategy) override; | Status CheckStrategy(const StrategyPtr &strategy) override; | ||||
| @@ -73,12 +83,17 @@ class ReshapeInfo : public OperatorInfo { | |||||
| Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout); | Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout); | ||||
| int32_t dev_num_; | int32_t dev_num_; | ||||
| int32_t pre_operator_index_; | |||||
| int32_t next_operator_index_; | |||||
| std::vector<int32_t> parameter_input_v_; | std::vector<int32_t> parameter_input_v_; | ||||
| std::vector<StrategyPtr> sp_vector_; | |||||
| Dimensions input_strategy_; | Dimensions input_strategy_; | ||||
| TensorLayout input_layout_; | TensorLayout input_layout_; | ||||
| TensorLayout output_layout_; | TensorLayout output_layout_; | ||||
| bool input_layout_set_flag_; | bool input_layout_set_flag_; | ||||
| bool output_layout_set_flag_; | bool output_layout_set_flag_; | ||||
| std::string pre_operator_name_; | |||||
| std::string next_operator_name_; | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -39,6 +39,7 @@ | |||||
| #include "parallel/auto_parallel/rec_core/rec_partition.h" | #include "parallel/auto_parallel/rec_core/rec_partition.h" | ||||
| #include "parallel/context.h" | #include "parallel/context.h" | ||||
| #include "parallel/ops_info/tmp_identity_info.h" | #include "parallel/ops_info/tmp_identity_info.h" | ||||
| #include "parallel/ops_info/reshape_info.h" | |||||
| #include "parallel/step_parallel.h" | #include "parallel/step_parallel.h" | ||||
| #include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | #include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | ||||
| #include "pipeline/parse/python_adapter.h" | #include "pipeline/parse/python_adapter.h" | ||||
| @@ -608,7 +609,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| EdgePtr edge_ptr; | EdgePtr edge_ptr; | ||||
| MS_LOG(INFO) << "Creating edge: " << edge_name; | MS_LOG(INFO) << "Creating edge: " << edge_name; | ||||
| bool follow_strategy = ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name()); | |||||
| bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) || | |||||
| (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name())); | |||||
| if (follow_strategy) { | if (follow_strategy) { | ||||
| // Redistribution in not allowed on the edge. | // Redistribution in not allowed on the edge. | ||||
| // Elementwise operators have the same strategy as their previous operators. | // Elementwise operators have the same strategy as their previous operators. | ||||
| @@ -893,6 +895,209 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| } | } | ||||
| } | } | ||||
| bool FindReshape(const CNodePtr &cnode) { | |||||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||||
| return false; | |||||
| } | |||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { | |||||
| return false; | |||||
| } | |||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||||
| if (operator_info == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; | |||||
| } | |||||
| if (prim->name() != RESHAPE) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| // find previous node, then obtain its strategy_cost_ vector to get its layout vector. | |||||
| bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int32_t *out_index) { | |||||
| // if previous node is a parameter, handle it in the outsize. | |||||
| if (node->isa<Parameter>()) { | |||||
| return false; | |||||
| } | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||||
| return false; | |||||
| } | |||||
| if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { | |||||
| *pre_operator_info = cnode->operator_info(); | |||||
| *out_index = 0; | |||||
| return true; | |||||
| } | |||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||||
| if (prim->name() == TUPLE_GETITEM) { | |||||
| *out_index = GetTupleGetItemIndex(cnode); | |||||
| // find tuple_get_item's previous node | |||||
| auto pre_node = cnode->input(1); | |||||
| if (!pre_node->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; | |||||
| } | |||||
| CNodePtr pre_cnode = pre_node->cast<CNodePtr>(); | |||||
| if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) { | |||||
| *pre_operator_info = pre_cnode->operator_info(); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| for (size_t index = 0; index < cnode->inputs().size(); ++index) { | |||||
| if (prim->name() == DEPEND && index != 1) { | |||||
| continue; | |||||
| } | |||||
| if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) { | |||||
| continue; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error"; | |||||
| return false; | |||||
| } | |||||
| // find next node, then obtain its strategy_cost_ vector to get its layout vector. | |||||
| // if reshape's output connect to several primitive, return the first layout found | |||||
| bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int32_t *in_index) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| MS_EXCEPTION_IF_NULL(cnode->func_graph()); | |||||
| FuncGraphManagerPtr manager = cnode->func_graph()->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| AnfNodeIndexSet node_set = manager->node_users()[cnode]; | |||||
| for (auto &node_pair : node_set) { | |||||
| CNodePtr use_apply = node_pair.first->cast<CNodePtr>(); | |||||
| if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) { | |||||
| continue; | |||||
| } | |||||
| ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prim_anf_node); | |||||
| PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(node_prim); | |||||
| MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); | |||||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||||
| continue; | |||||
| } | |||||
| if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { | |||||
| MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); | |||||
| *next_operator_info = use_apply->operator_info(); | |||||
| *in_index = node_pair.second - 1; | |||||
| return true; | |||||
| } | |||||
| MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) | |||||
| << " " << (use_apply->operator_info() != nullptr); | |||||
| if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra) { | |||||
| Shape shape = pre_out_tensor_info.shape(); | |||||
| Shape slice_shape = pre_out_tensor_info.slice_shape(); | |||||
| for (size_t i = 0; i < shape.size(); ++i) { | |||||
| if ((slice_shape[i] == 0) || (shape[i] % slice_shape[i] != 0)) { | |||||
| MS_LOG(EXCEPTION) << "slice_shape is wrong in reshape operator"; | |||||
| } | |||||
| int32_t dim = (int32_t)(shape[i] / slice_shape[i]); | |||||
| (*stra).push_back(dim); | |||||
| } | |||||
| } | |||||
| void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| for (auto node : all_nodes) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (!FindReshape(cnode)) { | |||||
| continue; | |||||
| } | |||||
| MS_ASSERT(cnode->inputs().size() == 3); | |||||
| // get previous node's strategy_cost_ | |||||
| auto pre_node = cnode->input(1); | |||||
| int32_t out_index = 0; | |||||
| OperatorInfoPtr pre_operator_info; | |||||
| std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs; | |||||
| if (pre_node->isa<Parameter>()) { | |||||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||||
| auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info); | |||||
| reshape_info->SetCostForReshapeWithParameter(); | |||||
| pre_operator_info = reshape_info; | |||||
| pre_stra_costs = reshape_info->strategy_cost(); | |||||
| } else { | |||||
| if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) { | |||||
| MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed"; | |||||
| } | |||||
| pre_stra_costs = pre_operator_info->strategy_cost(); | |||||
| } | |||||
| // get next node's strategy_cost_ | |||||
| int32_t in_index = 0; | |||||
| OperatorInfoPtr next_operator_info; | |||||
| std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs; | |||||
| bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index); | |||||
| if (!find_next_node) { | |||||
| MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed"; | |||||
| } | |||||
| // set input_layout and output_layout for reshape. | |||||
| // init reshape and set cost for each input_layout and output_layout. | |||||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||||
| auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info); | |||||
| reshape_info->set_pre_operator_name(pre_operator_info->name()); | |||||
| reshape_info->set_pre_operator_index(out_index); | |||||
| if (find_next_node) { | |||||
| next_stra_costs = next_operator_info->strategy_cost(); | |||||
| reshape_info->set_next_operator_name(next_operator_info->name()); | |||||
| reshape_info->set_next_operator_index(in_index); | |||||
| } | |||||
| for (auto pre_stra_cost : pre_stra_costs) { | |||||
| std::vector<TensorInfo> pre_out_tensor_infos; | |||||
| if (pre_node->isa<Parameter>()) { | |||||
| pre_out_tensor_infos = pre_stra_cost->inputs_ptr; | |||||
| } else { | |||||
| pre_out_tensor_infos = pre_stra_cost->outputs_ptr; | |||||
| } | |||||
| if (pre_out_tensor_infos.size() <= IntToSize(out_index)) { | |||||
| MS_LOG(EXCEPTION) << "out_index is out of range of the tensor_infos in setting reshape's input_layout"; | |||||
| } | |||||
| TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index]; | |||||
| TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout(); | |||||
| reshape_info->SetInputLayout(pre_out_tensor_layout); | |||||
| // infer pre_node output strategy from output_layout. | |||||
| Dimensions stra; | |||||
| InferStraByTensorInfo(pre_out_tensor_info, &stra); | |||||
| std::vector<Dimensions> stra_inputs = {stra}; | |||||
| StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs); | |||||
| if (next_stra_costs.empty()) { | |||||
| if (reshape_info->Init(nullptr) == FAILED) { | |||||
| MS_LOG(EXCEPTION) << "Failure:operator reshape init failed"; | |||||
| } | |||||
| // set cost for each input_layout and output_layout pairs. | |||||
| reshape_info->SetCostForReshape(reshape_stra); | |||||
| continue; | |||||
| } | |||||
| for (auto next_stra_cost : next_stra_costs) { | |||||
| std::vector<TensorInfo> next_in_tensor_infos = next_stra_cost->inputs_ptr; | |||||
| if (next_in_tensor_infos.size() <= IntToSize(in_index)) { | |||||
| MS_LOG(EXCEPTION) << "in_index is out of range of the tensor_infos in setting reshape's output_layout"; | |||||
| } | |||||
| TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; | |||||
| TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout(); | |||||
| reshape_info->SetOutputLayout(next_in_tensor_layout); | |||||
| if (reshape_info->Init(nullptr) == FAILED) { | |||||
| MS_LOG(EXCEPTION) << "Failure:operator reshape init failed"; | |||||
| } | |||||
| // set cost for each input_layout and output_layout pairs. | |||||
| reshape_info->SetCostForReshape(reshape_stra); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { | Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { | ||||
| // There are 4 meta-steps to determine the parallelization strategy for the ANF graph. | // There are 4 meta-steps to determine the parallelization strategy for the ANF graph. | ||||
| // Step 1: Traverse the ANF graph, and create NODEs for costgraph: | // Step 1: Traverse the ANF graph, and create NODEs for costgraph: | ||||
| @@ -930,7 +1135,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||||
| MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; | MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; | ||||
| } | } | ||||
| } | } | ||||
| // reshape operator needs the next node's input_layout as its output_layout. | |||||
| // and needs the previous node's output_layout as its input_layout. | |||||
| ReshapeCostCompute(all_nodes); | |||||
| // Step 2 | // Step 2 | ||||
| ConstructCostGraphEdges(all_nodes); | ConstructCostGraphEdges(all_nodes); | ||||
| MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() | MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() | ||||
| @@ -51,6 +51,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes); | |||||
| void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes); | void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes); | ||||
| void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra); | |||||
| Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | ||||
| Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | ||||
| @@ -219,22 +219,5 @@ TEST_F(TestReshapeInfo, CheckStrategy3) { | |||||
| Status ret = reshape->Init(strategy); | Status ret = reshape->Init(strategy); | ||||
| ASSERT_EQ(ret, SUCCESS); | ASSERT_EQ(ret, SUCCESS); | ||||
| } | } | ||||
| TEST_F(TestReshapeInfo, AutoStrategy1) { | |||||
| ASSERT_EQ(reshape->GenerateStrategies(0), Status::SUCCESS); | |||||
| std::vector<std::shared_ptr<StrategyWithCost>> sc = reshape->GetStrategyCost(); | |||||
| Shapes splittable_inputs = {{1, 0, 0, 0}}; | |||||
| std::vector<StrategyPtr> sp_vector; | |||||
| Shapes inputs_shape = {{32, 512, 7, 7}}; | |||||
| GenerateStrategiesForIndependentInputs(0, inputs_shape, splittable_inputs, &sp_vector); | |||||
| ASSERT_EQ(sc.size(), sp_vector.size()); | |||||
| for (auto stra : sp_vector) { | |||||
| auto stra0 = stra->GetInputDim()[0]; | |||||
| ASSERT_EQ(stra0[1], 1); | |||||
| ASSERT_EQ(stra0[2], 1); | |||||
| ASSERT_EQ(stra0[3], 1); | |||||
| } | |||||
| } | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -65,6 +65,193 @@ def test_reshape_matmul(): | |||||
| net.set_auto_parallel() | net.set_auto_parallel() | ||||
| _executor.compile(net, x) | _executor.compile(net, x) | ||||
| def test_reshape_auto_1(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.relu = P.ReLU() | |||||
| self.reshape = P.Reshape() | |||||
| self.matmul = P.MatMul() | |||||
| self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") | |||||
| def construct(self, x): | |||||
| out = self.relu(x) | |||||
| out = self.reshape(out, (64, 28)) | |||||
| out = self.matmul(out, self.matmul_weight) | |||||
| return out | |||||
| size = 8 | |||||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||||
| x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32) | |||||
| net = GradWrap(NetWithLoss(Net())) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| _executor.compile(net, x) | |||||
| def test_reshape_auto_2(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.relu = P.ReLU() | |||||
| self.reshape = P.Reshape() | |||||
| self.matmul = P.MatMul() | |||||
| self.add_weight = Parameter(Tensor(np.ones([128, 32]), dtype=ms.float32), name="weight1") | |||||
| self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") | |||||
| def construct(self, x): | |||||
| out = self.relu(x) | |||||
| out = self.reshape(out, (64, 28)) | |||||
| out = self.matmul(out, self.matmul_weight) | |||||
| out = self.reshape(out, (128, 32)) | |||||
| out = out + self.add_weight | |||||
| return out | |||||
| size = 8 | |||||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||||
| x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32) | |||||
| net = GradWrap(NetWithLoss(Net())) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| _executor.compile(net, x) | |||||
| def test_reshape_auto_3(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.relu = P.ReLU() | |||||
| self.reshape = P.Reshape() | |||||
| self.matmul = P.MatMul() | |||||
| self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") | |||||
| def construct(self, x): | |||||
| out = self.relu(x) | |||||
| out = self.matmul(out, self.matmul_weight) | |||||
| out = self.reshape(out, (8, 8, 8, 8)) | |||||
| return out | |||||
| size = 8 | |||||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||||
| x = Tensor(np.ones([8*size, 28]), dtype=ms.float32) | |||||
| net = GradWrap(NetWithLoss(Net())) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| _executor.compile(net, x) | |||||
| def test_reshape_auto_4(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.relu = P.ReLU() | |||||
| self.reshape = P.Reshape() | |||||
| self.matmul = P.MatMul() | |||||
| self.matmul_weight = Parameter(Tensor(np.ones([28*64]), dtype=ms.float32), name="weight") | |||||
| def construct(self, x): | |||||
| out = self.relu(x) | |||||
| out = self.reshape(out, (64, 28)) | |||||
| w = self.reshape(self.matmul_weight, (28, 64)) | |||||
| out = self.matmul(out, w) | |||||
| return out | |||||
| if __name__ == '__main__': | |||||
| test_reshape_matmul() | |||||
| size = 8 | |||||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||||
| x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32) | |||||
| net = GradWrap(NetWithLoss(Net())) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| _executor.compile(net, x) | |||||
| def test_reshape_auto_5(): | |||||
| class NetWithLoss(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(NetWithLoss, self).__init__() | |||||
| self.loss = VirtualLoss() | |||||
| self.network = network | |||||
| def construct(self, x, y): | |||||
| predict = self.network(x, y) | |||||
| return self.loss(predict) | |||||
| class GradWrap(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(GradWrap, self).__init__() | |||||
| self.network = network | |||||
| def construct(self, x, y): | |||||
| return C.grad_all(self.network)(x, y) | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.relu = P.ReLU() | |||||
| self.mul = P.Mul() | |||||
| self.reshape = P.Reshape() | |||||
| self.reduce_sum = P.ReduceSum() | |||||
| self.wide_w = Parameter(Tensor(np.ones([4, 1024*8, 64]), dtype=ms.float32), name="weight") | |||||
| def construct(self, x, y): | |||||
| mask = self.reshape(y, (4, 1024*8, 1)) | |||||
| w_id = self.relu(x) | |||||
| wx = self.mul(w_id, mask) | |||||
| wide_out = self.reshape(self.reduce_sum(wx, 1), (-1,1)) | |||||
| deep_id = x + self.wide_w | |||||
| vx = self.mul(deep_id, mask) | |||||
| deep_in = self.reshape(vx, (-1, 1024*8*64)) | |||||
| out = wide_out + deep_in | |||||
| return out | |||||
| size = 8 | |||||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||||
| x = Tensor(np.ones([4, 1024*size, 1]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([4, 1024*size,]), dtype=ms.float32) | |||||
| net = GradWrap(NetWithLoss(Net())) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| _executor.compile(net, x, y) | |||||
| def test_reshape_auto_6(): | |||||
| class NetWithLoss(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(NetWithLoss, self).__init__() | |||||
| self.loss = VirtualLoss() | |||||
| self.network = network | |||||
| def construct(self, x, y): | |||||
| predict = self.network(x, y) | |||||
| return self.loss(predict) | |||||
| class GradWrap(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(GradWrap, self).__init__() | |||||
| self.network = network | |||||
| def construct(self, x, y): | |||||
| return C.grad_all(self.network)(x, y) | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.relu = P.ReLU() | |||||
| self.mul = P.Mul() | |||||
| self.reshape = P.Reshape() | |||||
| self.reduce_mean = P.ReduceMean() | |||||
| self.wide_w = Parameter(Tensor(np.ones([4, 1024, 1]), dtype=ms.float32), name="weight") | |||||
| def construct(self, x, y): | |||||
| out1 = x + self.wide_w | |||||
| w = self.reshape(self.wide_w, (4,1024)) | |||||
| out1 = self.reduce_mean(out1, 1) | |||||
| out1 = out1 - w | |||||
| out2 = self.mul(y, w) | |||||
| out = out1 + out2 | |||||
| return out | |||||
| size = 8 | |||||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||||
| x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([4, 1024,]), dtype=ms.float32) | |||||
| net = GradWrap(NetWithLoss(Net())) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| _executor.compile(net, x, y) | |||||