Merge pull request !977 from yao_yf/reshape_auto_parallel_strategy_searchtags/v0.3.0-alpha
| @@ -13,9 +13,6 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "parallel/auto_parallel/graph_costmodel.h" | |||
| #include <algorithm> | |||
| #include <cstdlib> | |||
| #include <iterator> | |||
| @@ -24,6 +21,10 @@ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "parallel/auto_parallel/graph_costmodel.h" | |||
| #include "parallel/ops_info/reshape_info.h" | |||
| #include "parallel/step_auto_parallel.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| CostGraphPtr entire_costgraph = nullptr; | |||
| @@ -40,6 +41,7 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; | |||
| bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | |||
| bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; | |||
| int32_t RUN_PHASE = DEFAULT_RUN_PHASE; | |||
| constexpr char RESHAPEINFO[] = "ReshapeInfo"; | |||
| void CostGraph::SetDeviceMemoryAndCostParameter() { | |||
| MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | |||
| @@ -182,6 +184,20 @@ bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) { | |||
| return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test)); | |||
| } | |||
| void CostGraph::AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) { | |||
| std::vector<EdgePtr> curr_edges(edges_[{u_node, v_node}]); | |||
| curr_edges.push_back(edge); | |||
| edges_[{u_node, v_node}] = curr_edges; | |||
| std::vector<EdgePtr> curr_out_edges(out_edges_[u_node]); | |||
| curr_out_edges.push_back(edge); | |||
| out_edges_[u_node] = curr_out_edges; | |||
| std::vector<EdgePtr> curr_in_edges(in_edges_[v_node]); | |||
| curr_in_edges.push_back(edge); | |||
| in_edges_[v_node] = curr_in_edges; | |||
| } | |||
| bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) { | |||
| for (auto &edge_pair : edges_) { | |||
| auto edges = edge_pair.second; | |||
| @@ -1338,11 +1354,51 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo | |||
| Status CostGraph::InitSelectedStrategy() { | |||
| for (auto &op : ops_) { | |||
| MS_EXCEPTION_IF_NULL(op); | |||
| if (op->name().find(RESHAPEINFO) != std::string::npos) { | |||
| continue; | |||
| } | |||
| auto result = op->InitSelectedStrategy(op->selected_strategy()); | |||
| if (result != SUCCESS) { | |||
| return result; | |||
| } | |||
| } | |||
| // reshape init should be apply after the init of it's previous node and next node. | |||
| for (size_t i = 0; i < ops_.size(); ++i) { | |||
| if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) { | |||
| auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(ops_[i]); | |||
| auto in_edges = GetOriginalPrevEdges(ops_[i]); | |||
| auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](std::shared_ptr<Edge> edge) { | |||
| return edge->prev_operator()->name() == reshape_info->pre_operator_name(); | |||
| }); | |||
| auto out_edges = GetOriginalNextEdges(ops_[i]); | |||
| auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr<Edge> edge) { | |||
| return edge->next_operator()->name() == reshape_info->next_operator_name(); | |||
| }); | |||
| if (pre_iter != in_edges.end()) { | |||
| MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name(); | |||
| int32_t pre_index = reshape_info->pre_operator_index(); | |||
| Dimensions stra; | |||
| TensorInfo pre_info; | |||
| if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) { | |||
| pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index]; | |||
| } else { | |||
| pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index]; | |||
| } | |||
| reshape_info->SetInputLayout(pre_info.tensor_layout()); | |||
| InferStraByTensorInfo(pre_info, &stra); | |||
| std::vector<Dimensions> stra_inputs = {stra}; | |||
| StrategyPtr reshape_stra = | |||
| std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs); | |||
| reshape_info->set_strategy(reshape_stra); | |||
| } | |||
| if (next_iter != out_edges.end()) { | |||
| MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name(); | |||
| int32_t next_index = reshape_info->next_operator_index(); | |||
| reshape_info->SetOutputLayout((*next_iter)->next_operator()->inputs_tensor_info()[next_index].tensor_layout()); | |||
| } | |||
| return reshape_info->Init(nullptr); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -87,11 +87,9 @@ class CostGraph { | |||
| void RemoveOperator(const OperatorInfoPtr &op); | |||
| bool IsOperatorInCostGraph(const OperatorInfoPtr &op); | |||
| // the edge is in the form: u --> v | |||
| void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) { | |||
| std::vector<EdgePtr> curr_edges(edges_[{u_node, v_node}]); | |||
| curr_edges.push_back(edge); | |||
| edges_[{u_node, v_node}] = curr_edges; | |||
| } | |||
| void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge); | |||
| std::vector<std::shared_ptr<Edge>> GetOriginalPrevEdges(OperatorInfoPtr v_node) { return in_edges_[v_node]; } | |||
| std::vector<std::shared_ptr<Edge>> GetOriginalNextEdges(OperatorInfoPtr u_node) { return out_edges_[u_node]; } | |||
| // An edge is uniquely identified by its name, and its output index and input index. | |||
| bool IsEdgeInCostGraph(const std::string &, size_t, size_t); | |||
| @@ -219,6 +217,8 @@ class CostGraph { | |||
| std::vector<OperatorInfoPtr> ops_; | |||
| std::map<std::pair<OperatorInfoPtr, OperatorInfoPtr>, std::vector<EdgePtr>> edges_; | |||
| std::vector<std::shared_ptr<CostGraph>> connected_compoents_; | |||
| std::map<OperatorInfoPtr, std::vector<EdgePtr>> out_edges_; | |||
| std::map<OperatorInfoPtr, std::vector<EdgePtr>> in_edges_; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -111,6 +111,7 @@ class OperatorInfo { | |||
| Shape dev_matrix_shape() const { return dev_matrix_shape_; } | |||
| std::vector<TensorInfo> inputs_tensor_info() const { return inputs_tensor_info_; } | |||
| std::vector<TensorInfo> outputs_tensor_info() const { return outputs_tensor_info_; } | |||
| std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost() const { return strategy_cost_; } | |||
| const std::string &name() const { return name_; } | |||
| void set_name(const std::string &name) { name_ = name; } | |||
| RankList global_device_list() const { return global_device_list_; } | |||
| @@ -22,6 +22,7 @@ | |||
| #include "parallel/device_manager.h" | |||
| #include "parallel/device_matrix.h" | |||
| #include "parallel/step_parallel.h" | |||
| #include "parallel/auto_parallel/graph_costmodel.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -46,26 +47,6 @@ Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| } | |||
| return FAILED; | |||
| } | |||
| std::vector<Dimensions> stra = strategy->GetInputDim(); | |||
| for (size_t i = 0; i < strategy_size; ++i) { | |||
| Shape sub_strategy = stra.at(i); | |||
| size_t strategy_len = sub_strategy.size(); | |||
| bool flag = false; | |||
| for (size_t j = 0; j < strategy_len; ++j) { | |||
| int32_t strategy_value = sub_strategy.at(j); | |||
| if (strategy_value > 1) { | |||
| if (flag) { | |||
| if (is_auto_parallel_) { | |||
| MS_LOG(DEBUG) << name_ << ": Only support batch parallel strategy."; | |||
| } else { | |||
| MS_LOG(ERROR) << name_ << ": Only support batch parallel strategy."; | |||
| } | |||
| return FAILED; | |||
| } | |||
| flag = true; | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -402,6 +383,41 @@ Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr | |||
| return SUCCESS; | |||
| } | |||
| void ReshapeInfo::SetCostForReshapeWithParameter() { | |||
| size_t success = 0; | |||
| for (auto &sp : sp_vector_) { | |||
| if (SetCostUnderStrategy(sp) == SUCCESS) { | |||
| success++; | |||
| MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; | |||
| PrintStrategy(sp); | |||
| } | |||
| } | |||
| } | |||
| void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy) { | |||
| MS_EXCEPTION_IF_NULL(strategy); | |||
| int32_t stage_id = strategy->GetInputStage(); | |||
| double computation_cost = | |||
| operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||
| double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | |||
| result->communication_without_parameter_ = | |||
| operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||
| result->communication_with_partial_para_ = | |||
| result->communication_without_parameter_ + | |||
| COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); | |||
| // Breaking ties for preferring data parallelization | |||
| BreakingTiesForPerferringDataParallel(strategy, result); | |||
| // refine communication cost calculation for practice | |||
| RefineForPracticalCost(result, false); | |||
| std::shared_ptr<StrategyWithCost> swc = | |||
| std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_); | |||
| swc->cost_list.push_back(result); | |||
| strategy_cost_.emplace_back(swc); | |||
| } | |||
| Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { | |||
| if (GetAttrs() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": GetAttrs failed."; | |||
| @@ -414,22 +430,14 @@ Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { | |||
| } | |||
| is_auto_parallel_ = true; | |||
| Shape input0_split; | |||
| input0_split.emplace_back(1); | |||
| (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size() - 1, 0); | |||
| (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size(), 1); | |||
| Shapes splittable_inputs = {input0_split}; | |||
| std::vector<StrategyPtr> sp_vector; | |||
| if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { | |||
| // strategy used only in the input node is parameter, | |||
| // in other case, use the input node's output_layout as input_layout. | |||
| if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector_) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; | |||
| return FAILED; | |||
| } | |||
| size_t success = 0; | |||
| for (auto &sp : sp_vector) { | |||
| if (SetCostUnderStrategy(sp) == SUCCESS) { | |||
| success++; | |||
| MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; | |||
| PrintStrategy(sp); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| @@ -50,9 +50,19 @@ class ReshapeInfo : public OperatorInfo { | |||
| output_layout_ = output_layout; | |||
| output_layout_set_flag_ = true; | |||
| } | |||
| void SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy); | |||
| void SetCostForReshapeWithParameter(); | |||
| void set_pre_operator_name(const std::string &pre_name) { pre_operator_name_ = pre_name; } | |||
| void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; } | |||
| void set_pre_operator_index(int32_t pre_index) { pre_operator_index_ = pre_index; } | |||
| void set_next_operator_index(int32_t next_index) { next_operator_index_ = next_index; } | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| Status GenerateStrategies(int32_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| std::string pre_operator_name() const { return pre_operator_name_; } | |||
| std::string next_operator_name() const { return next_operator_name_; } | |||
| int32_t pre_operator_index() const { return pre_operator_index_; } | |||
| int32_t next_operator_index() const { return next_operator_index_; } | |||
| protected: | |||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||
| @@ -73,12 +83,17 @@ class ReshapeInfo : public OperatorInfo { | |||
| Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout); | |||
| int32_t dev_num_; | |||
| int32_t pre_operator_index_; | |||
| int32_t next_operator_index_; | |||
| std::vector<int32_t> parameter_input_v_; | |||
| std::vector<StrategyPtr> sp_vector_; | |||
| Dimensions input_strategy_; | |||
| TensorLayout input_layout_; | |||
| TensorLayout output_layout_; | |||
| bool input_layout_set_flag_; | |||
| bool output_layout_set_flag_; | |||
| std::string pre_operator_name_; | |||
| std::string next_operator_name_; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -39,6 +39,7 @@ | |||
| #include "parallel/auto_parallel/rec_core/rec_partition.h" | |||
| #include "parallel/context.h" | |||
| #include "parallel/ops_info/tmp_identity_info.h" | |||
| #include "parallel/ops_info/reshape_info.h" | |||
| #include "parallel/step_parallel.h" | |||
| #include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | |||
| #include "pipeline/parse/python_adapter.h" | |||
| @@ -608,7 +609,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||
| EdgePtr edge_ptr; | |||
| MS_LOG(INFO) << "Creating edge: " << edge_name; | |||
| bool follow_strategy = ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name()); | |||
| bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) || | |||
| (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name())); | |||
| if (follow_strategy) { | |||
| // Redistribution in not allowed on the edge. | |||
| // Elementwise operators have the same strategy as their previous operators. | |||
| @@ -893,6 +895,209 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||
| } | |||
| } | |||
| bool FindReshape(const CNodePtr &cnode) { | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| return false; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { | |||
| return false; | |||
| } | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||
| if (operator_info == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; | |||
| } | |||
| if (prim->name() != RESHAPE) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| // find previous node, then obtain its strategy_cost_ vector to get its layout vector. | |||
| bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int32_t *out_index) { | |||
| // if previous node is a parameter, handle it in the outsize. | |||
| if (node->isa<Parameter>()) { | |||
| return false; | |||
| } | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||
| return false; | |||
| } | |||
| if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { | |||
| *pre_operator_info = cnode->operator_info(); | |||
| *out_index = 0; | |||
| return true; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| if (prim->name() == TUPLE_GETITEM) { | |||
| *out_index = GetTupleGetItemIndex(cnode); | |||
| // find tuple_get_item's previous node | |||
| auto pre_node = cnode->input(1); | |||
| if (!pre_node->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; | |||
| } | |||
| CNodePtr pre_cnode = pre_node->cast<CNodePtr>(); | |||
| if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) { | |||
| *pre_operator_info = pre_cnode->operator_info(); | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| for (size_t index = 0; index < cnode->inputs().size(); ++index) { | |||
| if (prim->name() == DEPEND && index != 1) { | |||
| continue; | |||
| } | |||
| if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) { | |||
| continue; | |||
| } | |||
| return true; | |||
| } | |||
| MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error"; | |||
| return false; | |||
| } | |||
| // find next node, then obtain its strategy_cost_ vector to get its layout vector. | |||
| // if reshape's output connect to several primitive, return the first layout found | |||
| bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int32_t *in_index) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_EXCEPTION_IF_NULL(cnode->func_graph()); | |||
| FuncGraphManagerPtr manager = cnode->func_graph()->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| AnfNodeIndexSet node_set = manager->node_users()[cnode]; | |||
| for (auto &node_pair : node_set) { | |||
| CNodePtr use_apply = node_pair.first->cast<CNodePtr>(); | |||
| if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) { | |||
| continue; | |||
| } | |||
| ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(prim_anf_node); | |||
| PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| MS_EXCEPTION_IF_NULL(node_prim); | |||
| MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); | |||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||
| continue; | |||
| } | |||
| if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { | |||
| MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); | |||
| *next_operator_info = use_apply->operator_info(); | |||
| *in_index = node_pair.second - 1; | |||
| return true; | |||
| } | |||
| MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) | |||
| << " " << (use_apply->operator_info() != nullptr); | |||
| if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra) { | |||
| Shape shape = pre_out_tensor_info.shape(); | |||
| Shape slice_shape = pre_out_tensor_info.slice_shape(); | |||
| for (size_t i = 0; i < shape.size(); ++i) { | |||
| if ((slice_shape[i] == 0) || (shape[i] % slice_shape[i] != 0)) { | |||
| MS_LOG(EXCEPTION) << "slice_shape is wrong in reshape operator"; | |||
| } | |||
| int32_t dim = (int32_t)(shape[i] / slice_shape[i]); | |||
| (*stra).push_back(dim); | |||
| } | |||
| } | |||
| void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||
| for (auto node : all_nodes) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (!FindReshape(cnode)) { | |||
| continue; | |||
| } | |||
| MS_ASSERT(cnode->inputs().size() == 3); | |||
| // get previous node's strategy_cost_ | |||
| auto pre_node = cnode->input(1); | |||
| int32_t out_index = 0; | |||
| OperatorInfoPtr pre_operator_info; | |||
| std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs; | |||
| if (pre_node->isa<Parameter>()) { | |||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||
| auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info); | |||
| reshape_info->SetCostForReshapeWithParameter(); | |||
| pre_operator_info = reshape_info; | |||
| pre_stra_costs = reshape_info->strategy_cost(); | |||
| } else { | |||
| if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) { | |||
| MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed"; | |||
| } | |||
| pre_stra_costs = pre_operator_info->strategy_cost(); | |||
| } | |||
| // get next node's strategy_cost_ | |||
| int32_t in_index = 0; | |||
| OperatorInfoPtr next_operator_info; | |||
| std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs; | |||
| bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index); | |||
| if (!find_next_node) { | |||
| MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed"; | |||
| } | |||
| // set input_layout and output_layout for reshape. | |||
| // init reshape and set cost for each input_layout and output_layout. | |||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||
| auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info); | |||
| reshape_info->set_pre_operator_name(pre_operator_info->name()); | |||
| reshape_info->set_pre_operator_index(out_index); | |||
| if (find_next_node) { | |||
| next_stra_costs = next_operator_info->strategy_cost(); | |||
| reshape_info->set_next_operator_name(next_operator_info->name()); | |||
| reshape_info->set_next_operator_index(in_index); | |||
| } | |||
| for (auto pre_stra_cost : pre_stra_costs) { | |||
| std::vector<TensorInfo> pre_out_tensor_infos; | |||
| if (pre_node->isa<Parameter>()) { | |||
| pre_out_tensor_infos = pre_stra_cost->inputs_ptr; | |||
| } else { | |||
| pre_out_tensor_infos = pre_stra_cost->outputs_ptr; | |||
| } | |||
| if (pre_out_tensor_infos.size() <= IntToSize(out_index)) { | |||
| MS_LOG(EXCEPTION) << "out_index is out of range of the tensor_infos in setting reshape's input_layout"; | |||
| } | |||
| TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index]; | |||
| TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout(); | |||
| reshape_info->SetInputLayout(pre_out_tensor_layout); | |||
| // infer pre_node output strategy from output_layout. | |||
| Dimensions stra; | |||
| InferStraByTensorInfo(pre_out_tensor_info, &stra); | |||
| std::vector<Dimensions> stra_inputs = {stra}; | |||
| StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs); | |||
| if (next_stra_costs.empty()) { | |||
| if (reshape_info->Init(nullptr) == FAILED) { | |||
| MS_LOG(EXCEPTION) << "Failure:operator reshape init failed"; | |||
| } | |||
| // set cost for each input_layout and output_layout pairs. | |||
| reshape_info->SetCostForReshape(reshape_stra); | |||
| continue; | |||
| } | |||
| for (auto next_stra_cost : next_stra_costs) { | |||
| std::vector<TensorInfo> next_in_tensor_infos = next_stra_cost->inputs_ptr; | |||
| if (next_in_tensor_infos.size() <= IntToSize(in_index)) { | |||
| MS_LOG(EXCEPTION) << "in_index is out of range of the tensor_infos in setting reshape's output_layout"; | |||
| } | |||
| TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; | |||
| TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout(); | |||
| reshape_info->SetOutputLayout(next_in_tensor_layout); | |||
| if (reshape_info->Init(nullptr) == FAILED) { | |||
| MS_LOG(EXCEPTION) << "Failure:operator reshape init failed"; | |||
| } | |||
| // set cost for each input_layout and output_layout pairs. | |||
| reshape_info->SetCostForReshape(reshape_stra); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { | |||
| // There are 4 meta-steps to determine the parallelization strategy for the ANF graph. | |||
| // Step 1: Traverse the ANF graph, and create NODEs for costgraph: | |||
| @@ -930,7 +1135,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||
| MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; | |||
| } | |||
| } | |||
| // reshape operator needs the next node's input_layout as its output_layout. | |||
| // and needs the previous node's output_layout as its input_layout. | |||
| ReshapeCostCompute(all_nodes); | |||
| // Step 2 | |||
| ConstructCostGraphEdges(all_nodes); | |||
| MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() | |||
| @@ -51,6 +51,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes); | |||
| void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes); | |||
| void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra); | |||
| Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | |||
| Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | |||
| @@ -219,22 +219,5 @@ TEST_F(TestReshapeInfo, CheckStrategy3) { | |||
| Status ret = reshape->Init(strategy); | |||
| ASSERT_EQ(ret, SUCCESS); | |||
| } | |||
| TEST_F(TestReshapeInfo, AutoStrategy1) { | |||
| ASSERT_EQ(reshape->GenerateStrategies(0), Status::SUCCESS); | |||
| std::vector<std::shared_ptr<StrategyWithCost>> sc = reshape->GetStrategyCost(); | |||
| Shapes splittable_inputs = {{1, 0, 0, 0}}; | |||
| std::vector<StrategyPtr> sp_vector; | |||
| Shapes inputs_shape = {{32, 512, 7, 7}}; | |||
| GenerateStrategiesForIndependentInputs(0, inputs_shape, splittable_inputs, &sp_vector); | |||
| ASSERT_EQ(sc.size(), sp_vector.size()); | |||
| for (auto stra : sp_vector) { | |||
| auto stra0 = stra->GetInputDim()[0]; | |||
| ASSERT_EQ(stra0[1], 1); | |||
| ASSERT_EQ(stra0[2], 1); | |||
| ASSERT_EQ(stra0[3], 1); | |||
| } | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -65,6 +65,193 @@ def test_reshape_matmul(): | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x) | |||
| def test_reshape_auto_1(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.relu = P.ReLU() | |||
| self.reshape = P.Reshape() | |||
| self.matmul = P.MatMul() | |||
| self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") | |||
| def construct(self, x): | |||
| out = self.relu(x) | |||
| out = self.reshape(out, (64, 28)) | |||
| out = self.matmul(out, self.matmul_weight) | |||
| return out | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| _executor.compile(net, x) | |||
| def test_reshape_auto_2(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.relu = P.ReLU() | |||
| self.reshape = P.Reshape() | |||
| self.matmul = P.MatMul() | |||
| self.add_weight = Parameter(Tensor(np.ones([128, 32]), dtype=ms.float32), name="weight1") | |||
| self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") | |||
| def construct(self, x): | |||
| out = self.relu(x) | |||
| out = self.reshape(out, (64, 28)) | |||
| out = self.matmul(out, self.matmul_weight) | |||
| out = self.reshape(out, (128, 32)) | |||
| out = out + self.add_weight | |||
| return out | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| _executor.compile(net, x) | |||
| def test_reshape_auto_3(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.relu = P.ReLU() | |||
| self.reshape = P.Reshape() | |||
| self.matmul = P.MatMul() | |||
| self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") | |||
| def construct(self, x): | |||
| out = self.relu(x) | |||
| out = self.matmul(out, self.matmul_weight) | |||
| out = self.reshape(out, (8, 8, 8, 8)) | |||
| return out | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([8*size, 28]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| _executor.compile(net, x) | |||
| def test_reshape_auto_4(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.relu = P.ReLU() | |||
| self.reshape = P.Reshape() | |||
| self.matmul = P.MatMul() | |||
| self.matmul_weight = Parameter(Tensor(np.ones([28*64]), dtype=ms.float32), name="weight") | |||
| def construct(self, x): | |||
| out = self.relu(x) | |||
| out = self.reshape(out, (64, 28)) | |||
| w = self.reshape(self.matmul_weight, (28, 64)) | |||
| out = self.matmul(out, w) | |||
| return out | |||
| if __name__ == '__main__': | |||
| test_reshape_matmul() | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([8*size, 28, 1, 1]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| _executor.compile(net, x) | |||
| def test_reshape_auto_5(): | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| predict = self.network(x, y) | |||
| return self.loss(predict) | |||
| class GradWrap(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradWrap, self).__init__() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| return C.grad_all(self.network)(x, y) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.relu = P.ReLU() | |||
| self.mul = P.Mul() | |||
| self.reshape = P.Reshape() | |||
| self.reduce_sum = P.ReduceSum() | |||
| self.wide_w = Parameter(Tensor(np.ones([4, 1024*8, 64]), dtype=ms.float32), name="weight") | |||
| def construct(self, x, y): | |||
| mask = self.reshape(y, (4, 1024*8, 1)) | |||
| w_id = self.relu(x) | |||
| wx = self.mul(w_id, mask) | |||
| wide_out = self.reshape(self.reduce_sum(wx, 1), (-1,1)) | |||
| deep_id = x + self.wide_w | |||
| vx = self.mul(deep_id, mask) | |||
| deep_in = self.reshape(vx, (-1, 1024*8*64)) | |||
| out = wide_out + deep_in | |||
| return out | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([4, 1024*size, 1]), dtype=ms.float32) | |||
| y = Tensor(np.ones([4, 1024*size,]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| _executor.compile(net, x, y) | |||
| def test_reshape_auto_6(): | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| predict = self.network(x, y) | |||
| return self.loss(predict) | |||
| class GradWrap(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradWrap, self).__init__() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| return C.grad_all(self.network)(x, y) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.relu = P.ReLU() | |||
| self.mul = P.Mul() | |||
| self.reshape = P.Reshape() | |||
| self.reduce_mean = P.ReduceMean() | |||
| self.wide_w = Parameter(Tensor(np.ones([4, 1024, 1]), dtype=ms.float32), name="weight") | |||
| def construct(self, x, y): | |||
| out1 = x + self.wide_w | |||
| w = self.reshape(self.wide_w, (4,1024)) | |||
| out1 = self.reduce_mean(out1, 1) | |||
| out1 = out1 - w | |||
| out2 = self.mul(y, w) | |||
| out = out1 + out2 | |||
| return out | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32) | |||
| y = Tensor(np.ones([4, 1024,]), dtype=ms.float32) | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| _executor.compile(net, x, y) | |||