From: @yao_yf Reviewed-by: @yangzhenzhang,@stsuteng Signed-off-by: @stsutengpull/15396/MERGE
| @@ -201,6 +201,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| // Virtual Dataset | // Virtual Dataset | ||||
| virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(), | virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(), | ||||
| "virtual_dataset_eliminate", prim::kPrimVirtualDataset); | "virtual_dataset_eliminate", prim::kPrimVirtualDataset); | ||||
| // Virtual Dataset | |||||
| virtual_output_eliminate_ = | |||||
| MakeSubstitution(std::make_shared<VirtualOutputEliminater>(), "virtual_output_eliminate", prim::kPrimVirtualOutput); | |||||
| // Receive | // Receive | ||||
| receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive); | receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive); | ||||
| @@ -118,6 +118,8 @@ class OptimizeIRPassLib { | |||||
| // virtual dataset | // virtual dataset | ||||
| SubstitutionPtr virtual_dataset_eliminate_; | SubstitutionPtr virtual_dataset_eliminate_; | ||||
| // virtual output | |||||
| SubstitutionPtr virtual_output_eliminate_; | |||||
| // Receive | // Receive | ||||
| SubstitutionPtr receive_eliminate_; | SubstitutionPtr receive_eliminate_; | ||||
| @@ -99,6 +99,24 @@ class VirtualDatasetEliminater : public AnfVisitor { | |||||
| void Visit(const AnfNodePtr &) override {} | void Visit(const AnfNodePtr &) override {} | ||||
| }; | }; | ||||
| // {prim::kPrimVirtualOutput, X} -> X | |||||
| class VirtualOutputEliminater : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimVirtualOutput) || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (cnode->inputs().size() <= 1) { | |||||
| return nullptr; | |||||
| } | |||||
| return cnode->input(1); | |||||
| } | |||||
| void Visit(const AnfNodePtr &) override {} | |||||
| }; | |||||
| // {prim::kPrimReceive, X} -> prim::kPrimReceive | // {prim::kPrimReceive, X} -> prim::kPrimReceive | ||||
| class ReceiveEliminater : public AnfVisitor { | class ReceiveEliminater : public AnfVisitor { | ||||
| public: | public: | ||||
| @@ -127,7 +127,7 @@ double MatMulCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, co | |||||
| // this operator uses | // this operator uses | ||||
| double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | ||||
| const std::vector<TensorInfo> &outputs, int64_t) const { | const std::vector<TensorInfo> &outputs, int64_t) const { | ||||
| // In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) | |||||
| // In forward phase, the computation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) | |||||
| double result = 0.0; | double result = 0.0; | ||||
| TensorInfo output0 = outputs[0]; | TensorInfo output0 = outputs[0]; | ||||
| Shape input0_slice_shape = inputs[0].slice_shape(); | Shape input0_slice_shape = inputs[0].slice_shape(); | ||||
| @@ -368,7 +368,7 @@ void ReLU6Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_outpu | |||||
| // Taking account of input | // Taking account of input | ||||
| void TransposeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | void TransposeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | ||||
| // When calulating 'dx', taking account of 'y' | |||||
| // When calculating 'dx', taking account of 'y' | |||||
| if (is_parameter_[0]) { | if (is_parameter_[0]) { | ||||
| is_inputs_should_in_memory_[0] = true; | is_inputs_should_in_memory_[0] = true; | ||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | ||||
| @@ -195,6 +195,7 @@ REGISTER(SelectInfo); | |||||
| REGISTER(GatherNdInfo); | REGISTER(GatherNdInfo); | ||||
| REGISTER(TopKInfo); | REGISTER(TopKInfo); | ||||
| REGISTER(ScatterUpdateInfo); | REGISTER(ScatterUpdateInfo); | ||||
| REGISTER(VirtualOutputInfo); | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -54,5 +54,6 @@ | |||||
| #include "frontend/parallel/ops_info/gathernd_info.h" | #include "frontend/parallel/ops_info/gathernd_info.h" | ||||
| #include "frontend/parallel/ops_info/topk_info.h" | #include "frontend/parallel/ops_info/topk_info.h" | ||||
| #include "frontend/parallel/ops_info/scatter_update_info.h" | #include "frontend/parallel/ops_info/scatter_update_info.h" | ||||
| #include "frontend/parallel/ops_info/virtual_output_info.h" | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ | #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ | ||||
| @@ -218,6 +218,7 @@ constexpr char ASSIGN_SUB[] = "AssignSub"; | |||||
| constexpr char GREATER[] = "Greater"; | constexpr char GREATER[] = "Greater"; | ||||
| constexpr char UNIFORM_CANDIDATE_SAMPLER[] = "UniformCandidateSampler"; | constexpr char UNIFORM_CANDIDATE_SAMPLER[] = "UniformCandidateSampler"; | ||||
| constexpr char VIRTUAL_DATA_SET[] = "_VirtualDataset"; | constexpr char VIRTUAL_DATA_SET[] = "_VirtualDataset"; | ||||
| constexpr char VIRTUAL_OUTPUT[] = "_VirtualOutput"; | |||||
| constexpr char VIRTUAL_DATA_SET_INFO[] = "VirtualDatasetInfo"; | constexpr char VIRTUAL_DATA_SET_INFO[] = "VirtualDatasetInfo"; | ||||
| constexpr char SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SparseSoftmaxCrossEntropyWithLogits"; | constexpr char SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SparseSoftmaxCrossEntropyWithLogits"; | ||||
| constexpr char RELU[] = "ReLU"; | constexpr char RELU[] = "ReLU"; | ||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "frontend/parallel/ops_info/virtual_output_info.h" | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "frontend/parallel/device_manager.h" | |||||
| #include "frontend/parallel/device_matrix.h" | |||||
| #include "frontend/parallel/step_parallel.h" | |||||
| #include "frontend/parallel/context.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace parallel { | |||||
| Status VirtualOutputInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy."; | |||||
| return FAILED; | |||||
| } | |||||
| Strategys stra = strategy->GetInputDim(); | |||||
| if (stra.size() != 1) { | |||||
| MS_LOG(ERROR) << name_ << ": Strategys size must be 1."; | |||||
| return FAILED; | |||||
| } | |||||
| Dimensions strategy_first = stra.at(0); | |||||
| for (auto dim = strategy_first.begin() + 1; dim != strategy_first.end(); ++dim) { | |||||
| if (*dim != 1) { | |||||
| MS_LOG(ERROR) << name_ << ": All dimension except the first dimension of the strategy must be 1."; | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace parallel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARALLEL_OPS_INFO_OUTPUT_INFO_H_ | |||||
| #define PARALLEL_OPS_INFO_OUTPUT_INFO_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "ir/value.h" | |||||
| #include "frontend/parallel/ops_info/operator_info.h" | |||||
| #include "frontend/parallel/ops_info/virtual_dataset_info.h" | |||||
| #include "frontend/parallel/strategy.h" | |||||
| namespace mindspore { | |||||
| namespace parallel { | |||||
| class VirtualOutputInfo : public VirtualDatasetInfo { | |||||
| public: | |||||
| VirtualOutputInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||||
| const PrimitiveAttrs &attrs) | |||||
| : VirtualDatasetInfo(name, inputs_shape, outputs_shape, attrs) {} | |||||
| ~VirtualOutputInfo() override = default; | |||||
| protected: | |||||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||||
| }; | |||||
| } // namespace parallel | |||||
| } // namespace mindspore | |||||
| #endif // PARALLEL_OPS_INFO_VIRTUAL_OUTPUT_INFO_H_ | |||||
| @@ -69,6 +69,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | |||||
| root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) { | root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) { | ||||
| return changes; | return changes; | ||||
| } | } | ||||
| // check whether strategy_search_mode is valid | // check whether strategy_search_mode is valid | ||||
| std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode(); | std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode(); | ||||
| if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) { | if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) { | ||||
| @@ -87,14 +88,17 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | |||||
| TOTAL_OPS = 0; | TOTAL_OPS = 0; | ||||
| AnfNodePtr ret = root->get_return(); | AnfNodePtr ret = root->get_return(); | ||||
| std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret); | std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret); | ||||
| if (ParallelInit() != SUCCESS) { | if (ParallelInit() != SUCCESS) { | ||||
| MS_LOG(EXCEPTION) << "Parallel init failed"; | MS_LOG(EXCEPTION) << "Parallel init failed"; | ||||
| } | } | ||||
| // mark the forward cnodes, parallel only care these nodes | // mark the forward cnodes, parallel only care these nodes | ||||
| MarkForwardCNode(root); | MarkForwardCNode(root); | ||||
| if (!root->has_flag(TRAINING)) { | |||||
| InsertVirtualOutput(root, all_nodes); | |||||
| AnfNodePtr ret_after = root->get_return(); | |||||
| MS_EXCEPTION_IF_NULL(ret_after); | |||||
| all_nodes = DeepScopedGraphSearch(ret_after); | |||||
| } | |||||
| if (FindCommunicationOp(all_nodes)) { | if (FindCommunicationOp(all_nodes)) { | ||||
| MS_LOG(EXCEPTION) << "The graph contain communication op"; | MS_LOG(EXCEPTION) << "The graph contain communication op"; | ||||
| } | } | ||||
| @@ -163,7 +167,7 @@ bool IsSplittableOperator(const std::string &op_name) { | |||||
| BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, | BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, | ||||
| SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, | SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, | ||||
| UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, | UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, | ||||
| UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE}; | |||||
| UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT}; | |||||
| // clang-format on | // clang-format on | ||||
| auto iter = splittable_op.find(op_name); | auto iter = splittable_op.find(op_name); | ||||
| @@ -239,13 +243,7 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive | |||||
| StrategyMap *stra_map, const std::string &strategy_key_name) { | StrategyMap *stra_map, const std::string &strategy_key_name) { | ||||
| // In this case, the configured strategy should be extracted to help setting cost | // In this case, the configured strategy should be extracted to help setting cost | ||||
| StrategyPtr strategyPtr; | StrategyPtr strategyPtr; | ||||
| if (is_last_nodes) { | |||||
| bool full_batch = ParallelContext::GetInstance()->full_batch(); | |||||
| strategyPtr = GenerateBatchParallelStrategy(operator_info, prim); | |||||
| if (full_batch) { | |||||
| SetLastNodeStrategy(strategyPtr); | |||||
| } | |||||
| } else if (StrategyFound(attrs)) { | |||||
| if (StrategyFound(attrs)) { | |||||
| strategyPtr = parallel::ExtractStrategy(attrs); | strategyPtr = parallel::ExtractStrategy(attrs); | ||||
| } else { | } else { | ||||
| strategyPtr = (*stra_map)[strategy_key_name]; | strategyPtr = (*stra_map)[strategy_key_name]; | ||||
| @@ -332,10 +330,9 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||||
| bool load_strategy_from_ckpt = | bool load_strategy_from_ckpt = | ||||
| StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); | StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); | ||||
| // If no strategy has been configured for this operator, then candidate strategies are generated for | // If no strategy has been configured for this operator, then candidate strategies are generated for | ||||
| // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy; | |||||
| // if strategy is set to load from checkpoint, it is preferred to load strategy from checkpoint. | |||||
| bool is_gen_stra = (!StrategyFound(attrs) || prim->name() == CAST) && (!load_strategy_from_ckpt) && (!is_last_nodes); | |||||
| if (is_gen_stra) { | |||||
| // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy. | |||||
| // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint . | |||||
| if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) { | |||||
| // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for | // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for | ||||
| // BatchParallelInfo operator | // BatchParallelInfo operator | ||||
| operator_info->ComputeBatchSplitFlagList(); | operator_info->ComputeBatchSplitFlagList(); | ||||
| @@ -371,11 +368,6 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||||
| MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | ||||
| } | } | ||||
| } | } | ||||
| std::vector<std::string> last_forward_node_ids; | |||||
| if (!root->has_flag(TRAINING)) { | |||||
| FindLastNodesUniqueId(all_nodes, &last_forward_node_ids); | |||||
| MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; | |||||
| } | |||||
| for (auto &node : all_nodes) { | for (auto &node : all_nodes) { | ||||
| // NOTE: we only care about splittable Primitive operators | // NOTE: we only care about splittable Primitive operators | ||||
| @@ -421,8 +413,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||||
| (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), current_op_ptr)); | (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), current_op_ptr)); | ||||
| continue; | continue; | ||||
| } | } | ||||
| bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != | |||||
| last_forward_node_ids.end(); | |||||
| bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput); | |||||
| auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map); | auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map); | ||||
| if (operator_info == nullptr) { | if (operator_info == nullptr) { | ||||
| return FAILED; | return FAILED; | ||||
| @@ -496,11 +487,6 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||||
| StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { | StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { | ||||
| MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | ||||
| } | } | ||||
| std::vector<std::string> last_forward_node_ids; | |||||
| if (!root->has_flag(TRAINING)) { | |||||
| FindLastNodesUniqueId(all_nodes, &last_forward_node_ids); | |||||
| MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; | |||||
| } | |||||
| for (auto &node : all_nodes) { | for (auto &node : all_nodes) { | ||||
| // NOTE: we only care about splittable Primitive operators | // NOTE: we only care about splittable Primitive operators | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| @@ -546,8 +532,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||||
| continue; | continue; | ||||
| } | } | ||||
| // In this case, the corresponding OperatorInfo is not created, create the new one. | // In this case, the corresponding OperatorInfo is not created, create the new one. | ||||
| bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != | |||||
| last_forward_node_ids.end(); | |||||
| bool is_last_nodes = IsPrimitiveCNode(cnode, prim::kPrimVirtualOutput); | |||||
| auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map); | auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map); | ||||
| MS_EXCEPTION_IF_NULL(operator_info); | MS_EXCEPTION_IF_NULL(operator_info); | ||||
| @@ -625,7 +625,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| // get_next is not in the forward graph, we need mark the get_next as the forward node | // get_next is not in the forward graph, we need mark the get_next as the forward node | ||||
| if (prim->name() == GET_NEXT) { | |||||
| if (prim->name() == GET_NEXT || prim->name() == VIRTUAL_OUTPUT) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) { | if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) { | ||||
| @@ -1004,6 +1004,55 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node | |||||
| } | } | ||||
| } | } | ||||
| void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) { | |||||
| vector<std::string> last_forward_node_ids; | |||||
| vector<size_t> last_indexs; | |||||
| FindLastNodesUniqueId(root, &last_forward_node_ids, &last_indexs); | |||||
| MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; | |||||
| for (auto &node : all_nodes) { | |||||
| // here insert virtualoutput node | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto last_node_iter = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()); | |||||
| if (last_node_iter == last_forward_node_ids.end()) { | |||||
| continue; | |||||
| } | |||||
| for (size_t last_node_index = 0; last_node_index < last_forward_node_ids.size(); ++last_node_index) { | |||||
| if (last_forward_node_ids[last_node_index] != cnode->UniqueId()) { | |||||
| continue; | |||||
| } | |||||
| MS_LOG(INFO) << "find last node: " << cnode->fullname_with_scope() << ", the parallel care node is: " | |||||
| << cnode->input(last_indexs[last_node_index])->fullname_with_scope(); | |||||
| if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { | |||||
| FuncGraphManagerPtr manager = cnode->func_graph()->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto node_pair = manager->node_users()[cnode].front(); | |||||
| if (!node_pair.first->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "the output of tuple_get_item is not a cnode"; | |||||
| } | |||||
| cnode = node_pair.first->cast<CNodePtr>(); | |||||
| last_indexs[last_node_index] = size_t(node_pair.second); | |||||
| } | |||||
| FuncGraphPtr func_graph = node->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| OperatorParams params; | |||||
| OperatorAttrs attrs; | |||||
| OperatorArgs args = std::make_pair(attrs, params); | |||||
| Operator op = std::make_pair(VIRTUAL_OUTPUT, args); | |||||
| auto pre_node = cnode->input(last_indexs[last_node_index]); | |||||
| Shapes shape_outputs = GetNodeShape(pre_node); | |||||
| InsertNode(op, cnode, last_indexs[last_node_index], pre_node, func_graph, VIRTUAL_OUTPUT); | |||||
| auto virtual_output_node = cnode->input(last_indexs[last_node_index]); | |||||
| AbstractBasePtr virtual_output_abstract = pre_node->abstract()->Clone(); | |||||
| std::shared_ptr<abstract::BaseShape> virtual_output_shape = std::make_shared<abstract::Shape>(shape_outputs[0]); | |||||
| virtual_output_abstract->set_shape(virtual_output_shape); | |||||
| virtual_output_node->set_abstract(virtual_output_abstract); | |||||
| } | |||||
| } | |||||
| } | |||||
| static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { | static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { | ||||
| if (IsValueNode<RefKey>(node)) { | if (IsValueNode<RefKey>(node)) { | ||||
| std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph); | std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph); | ||||
| @@ -1826,7 +1875,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { | |||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(node->input(0)); | PrimitivePtr prim = GetValueNode<PrimitivePtr>(node->input(0)); | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| if (prim->name() == VIRTUAL_DATA_SET) { | |||||
| if (prim->name() == VIRTUAL_DATA_SET || prim->name() == VIRTUAL_OUTPUT) { | |||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| int64_t dev_num; | int64_t dev_num; | ||||
| if (full_batch) { | if (full_batch) { | ||||
| @@ -1856,32 +1905,36 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { | |||||
| } | } | ||||
| } | } | ||||
| // find previous parallel care node. | |||||
| bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) { | |||||
| // find previous parallel care node's next node. | |||||
| bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vector<size_t> *indexes) { | |||||
| MS_EXCEPTION_IF_NULL(unique_ids); | MS_EXCEPTION_IF_NULL(unique_ids); | ||||
| // if previous node is a parameter, handle it in the outsize. | |||||
| if (node->isa<Parameter>()) { | |||||
| return false; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(indexes); | |||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||||
| CNodePtr pre_cnode = node->cast<CNodePtr>(); | |||||
| if (!IsValueNode<Primitive>(pre_cnode->input(0))) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||||
| if (IsParallelCareNode(cnode) && prim->name() != MAKE_TUPLE && prim->name() != MAKE_LIST) { | |||||
| unique_ids->push_back(cnode->UniqueId()); | |||||
| return true; | |||||
| } | |||||
| bool find = false; | bool find = false; | ||||
| for (size_t index = 0; index < cnode->inputs().size(); ++index) { | |||||
| if (prim->name() == DEPEND && index != 1) { | |||||
| for (size_t index = 1; index < pre_cnode->inputs().size(); ++index) { | |||||
| auto next_node = pre_cnode->inputs()[index]; | |||||
| if (!next_node->isa<CNode>() || next_node->isa<Parameter>()) { | |||||
| return false; | |||||
| } | |||||
| CNodePtr cnode = next_node->cast<CNodePtr>(); | |||||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||||
| return false; | |||||
| } | |||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||||
| if (IsParallelCareNode(cnode) && prim->name() != MAKE_TUPLE && prim->name() != MAKE_LIST) { | |||||
| unique_ids->push_back(pre_cnode->UniqueId()); | |||||
| indexes->push_back(index); | |||||
| find = true; | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (FindPreNodes(cnode->inputs()[index], unique_ids)) { | |||||
| if (FindPreNodes(cnode, unique_ids, indexes)) { | |||||
| find = true; | find = true; | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -1889,20 +1942,12 @@ bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) { | |||||
| return find; | return find; | ||||
| } | } | ||||
| void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids) { | |||||
| void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *unique_ids, | |||||
| std::vector<size_t> *indexes) { | |||||
| MS_EXCEPTION_IF_NULL(unique_ids); | MS_EXCEPTION_IF_NULL(unique_ids); | ||||
| for (auto &node : all_nodes) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||||
| continue; | |||||
| } | |||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||||
| if (prim->name() == RETURN) { | |||||
| if (!FindPreNodes(cnode, unique_ids)) { | |||||
| MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph"; | |||||
| } | |||||
| } | |||||
| CNodePtr cnode = root->get_return(); | |||||
| if (!FindPreNodes(cnode, unique_ids, indexes)) { | |||||
| MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph"; | |||||
| } | } | ||||
| } | } | ||||
| @@ -1926,16 +1971,6 @@ StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const | |||||
| return strategyPtr; | return strategyPtr; | ||||
| } | } | ||||
| void SetLastNodeStrategy(const StrategyPtr strategyPtr) { | |||||
| auto strategys = strategyPtr->GetInputDim(); | |||||
| for (size_t i = 0; i < strategys.size(); ++i) { | |||||
| for (size_t j = 0; j < strategys[i].size(); ++j) { | |||||
| strategys[i][j] = 1; | |||||
| } | |||||
| } | |||||
| strategyPtr->ResetInputs(strategys); | |||||
| } | |||||
| static bool CheckExtractInfomation(const CNodePtr &cnode) { | static bool CheckExtractInfomation(const CNodePtr &cnode) { | ||||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | ||||
| return false; | return false; | ||||
| @@ -1960,11 +1995,6 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini | |||||
| (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS)) { | (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS)) { | ||||
| MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | ||||
| } | } | ||||
| vector<std::string> last_forward_node_ids; | |||||
| if (!is_training) { | |||||
| FindLastNodesUniqueId(all_nodes, &last_forward_node_ids); | |||||
| MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict"; | |||||
| } | |||||
| for (auto &node : all_nodes) { | for (auto &node : all_nodes) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| @@ -2012,10 +2042,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini | |||||
| } | } | ||||
| bool load_strategy_from_ckpt = | bool load_strategy_from_ckpt = | ||||
| StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); | StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); | ||||
| bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != | |||||
| last_forward_node_ids.end(); | |||||
| bool full_batch = ParallelContext::GetInstance()->full_batch(); | |||||
| if ((is_last_nodes && !full_batch) || (!StrategyFound(attrs) && !load_strategy_from_ckpt)) { | |||||
| if ((!StrategyFound(attrs) && !load_strategy_from_ckpt)) { | |||||
| MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() | MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() | ||||
| << " is empty, using batch parallel"; | << " is empty, using batch parallel"; | ||||
| strategyPtr = GenerateBatchParallelStrategy(operator_, prim); | strategyPtr = GenerateBatchParallelStrategy(operator_, prim); | ||||
| @@ -2026,9 +2053,6 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini | |||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(strategyPtr); | MS_EXCEPTION_IF_NULL(strategyPtr); | ||||
| if (is_last_nodes && full_batch) { | |||||
| SetLastNodeStrategy(strategyPtr); | |||||
| } | |||||
| if (operator_->Init(strategyPtr) == FAILED) { | if (operator_->Init(strategyPtr) == FAILED) { | ||||
| MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; | MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; | ||||
| } | } | ||||
| @@ -3537,6 +3561,14 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||||
| MS_LOG(EXCEPTION) << "The graph contain communication op"; | MS_LOG(EXCEPTION) << "The graph contain communication op"; | ||||
| } | } | ||||
| if (!root->has_flag(TRAINING)) { | |||||
| InsertVirtualOutput(root, all_nodes); | |||||
| AnfNodePtr ret_after = root->get_return(); | |||||
| MS_EXCEPTION_IF_NULL(ret_after); | |||||
| all_nodes = DeepScopedGraphSearch(ret_after); | |||||
| std::reverse(all_nodes.begin(), all_nodes.end()); | |||||
| } | |||||
| // extract shape and strategy, set operator_info | // extract shape and strategy, set operator_info | ||||
| ExtractInformation(all_nodes, root->has_flag(TRAINING)); | ExtractInformation(all_nodes, root->has_flag(TRAINING)); | ||||
| ReshapeInit(all_nodes); | ReshapeInit(all_nodes); | ||||
| @@ -172,7 +172,10 @@ void SetLastNodeStrategy(const StrategyPtr strategyPtr); | |||||
| bool CreateGroupsByCkptFile(const std::string &file); | bool CreateGroupsByCkptFile(const std::string &file); | ||||
| void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids); | |||||
| void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *unique_ids, | |||||
| std::vector<size_t> *indexes); | |||||
| void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes); | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -195,6 +195,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| {"parallel", opt::OptPassConfig(parallel::StepParallel)}, | {"parallel", opt::OptPassConfig(parallel::StepParallel)}, | ||||
| {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, | {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, | ||||
| {"virtual_dataset", virtual_dataset}, | {"virtual_dataset", virtual_dataset}, | ||||
| {"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})}, | |||||
| {"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())}, | {"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())}, | ||||
| {"resolve", resolve_pass}, | {"resolve", resolve_pass}, | ||||
| {"a_after_grad", a_after_grad}, | {"a_after_grad", a_after_grad}, | ||||
| @@ -317,6 +317,7 @@ inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared<Primitive>(" | |||||
| inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | ||||
| inline const PrimitivePtr kPrimVirtualAdd = std::make_shared<Primitive>("_VirtualAdd"); | inline const PrimitivePtr kPrimVirtualAdd = std::make_shared<Primitive>("_VirtualAdd"); | ||||
| inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | ||||
| inline const PrimitivePtr kPrimVirtualOutput = std::make_shared<Primitive>("_VirtualOutput"); | |||||
| inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send"); | inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send"); | ||||
| inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive"); | inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive"); | ||||
| inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | ||||
| @@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta | |||||
| Unique, GatherD, Identity, Range) | Unique, GatherD, Identity, Range) | ||||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | ||||
| _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, | _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, | ||||
| _VirtualDiv, _GetTensorSlice, _VirtualAdd, | |||||
| _VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, | |||||
| _HostAllGather, _HostReduceScatter) | _HostAllGather, _HostReduceScatter) | ||||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | ||||
| TensorSummary, HistogramSummary, Print, Assert) | TensorSummary, HistogramSummary, Print, Assert) | ||||
| @@ -670,7 +670,7 @@ class _VirtualDataset(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| Auto parallel virtual dataset operator. | Auto parallel virtual dataset operator. | ||||
| It would insert Broadcast operator in forward computation and be deleted before backward computation. | |||||
| It would insert VirtualDataset operator in forward computation and be deleted before backward computation. | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -686,6 +686,22 @@ class _VirtualDataset(PrimitiveWithInfer): | |||||
| virtual_dataset = _VirtualDataset() | virtual_dataset = _VirtualDataset() | ||||
| class _VirtualOutput(PrimitiveWithInfer): | |||||
| """ | |||||
| Auto parallel virtual out operator. | |||||
| It would insert VirtualOutput operator in forward computation and be deleted before backward computation. | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init""" | |||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype): | |||||
| return x_dtype | |||||
| class _GetTensorSlice(PrimitiveWithInfer): | class _GetTensorSlice(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -74,6 +74,7 @@ def test_two_bn(): | |||||
| net = NetWithLoss(Net()) | net = NetWithLoss(Net()) | ||||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | x = Tensor(np.ones([64, 64]), dtype=ms.float32) | ||||
| net.set_auto_parallel() | net.set_auto_parallel() | ||||
| net.set_train() | |||||
| set_algo_parameters(elementwise_op_strategy_follow=True) | set_algo_parameters(elementwise_op_strategy_follow=True) | ||||
| reset_op_id() | reset_op_id() | ||||
| @@ -158,4 +158,5 @@ def test_only_one_get_next(): | |||||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | context.set_auto_parallel_context(device_num=4, global_rank=0) | ||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | ||||
| net = Net() | net = Net() | ||||
| net.set_train() | |||||
| compile_net(net) | compile_net(net) | ||||
| @@ -0,0 +1,252 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import re | |||||
| import numpy as np | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.common.api import _executor | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.parameter import Parameter | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| class DenseMutMulNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(DenseMutMulNet, self).__init__() | |||||
| self.fc1 = nn.Dense(128, 768) | |||||
| self.fc2 = nn.Dense(128, 768) | |||||
| self.fc3 = nn.Dense(128, 768) | |||||
| self.fc4 = nn.Dense(768, 768, has_bias=False) | |||||
| self.relu4 = nn.ReLU() | |||||
| self.relu5 = nn.ReLU() | |||||
| self.transpose = P.Transpose() | |||||
| self.matmul1 = P.MatMul() | |||||
| self.matmul2 = P.MatMul() | |||||
| self.fc4.matmul.shard(((1, 1), (8, 1))) | |||||
| def construct(self, x): | |||||
| q = self.fc1(x) | |||||
| k = self.fc2(x) | |||||
| v = self.fc3(x) | |||||
| k = self.transpose(k, (1, 0)) | |||||
| c = self.relu4(self.matmul1(q, k)) | |||||
| s = self.relu5(self.matmul2(c, v)) | |||||
| s = self.fc4(s) | |||||
| return s | |||||
| class MulNegTwoOutputNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.mul = P.Mul().shard(((2, 4), (2, 4))) | |||||
| self.neg = P.Neg().shard(((2, 4),)) | |||||
| self.mul_weight = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight") | |||||
| def construct(self, x): | |||||
| out1 = self.mul(x, self.mul_weight) | |||||
| out2 = self.neg(out1) | |||||
| return out1, out2 | |||||
| class ReshapeMatMulNet(nn.Cell): | |||||
| def __init__(self, strategy1, strategy2): | |||||
| super().__init__() | |||||
| self.reshape = P.Reshape() | |||||
| self.matmul = P.MatMul().shard(strategy2) | |||||
| self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") | |||||
| # x (64, 4, 7) | |||||
| def construct(self, x): | |||||
| out = self.reshape(x, (64, 28)) | |||||
| out = self.matmul(out, self.matmul_weight) | |||||
| return out | |||||
| class MatMulReshapeNet(nn.Cell): | |||||
| def __init__(self, strategy1, strategy2): | |||||
| super().__init__() | |||||
| self.reshape = P.Reshape() | |||||
| self.matmul = P.MatMul().shard(strategy1) | |||||
| self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") | |||||
| # x (128, 28) | |||||
| def construct(self, x): | |||||
| out = self.matmul(x, self.matmul_weight) | |||||
| out = self.reshape(out, (64, -1)) | |||||
| return out | |||||
| class ReshapeMulNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.reshape = P.Reshape() | |||||
| self.mul = P.Mul().shard(((1, 2, 4), (2, 4))) | |||||
| self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") | |||||
| def construct(self, x): | |||||
| weight = self.reshape(self.mul_weight, (1, 128, 96)) | |||||
| out = self.mul(weight, self.mul_weight) | |||||
| return out | |||||
| def compile_graph(x, net): | |||||
| net.set_auto_parallel() | |||||
| net.set_train(False) | |||||
| _executor.compile(net, x, auto_parallel_mode=True) | |||||
| strategies = _executor._get_shard_strategy(net) | |||||
| return strategies | |||||
| def test_dense_relu_semi_auto(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False) | |||||
| net = DenseMutMulNet() | |||||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||||
| strategies = compile_graph(x, net) | |||||
| for (k, v) in strategies.items(): | |||||
| if re.search('VirtualOutput-op', k) is not None: | |||||
| assert v[0][0] == 8 | |||||
| def test_dense_relu_semi_auto_full_batch(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=True) | |||||
| net = DenseMutMulNet() | |||||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||||
| strategies = compile_graph(x, net) | |||||
| for (k, v) in strategies.items(): | |||||
| if re.search('VirtualOutput-op', k) is not None: | |||||
| assert v[0][0] == 1 | |||||
| def test_dense_relu_auto(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False) | |||||
| net = DenseMutMulNet() | |||||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||||
| strategies = compile_graph(x, net) | |||||
| for (k, v) in strategies.items(): | |||||
| if re.search('VirtualOutput-op', k) is not None: | |||||
| assert v[0][0] == 8 | |||||
| def test_dense_relu_auto_full_batch(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=True) | |||||
| net = DenseMutMulNet() | |||||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||||
| strategies = compile_graph(x, net) | |||||
| for (k, v) in strategies.items(): | |||||
| if re.search('VirtualOutput-op', k) is not None: | |||||
| assert v[0][0] == 1 | |||||
| def test_mul_neg_two_output_semi_auto(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False) | |||||
| net = MulNegTwoOutputNet() | |||||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||||
| strategies = compile_graph(x, net) | |||||
| count = 0 | |||||
| for (k, v) in strategies.items(): | |||||
| if re.search('VirtualOutput-op', k) is not None: | |||||
| count += 1 | |||||
| assert v[0][0] == 8 | |||||
| assert count == 2 | |||||
| def test_mul_neg_two_output_semi_auto_full_batch(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=True) | |||||
| net = MulNegTwoOutputNet() | |||||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||||
| strategies = compile_graph(x, net) | |||||
| count = 0 | |||||
| for (k, v) in strategies.items(): | |||||
| if re.search('VirtualOutput-op', k) is not None: | |||||
| count += 1 | |||||
| assert v[0][0] == 1 | |||||
| assert count == 2 | |||||
| def test_mul_neg_two_output_auto(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False) | |||||
| net = MulNegTwoOutputNet() | |||||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||||
| strategies = compile_graph(x, net) | |||||
| count = 0 | |||||
| for (k, v) in strategies.items(): | |||||
| if re.search('VirtualOutput-op', k) is not None: | |||||
| count += 1 | |||||
| assert v[0][0] == 8 | |||||
| assert count == 2 | |||||
| def test_mul_neg_two_output_full_batch(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=True) | |||||
| net = MulNegTwoOutputNet() | |||||
| x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) | |||||
| strategies = compile_graph(x, net) | |||||
| count = 0 | |||||
| for (k, v) in strategies.items(): | |||||
| if re.search('VirtualOutput-op', k) is not None: | |||||
| count += 1 | |||||
| assert v[0][0] == 1 | |||||
| assert count == 2 | |||||
| def test_reshape_matmul_semi_auto(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False) | |||||
| strategy1 = None | |||||
| strategy2 = ((1, 1), (1, 8)) | |||||
| net = ReshapeMatMulNet(strategy1, strategy2) | |||||
| x = Tensor(np.ones([64, 4, 7]), ms.float32) | |||||
| strategies = compile_graph(x, net) | |||||
| for (k, v) in strategies.items(): | |||||
| if re.search('VirtualOutput-op', k) is not None: | |||||
| assert v[0][0] == 8 | |||||
| def test_reshape_matmul_auto(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False) | |||||
| strategy1 = None | |||||
| strategy2 = ((1, 1), (1, 8)) | |||||
| net = ReshapeMatMulNet(strategy1, strategy2) | |||||
| x = Tensor(np.ones([64, 4, 7]), ms.float32) | |||||
| strategies = compile_graph(x, net) | |||||
| for (k, v) in strategies.items(): | |||||
| if re.search('VirtualOutput-op', k) is not None: | |||||
| assert v[0][0] == 8 | |||||
| def test_matmul_reshape_semi_auto(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False) | |||||
| strategy2 = None | |||||
| strategy1 = ((1, 1), (1, 8)) | |||||
| net = MatMulReshapeNet(strategy1, strategy2) | |||||
| x = Tensor(np.ones([128, 28]), ms.float32) | |||||
| strategies = compile_graph(x, net) | |||||
| for (k, v) in strategies.items(): | |||||
| if re.search('VirtualOutput-op', k) is not None: | |||||
| assert v[0][0] == 8 | |||||
| def test_matmul_reshape_auto(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False) | |||||
| strategy2 = None | |||||
| strategy1 = ((1, 1), (1, 8)) | |||||
| net = MatMulReshapeNet(strategy1, strategy2) | |||||
| x = Tensor(np.ones([128, 28]), ms.float32) | |||||
| strategies = compile_graph(x, net) | |||||
| for (k, v) in strategies.items(): | |||||
| if re.search('VirtualOutput-op', k) is not None: | |||||
| assert v[0][0] == 8 | |||||
| def test_reshape_mul_semi_auto(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=True) | |||||
| net = ReshapeMulNet() | |||||
| x = Tensor(np.ones([64, 4]), ms.float32) | |||||
| strategies = compile_graph(x, net) | |||||
| for (k, v) in strategies.items(): | |||||
| if re.search('VirtualOutput-op', k) is not None: | |||||
| assert v[0][0] == 1 | |||||
| def test_reshape_mul_auto(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=True) | |||||
| net = ReshapeMulNet() | |||||
| x = Tensor(np.ones([64, 4]), ms.float32) | |||||
| strategies = compile_graph(x, net) | |||||
| for (k, v) in strategies.items(): | |||||
| if re.search('VirtualOutput-op', k) is not None: | |||||
| assert v[0][0] == 1 | |||||