From: @yao_yf Reviewed-by: @yangzhenzhang,@stsuteng Signed-off-by: @stsutengpull/15396/MERGE
| @@ -201,6 +201,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| // Virtual Dataset | |||
| virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(), | |||
| "virtual_dataset_eliminate", prim::kPrimVirtualDataset); | |||
| // Virtual Dataset | |||
| virtual_output_eliminate_ = | |||
| MakeSubstitution(std::make_shared<VirtualOutputEliminater>(), "virtual_output_eliminate", prim::kPrimVirtualOutput); | |||
| // Receive | |||
| receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive); | |||
| @@ -118,6 +118,8 @@ class OptimizeIRPassLib { | |||
| // virtual dataset | |||
| SubstitutionPtr virtual_dataset_eliminate_; | |||
| // virtual output | |||
| SubstitutionPtr virtual_output_eliminate_; | |||
| // Receive | |||
| SubstitutionPtr receive_eliminate_; | |||
| @@ -99,6 +99,24 @@ class VirtualDatasetEliminater : public AnfVisitor { | |||
| void Visit(const AnfNodePtr &) override {} | |||
| }; | |||
| // {prim::kPrimVirtualOutput, X} -> X | |||
| class VirtualOutputEliminater : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimVirtualOutput) || node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->inputs().size() <= 1) { | |||
| return nullptr; | |||
| } | |||
| return cnode->input(1); | |||
| } | |||
| void Visit(const AnfNodePtr &) override {} | |||
| }; | |||
| // {prim::kPrimReceive, X} -> prim::kPrimReceive | |||
| class ReceiveEliminater : public AnfVisitor { | |||
| public: | |||
| @@ -127,7 +127,7 @@ double MatMulCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, co | |||
| // this operator uses | |||
| double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | |||
| const std::vector<TensorInfo> &outputs, int64_t) const { | |||
| // In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) | |||
| // In forward phase, the computation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) | |||
| double result = 0.0; | |||
| TensorInfo output0 = outputs[0]; | |||
| Shape input0_slice_shape = inputs[0].slice_shape(); | |||
| @@ -368,7 +368,7 @@ void ReLU6Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_outpu | |||
| // Taking account of input | |||
| void TransposeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calulating 'dx', taking account of 'y' | |||
| // When calculating 'dx', taking account of 'y' | |||
| if (is_parameter_[0]) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| @@ -195,6 +195,7 @@ REGISTER(SelectInfo); | |||
| REGISTER(GatherNdInfo); | |||
| REGISTER(TopKInfo); | |||
| REGISTER(ScatterUpdateInfo); | |||
| REGISTER(VirtualOutputInfo); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -54,5 +54,6 @@ | |||
| #include "frontend/parallel/ops_info/gathernd_info.h" | |||
| #include "frontend/parallel/ops_info/topk_info.h" | |||
| #include "frontend/parallel/ops_info/scatter_update_info.h" | |||
| #include "frontend/parallel/ops_info/virtual_output_info.h" | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ | |||
| @@ -218,6 +218,7 @@ constexpr char ASSIGN_SUB[] = "AssignSub"; | |||
| constexpr char GREATER[] = "Greater"; | |||
| constexpr char UNIFORM_CANDIDATE_SAMPLER[] = "UniformCandidateSampler"; | |||
| constexpr char VIRTUAL_DATA_SET[] = "_VirtualDataset"; | |||
| constexpr char VIRTUAL_OUTPUT[] = "_VirtualOutput"; | |||
| constexpr char VIRTUAL_DATA_SET_INFO[] = "VirtualDatasetInfo"; | |||
| constexpr char SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SparseSoftmaxCrossEntropyWithLogits"; | |||
| constexpr char RELU[] = "ReLU"; | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/parallel/ops_info/virtual_output_info.h" | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "frontend/parallel/device_manager.h" | |||
| #include "frontend/parallel/device_matrix.h" | |||
| #include "frontend/parallel/step_parallel.h" | |||
| #include "frontend/parallel/context.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| Status VirtualOutputInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Invalid strategy."; | |||
| return FAILED; | |||
| } | |||
| Strategys stra = strategy->GetInputDim(); | |||
| if (stra.size() != 1) { | |||
| MS_LOG(ERROR) << name_ << ": Strategys size must be 1."; | |||
| return FAILED; | |||
| } | |||
| Dimensions strategy_first = stra.at(0); | |||
| for (auto dim = strategy_first.begin() + 1; dim != strategy_first.end(); ++dim) { | |||
| if (*dim != 1) { | |||
| MS_LOG(ERROR) << name_ << ": All dimension except the first dimension of the strategy must be 1."; | |||
| return FAILED; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef PARALLEL_OPS_INFO_OUTPUT_INFO_H_ | |||
| #define PARALLEL_OPS_INFO_OUTPUT_INFO_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "ir/value.h" | |||
| #include "frontend/parallel/ops_info/operator_info.h" | |||
| #include "frontend/parallel/ops_info/virtual_dataset_info.h" | |||
| #include "frontend/parallel/strategy.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| class VirtualOutputInfo : public VirtualDatasetInfo { | |||
| public: | |||
| VirtualOutputInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : VirtualDatasetInfo(name, inputs_shape, outputs_shape, attrs) {} | |||
| ~VirtualOutputInfo() override = default; | |||
| protected: | |||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| #endif // PARALLEL_OPS_INFO_VIRTUAL_OUTPUT_INFO_H_ | |||
| @@ -69,6 +69,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | |||
| root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) { | |||
| return changes; | |||
| } | |||
| // check whether strategy_search_mode is valid | |||
| std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode(); | |||
| if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) { | |||
| @@ -87,14 +88,17 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | |||
| TOTAL_OPS = 0; | |||
| AnfNodePtr ret = root->get_return(); | |||
| std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret); | |||
| if (ParallelInit() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Parallel init failed"; | |||
| } | |||
| // mark the forward cnodes, parallel only care these nodes | |||
| MarkForwardCNode(root); | |||
| if (!root->has_flag(TRAINING)) { | |||
| InsertVirtualOutput(root, all_nodes); | |||
| AnfNodePtr ret_after = root->get_return(); | |||
| MS_EXCEPTION_IF_NULL(ret_after); | |||
| all_nodes = DeepScopedGraphSearch(ret_after); | |||
| } | |||
| if (FindCommunicationOp(all_nodes)) { | |||
| MS_LOG(EXCEPTION) << "The graph contain communication op"; | |||
| } | |||
| @@ -163,7 +167,7 @@ bool IsSplittableOperator(const std::string &op_name) { | |||
| BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, | |||
| SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, | |||
| UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, | |||
| UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE}; | |||
| UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT}; | |||
| // clang-format on | |||
| auto iter = splittable_op.find(op_name); | |||
| @@ -239,13 +243,7 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive | |||
| StrategyMap *stra_map, const std::string &strategy_key_name) { | |||
| // In this case, the configured strategy should be extracted to help setting cost | |||
| StrategyPtr strategyPtr; | |||
| if (is_last_nodes) { | |||
| bool full_batch = ParallelContext::GetInstance()->full_batch(); | |||
| strategyPtr = GenerateBatchParallelStrategy(operator_info, prim); | |||
| if (full_batch) { | |||
| SetLastNodeStrategy(strategyPtr); | |||
| } | |||
| } else if (StrategyFound(attrs)) { | |||
| if (StrategyFound(attrs)) { | |||
| strategyPtr = parallel::ExtractStrategy(attrs); | |||
| } else { | |||
| strategyPtr = (*stra_map)[strategy_key_name]; | |||
| @@ -332,10 +330,9 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||
| bool load_strategy_from_ckpt = | |||
| StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); | |||
| // If no strategy has been configured for this operator, then candidate strategies are generated for | |||
| // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy; | |||
| // if strategy is set to load from checkpoint, it is preferred to load strategy from checkpoint. | |||
| bool is_gen_stra = (!StrategyFound(attrs) || prim->name() == CAST) && (!load_strategy_from_ckpt) && (!is_last_nodes); | |||
| if (is_gen_stra) { | |||
| // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy. | |||
| // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint . | |||
| if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) { | |||
| // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for | |||
| // BatchParallelInfo operator | |||
| operator_info->ComputeBatchSplitFlagList(); | |||
| @@ -371,11 +368,6 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||
| MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | |||
| } | |||
| } | |||
| std::vector<std::string> last_forward_node_ids; | |||
| if (!root->has_flag(TRAINING)) { | |||
| FindLastNodesUniqueId(all_nodes, &last_forward_node_ids); | |||
| MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; | |||
| } | |||
| for (auto &node : all_nodes) { | |||
| // NOTE: we only care about splittable Primitive operators | |||
| @@ -421,8 +413,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||
| (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), current_op_ptr)); | |||
| continue; | |||
| } | |||
| bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != | |||
| last_forward_node_ids.end(); | |||
| bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput); | |||
| auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map); | |||
| if (operator_info == nullptr) { | |||
| return FAILED; | |||
| @@ -496,11 +487,6 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||
| StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | |||
| } | |||
| std::vector<std::string> last_forward_node_ids; | |||
| if (!root->has_flag(TRAINING)) { | |||
| FindLastNodesUniqueId(all_nodes, &last_forward_node_ids); | |||
| MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; | |||
| } | |||
| for (auto &node : all_nodes) { | |||
| // NOTE: we only care about splittable Primitive operators | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| @@ -546,8 +532,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||
| continue; | |||
| } | |||
| // In this case, the corresponding OperatorInfo is not created, create the new one. | |||
| bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != | |||
| last_forward_node_ids.end(); | |||
| bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput); | |||
| auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map); | |||
| MS_EXCEPTION_IF_NULL(operator_info); | |||
| @@ -625,7 +625,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) { | |||
| return false; | |||
| } | |||
| // get_next is not in the forward graph, we need mark the get_next as the forward node | |||
| if (prim->name() == GET_NEXT) { | |||
| if (prim->name() == GET_NEXT || prim->name() == VIRTUAL_OUTPUT) { | |||
| return true; | |||
| } | |||
| if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) { | |||
| @@ -1004,6 +1004,55 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node | |||
| } | |||
| } | |||
| void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) { | |||
| vector<std::string> last_forward_node_ids; | |||
| vector<size_t> last_indexs; | |||
| FindLastNodesUniqueId(root, &last_forward_node_ids, &last_indexs); | |||
| MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; | |||
| for (auto &node : all_nodes) { | |||
| // here insert virtualoutput node | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| continue; | |||
| } | |||
| auto last_node_iter = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()); | |||
| if (last_node_iter == last_forward_node_ids.end()) { | |||
| continue; | |||
| } | |||
| for (size_t last_node_index = 0; last_node_index < last_forward_node_ids.size(); ++last_node_index) { | |||
| if (last_forward_node_ids[last_node_index] != cnode->UniqueId()) { | |||
| continue; | |||
| } | |||
| MS_LOG(INFO) << "find last node: " << cnode->fullname_with_scope() << ", the parallel care node is: " | |||
| << cnode->input(last_indexs[last_node_index])->fullname_with_scope(); | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { | |||
| FuncGraphManagerPtr manager = cnode->func_graph()->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto node_pair = manager->node_users()[cnode].front(); | |||
| if (!node_pair.first->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "the output of tuple_get_item is not a cnode"; | |||
| } | |||
| cnode = node_pair.first->cast<CNodePtr>(); | |||
| last_indexs[last_node_index] = size_t(node_pair.second); | |||
| } | |||
| FuncGraphPtr func_graph = node->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| OperatorParams params; | |||
| OperatorAttrs attrs; | |||
| OperatorArgs args = std::make_pair(attrs, params); | |||
| Operator op = std::make_pair(VIRTUAL_OUTPUT, args); | |||
| auto pre_node = cnode->input(last_indexs[last_node_index]); | |||
| Shapes shape_outputs = GetNodeShape(pre_node); | |||
| InsertNode(op, cnode, last_indexs[last_node_index], pre_node, func_graph, VIRTUAL_OUTPUT); | |||
| auto virtual_output_node = cnode->input(last_indexs[last_node_index]); | |||
| AbstractBasePtr virtual_output_abstract = pre_node->abstract()->Clone(); | |||
| std::shared_ptr<abstract::BaseShape> virtual_output_shape = std::make_shared<abstract::Shape>(shape_outputs[0]); | |||
| virtual_output_abstract->set_shape(virtual_output_shape); | |||
| virtual_output_node->set_abstract(virtual_output_abstract); | |||
| } | |||
| } | |||
| } | |||
| static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { | |||
| if (IsValueNode<RefKey>(node)) { | |||
| std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph); | |||
| @@ -1826,7 +1875,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(node->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (prim->name() == VIRTUAL_DATA_SET) { | |||
| if (prim->name() == VIRTUAL_DATA_SET || prim->name() == VIRTUAL_OUTPUT) { | |||
| CheckGlobalDeviceManager(); | |||
| int64_t dev_num; | |||
| if (full_batch) { | |||
| @@ -1856,32 +1905,36 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { | |||
| } | |||
| } | |||
| // find previous parallel care node. | |||
| bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) { | |||
| // find previous parallel care node's next node. | |||
| bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vector<size_t> *indexes) { | |||
| MS_EXCEPTION_IF_NULL(unique_ids); | |||
| // if previous node is a parameter, handle it in the outsize. | |||
| if (node->isa<Parameter>()) { | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(indexes); | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||
| CNodePtr pre_cnode = node->cast<CNodePtr>(); | |||
| if (!IsValueNode<Primitive>(pre_cnode->input(0))) { | |||
| return false; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| if (IsParallelCareNode(cnode) && prim->name() != MAKE_TUPLE && prim->name() != MAKE_LIST) { | |||
| unique_ids->push_back(cnode->UniqueId()); | |||
| return true; | |||
| } | |||
| bool find = false; | |||
| for (size_t index = 0; index < cnode->inputs().size(); ++index) { | |||
| if (prim->name() == DEPEND && index != 1) { | |||
| for (size_t index = 1; index < pre_cnode->inputs().size(); ++index) { | |||
| auto next_node = pre_cnode->inputs()[index]; | |||
| if (!next_node->isa<CNode>() || next_node->isa<Parameter>()) { | |||
| return false; | |||
| } | |||
| CNodePtr cnode = next_node->cast<CNodePtr>(); | |||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||
| return false; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| if (IsParallelCareNode(cnode) && prim->name() != MAKE_TUPLE && prim->name() != MAKE_LIST) { | |||
| unique_ids->push_back(pre_cnode->UniqueId()); | |||
| indexes->push_back(index); | |||
| find = true; | |||
| continue; | |||
| } | |||
| if (FindPreNodes(cnode->inputs()[index], unique_ids)) { | |||
| if (FindPreNodes(cnode, unique_ids, indexes)) { | |||
| find = true; | |||
| continue; | |||
| } | |||
| @@ -1889,20 +1942,12 @@ bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) { | |||
| return find; | |||
| } | |||
| void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids) { | |||
| void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *unique_ids, | |||
| std::vector<size_t> *indexes) { | |||
| MS_EXCEPTION_IF_NULL(unique_ids); | |||
| for (auto &node : all_nodes) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| continue; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| if (prim->name() == RETURN) { | |||
| if (!FindPreNodes(cnode, unique_ids)) { | |||
| MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph"; | |||
| } | |||
| } | |||
| CNodePtr cnode = root->get_return(); | |||
| if (!FindPreNodes(cnode, unique_ids, indexes)) { | |||
| MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph"; | |||
| } | |||
| } | |||
| @@ -1926,16 +1971,6 @@ StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const | |||
| return strategyPtr; | |||
| } | |||
| void SetLastNodeStrategy(const StrategyPtr strategyPtr) { | |||
| auto strategys = strategyPtr->GetInputDim(); | |||
| for (size_t i = 0; i < strategys.size(); ++i) { | |||
| for (size_t j = 0; j < strategys[i].size(); ++j) { | |||
| strategys[i][j] = 1; | |||
| } | |||
| } | |||
| strategyPtr->ResetInputs(strategys); | |||
| } | |||
| static bool CheckExtractInfomation(const CNodePtr &cnode) { | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| return false; | |||
| @@ -1960,11 +1995,6 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini | |||
| (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS)) { | |||
| MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | |||
| } | |||
| vector<std::string> last_forward_node_ids; | |||
| if (!is_training) { | |||
| FindLastNodesUniqueId(all_nodes, &last_forward_node_ids); | |||
| MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; | |||
| } | |||
| for (auto &node : all_nodes) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| @@ -2012,10 +2042,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini | |||
| } | |||
| bool load_strategy_from_ckpt = | |||
| StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); | |||
| bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != | |||
| last_forward_node_ids.end(); | |||
| bool full_batch = ParallelContext::GetInstance()->full_batch(); | |||
| if ((is_last_nodes && !full_batch) || (!StrategyFound(attrs) && !load_strategy_from_ckpt)) { | |||
| if ((!StrategyFound(attrs) && !load_strategy_from_ckpt)) { | |||
| MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() | |||
| << " is empty, using batch parallel"; | |||
| strategyPtr = GenerateBatchParallelStrategy(operator_, prim); | |||
| @@ -2026,9 +2053,6 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini | |||
| } | |||
| MS_EXCEPTION_IF_NULL(strategyPtr); | |||
| if (is_last_nodes && full_batch) { | |||
| SetLastNodeStrategy(strategyPtr); | |||
| } | |||
| if (operator_->Init(strategyPtr) == FAILED) { | |||
| MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; | |||
| } | |||
| @@ -3537,6 +3561,14 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||
| MS_LOG(EXCEPTION) << "The graph contain communication op"; | |||
| } | |||
| if (!root->has_flag(TRAINING)) { | |||
| InsertVirtualOutput(root, all_nodes); | |||
| AnfNodePtr ret_after = root->get_return(); | |||
| MS_EXCEPTION_IF_NULL(ret_after); | |||
| all_nodes = DeepScopedGraphSearch(ret_after); | |||
| std::reverse(all_nodes.begin(), all_nodes.end()); | |||
| } | |||
| // extract shape and strategy, set operator_info | |||
| ExtractInformation(all_nodes, root->has_flag(TRAINING)); | |||
| ReshapeInit(all_nodes); | |||
| @@ -172,7 +172,10 @@ void SetLastNodeStrategy(const StrategyPtr strategyPtr); | |||
| bool CreateGroupsByCkptFile(const std::string &file); | |||
| void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids); | |||
| void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *unique_ids, | |||
| std::vector<size_t> *indexes); | |||
| void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -195,6 +195,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| {"parallel", opt::OptPassConfig(parallel::StepParallel)}, | |||
| {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, | |||
| {"virtual_dataset", virtual_dataset}, | |||
| {"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})}, | |||
| {"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())}, | |||
| {"resolve", resolve_pass}, | |||
| {"a_after_grad", a_after_grad}, | |||
| @@ -317,6 +317,7 @@ inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared<Primitive>(" | |||
| inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | |||
| inline const PrimitivePtr kPrimVirtualAdd = std::make_shared<Primitive>("_VirtualAdd"); | |||
| inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | |||
| inline const PrimitivePtr kPrimVirtualOutput = std::make_shared<Primitive>("_VirtualOutput"); | |||
| inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send"); | |||
| inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive"); | |||
| inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | |||
| @@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta | |||
| Unique, GatherD, Identity, Range) | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | |||
| _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, | |||
| _VirtualDiv, _GetTensorSlice, _VirtualAdd, | |||
| _VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, | |||
| _HostAllGather, _HostReduceScatter) | |||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | |||
| TensorSummary, HistogramSummary, Print, Assert) | |||
| @@ -670,7 +670,7 @@ class _VirtualDataset(PrimitiveWithInfer): | |||
| """ | |||
| Auto parallel virtual dataset operator. | |||
| It would insert Broadcast operator in forward computation and be deleted before backward computation. | |||
| It would insert VirtualDataset operator in forward computation and be deleted before backward computation. | |||
| """ | |||
| @prim_attr_register | |||
| @@ -686,6 +686,22 @@ class _VirtualDataset(PrimitiveWithInfer): | |||
| virtual_dataset = _VirtualDataset() | |||
| class _VirtualOutput(PrimitiveWithInfer): | |||
| """ | |||
| Auto parallel virtual out operator. | |||
| It would insert VirtualOutput operator in forward computation and be deleted before backward computation. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init""" | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| return x_dtype | |||
| class _GetTensorSlice(PrimitiveWithInfer): | |||
| """ | |||
| @@ -74,6 +74,7 @@ def test_two_bn(): | |||
| net = NetWithLoss(Net()) | |||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| net.set_auto_parallel() | |||
| net.set_train() | |||
| set_algo_parameters(elementwise_op_strategy_follow=True) | |||
| reset_op_id() | |||
| @@ -158,4 +158,5 @@ def test_only_one_get_next(): | |||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| net = Net() | |||
| net.set_train() | |||
| compile_net(net) | |||
| @@ -0,0 +1,252 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import re | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.common.api import _executor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.parameter import Parameter | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| class DenseMutMulNet(nn.Cell): | |||
| def __init__(self): | |||
| super(DenseMutMulNet, self).__init__() | |||
| self.fc1 = nn.Dense(128, 768) | |||
| self.fc2 = nn.Dense(128, 768) | |||
| self.fc3 = nn.Dense(128, 768) | |||
| self.fc4 = nn.Dense(768, 768, has_bias=False) | |||
| self.relu4 = nn.ReLU() | |||
| self.relu5 = nn.ReLU() | |||
| self.transpose = P.Transpose() | |||
| self.matmul1 = P.MatMul() | |||
| self.matmul2 = P.MatMul() | |||
| self.fc4.matmul.shard(((1, 1), (8, 1))) | |||
| def construct(self, x): | |||
| q = self.fc1(x) | |||
| k = self.fc2(x) | |||
| v = self.fc3(x) | |||
| k = self.transpose(k, (1, 0)) | |||
| c = self.relu4(self.matmul1(q, k)) | |||
| s = self.relu5(self.matmul2(c, v)) | |||
| s = self.fc4(s) | |||
| return s | |||
| class MulNegTwoOutputNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.mul = P.Mul().shard(((2, 4), (2, 4))) | |||
| self.neg = P.Neg().shard(((2, 4),)) | |||
| self.mul_weight = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight") | |||
| def construct(self, x): | |||
| out1 = self.mul(x, self.mul_weight) | |||
| out2 = self.neg(out1) | |||
| return out1, out2 | |||
| class ReshapeMatMulNet(nn.Cell): | |||
| def __init__(self, strategy1, strategy2): | |||
| super().__init__() | |||
| self.reshape = P.Reshape() | |||
| self.matmul = P.MatMul().shard(strategy2) | |||
| self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") | |||
| # x (64, 4, 7) | |||
| def construct(self, x): | |||
| out = self.reshape(x, (64, 28)) | |||
| out = self.matmul(out, self.matmul_weight) | |||
| return out | |||
| class MatMulReshapeNet(nn.Cell): | |||
| def __init__(self, strategy1, strategy2): | |||
| super().__init__() | |||
| self.reshape = P.Reshape() | |||
| self.matmul = P.MatMul().shard(strategy1) | |||
| self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") | |||
| # x (128, 28) | |||
| def construct(self, x): | |||
| out = self.matmul(x, self.matmul_weight) | |||
| out = self.reshape(out, (64, -1)) | |||
| return out | |||
| class ReshapeMulNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.reshape = P.Reshape() | |||
| self.mul = P.Mul().shard(((1, 2, 4), (2, 4))) | |||
| self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") | |||
| def construct(self, x): | |||
| weight = self.reshape(self.mul_weight, (1, 128, 96)) | |||
| out = self.mul(weight, self.mul_weight) | |||
| return out | |||
| def compile_graph(x, net): | |||
| net.set_auto_parallel() | |||
| net.set_train(False) | |||
| _executor.compile(net, x, auto_parallel_mode=True) | |||
| strategies = _executor._get_shard_strategy(net) | |||
| return strategies | |||
| def test_dense_relu_semi_auto(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False) | |||
| net = DenseMutMulNet() | |||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||
| strategies = compile_graph(x, net) | |||
| for (k, v) in strategies.items(): | |||
| if re.search('VirtualOutput-op', k) is not None: | |||
| assert v[0][0] == 8 | |||
| def test_dense_relu_semi_auto_full_batch(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=True) | |||
| net = DenseMutMulNet() | |||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||
| strategies = compile_graph(x, net) | |||
| for (k, v) in strategies.items(): | |||
| if re.search('VirtualOutput-op', k) is not None: | |||
| assert v[0][0] == 1 | |||
| def test_dense_relu_auto(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False) | |||
| net = DenseMutMulNet() | |||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||
| strategies = compile_graph(x, net) | |||
| for (k, v) in strategies.items(): | |||
| if re.search('VirtualOutput-op', k) is not None: | |||
| assert v[0][0] == 8 | |||
| def test_dense_relu_auto_full_batch(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=True) | |||
| net = DenseMutMulNet() | |||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||
| strategies = compile_graph(x, net) | |||
| for (k, v) in strategies.items(): | |||
| if re.search('VirtualOutput-op', k) is not None: | |||
| assert v[0][0] == 1 | |||
| def test_mul_neg_two_output_semi_auto(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False) | |||
| net = MulNegTwoOutputNet() | |||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||
| strategies = compile_graph(x, net) | |||
| count = 0 | |||
| for (k, v) in strategies.items(): | |||
| if re.search('VirtualOutput-op', k) is not None: | |||
| count += 1 | |||
| assert v[0][0] == 8 | |||
| assert count == 2 | |||
| def test_mul_neg_two_output_semi_auto_full_batch(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=True) | |||
| net = MulNegTwoOutputNet() | |||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||
| strategies = compile_graph(x, net) | |||
| count = 0 | |||
| for (k, v) in strategies.items(): | |||
| if re.search('VirtualOutput-op', k) is not None: | |||
| count += 1 | |||
| assert v[0][0] == 1 | |||
| assert count == 2 | |||
| def test_mul_neg_two_output_auto(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False) | |||
| net = MulNegTwoOutputNet() | |||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||
| strategies = compile_graph(x, net) | |||
| count = 0 | |||
| for (k, v) in strategies.items(): | |||
| if re.search('VirtualOutput-op', k) is not None: | |||
| count += 1 | |||
| assert v[0][0] == 8 | |||
| assert count == 2 | |||
| def test_mul_neg_two_output_full_batch(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=True) | |||
| net = MulNegTwoOutputNet() | |||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||
| strategies = compile_graph(x, net) | |||
| count = 0 | |||
| for (k, v) in strategies.items(): | |||
| if re.search('VirtualOutput-op', k) is not None: | |||
| count += 1 | |||
| assert v[0][0] == 1 | |||
| assert count == 2 | |||
| def test_reshape_matmul_semi_auto(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False) | |||
| strategy1 = None | |||
| strategy2 = ((1, 1), (1, 8)) | |||
| net = ReshapeMatMulNet(strategy1, strategy2) | |||
| x = Tensor(np.ones([64, 4, 7]), ms.float32) | |||
| strategies = compile_graph(x, net) | |||
| for (k, v) in strategies.items(): | |||
| if re.search('VirtualOutput-op', k) is not None: | |||
| assert v[0][0] == 8 | |||
| def test_reshape_matmul_auto(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False) | |||
| strategy1 = None | |||
| strategy2 = ((1, 1), (1, 8)) | |||
| net = ReshapeMatMulNet(strategy1, strategy2) | |||
| x = Tensor(np.ones([64, 4, 7]), ms.float32) | |||
| strategies = compile_graph(x, net) | |||
| for (k, v) in strategies.items(): | |||
| if re.search('VirtualOutput-op', k) is not None: | |||
| assert v[0][0] == 8 | |||
| def test_matmul_reshape_semi_auto(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False) | |||
| strategy2 = None | |||
| strategy1 = ((1, 1), (1, 8)) | |||
| net = MatMulReshapeNet(strategy1, strategy2) | |||
| x = Tensor(np.ones([128, 28]), ms.float32) | |||
| strategies = compile_graph(x, net) | |||
| for (k, v) in strategies.items(): | |||
| if re.search('VirtualOutput-op', k) is not None: | |||
| assert v[0][0] == 8 | |||
| def test_matmul_reshape_auto(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False) | |||
| strategy2 = None | |||
| strategy1 = ((1, 1), (1, 8)) | |||
| net = MatMulReshapeNet(strategy1, strategy2) | |||
| x = Tensor(np.ones([128, 28]), ms.float32) | |||
| strategies = compile_graph(x, net) | |||
| for (k, v) in strategies.items(): | |||
| if re.search('VirtualOutput-op', k) is not None: | |||
| assert v[0][0] == 8 | |||
| def test_reshape_mul_semi_auto(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=True) | |||
| net = ReshapeMulNet() | |||
| x = Tensor(np.ones([64, 4]), ms.float32) | |||
| strategies = compile_graph(x, net) | |||
| for (k, v) in strategies.items(): | |||
| if re.search('VirtualOutput-op', k) is not None: | |||
| assert v[0][0] == 1 | |||
| def test_reshape_mul_auto(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=True) | |||
| net = ReshapeMulNet() | |||
| x = Tensor(np.ones([64, 4]), ms.float32) | |||
| strategies = compile_graph(x, net) | |||
| for (k, v) in strategies.items(): | |||
| if re.search('VirtualOutput-op', k) is not None: | |||
| assert v[0][0] == 1 | |||