| @@ -52,7 +52,11 @@ class Primitive : public Named { | |||
| : Named(name), signatures_(), prim_type_(prim_type) {} | |||
| Primitive(const Primitive &prim) | |||
| : Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {} | |||
| : Named(prim), | |||
| attrs_(prim.attrs_), | |||
| signatures_(prim.signatures_), | |||
| instance_name_(prim.instance_name_), | |||
| prim_type_(prim.prim_type_) {} | |||
| MS_DECLARE_PARENT(Primitive, Named); | |||
| @@ -56,6 +56,8 @@ void ParallelContext::Reset() { | |||
| parameter_broadcast_ = false; | |||
| parameter_broadcast_is_set_ = false; | |||
| enable_all_reduce_fusion_ = false; | |||
| strategy_ckpt_load_file_ = ""; | |||
| strategy_ckpt_save_file_ = ""; | |||
| } | |||
| void ParallelContext::set_device_num(int32_t device_num) { | |||
| @@ -103,6 +105,14 @@ void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) { | |||
| parameter_broadcast_is_set_ = true; | |||
| } | |||
| void ParallelContext::set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file) { | |||
| strategy_ckpt_load_file_ = strategy_ckpt_load_file; | |||
| } | |||
| void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file) { | |||
| strategy_ckpt_save_file_ = strategy_ckpt_save_file; | |||
| } | |||
| void ParallelContext::set_all_reduce_fusion_split_indices(const std::vector<uint32_t> indices) { | |||
| all_reduce_fusion_split_indices_ = indices; | |||
| } | |||
| @@ -85,6 +85,11 @@ class ParallelContext { | |||
| } | |||
| bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; } | |||
| void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file); | |||
| std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; } | |||
| void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); | |||
| std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } | |||
| void Reset(); | |||
| private: | |||
| @@ -105,6 +110,8 @@ class ParallelContext { | |||
| bool enable_all_reduce_fusion_; | |||
| std::vector<uint32_t> all_reduce_fusion_split_indices_; | |||
| std::vector<uint32_t> all_reduce_fusion_split_sizes_; | |||
| std::string strategy_ckpt_load_file_; | |||
| std::string strategy_ckpt_save_file_; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -40,6 +40,7 @@ | |||
| #include "parallel/context.h" | |||
| #include "parallel/ops_info/tmp_identity_info.h" | |||
| #include "parallel/step_parallel.h" | |||
| #include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | |||
| #include "pipeline/parse/python_adapter.h" | |||
| #include "pipeline/pipeline.h" | |||
| @@ -339,7 +340,7 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) { | |||
| return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name()); | |||
| } | |||
| OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) { | |||
| OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto attrs = prim->attrs(); | |||
| @@ -385,9 +386,15 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||
| operator_info->set_input_value(input_value); | |||
| operator_info->set_outputs_dtype(cnode->Type()); | |||
| operator_info->set_cnode(cnode); | |||
| // key of strategy map | |||
| std::string instance_name = prim->instance_name(); | |||
| std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name; | |||
| bool load_strategy_from_ckpt = | |||
| StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); | |||
| // If no strategy has been configured for this operator, then candidate strategies are generated for | |||
| // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy | |||
| if (!StrategyFound(attrs) || prim->name() == CAST) { | |||
| // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy. | |||
| // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint . | |||
| if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) { | |||
| // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for | |||
| // BatchParallelInfo operator | |||
| operator_info->ComputeBatchSplitFlagList(); | |||
| @@ -397,7 +404,12 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||
| } | |||
| } else { | |||
| // In this case, the configured strategy should be extracted to help setting cost | |||
| StrategyPtr strategyPtr = parallel::ExtractStrategy(attrs); | |||
| StrategyPtr strategyPtr; | |||
| if (load_strategy_from_ckpt) { | |||
| strategyPtr = (*stra_map)[strategy_key_name]; | |||
| } else { | |||
| strategyPtr = parallel::ExtractStrategy(attrs); | |||
| } | |||
| if (strategyPtr != nullptr) { | |||
| if (prim->name() == RESHAPE) { | |||
| MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; | |||
| @@ -433,7 +445,13 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||
| entire_costgraph->SetDeviceMemoryAndCostParameter(); | |||
| // The map from CNode's UniqueId to its operatorInfo | |||
| std::map<std::string, OperatorInfoPtr> from_cnode_to_info; | |||
| // extract strategy from checkpoint for multi-train | |||
| StrategyMap stra_map; | |||
| if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { | |||
| if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | |||
| } | |||
| } | |||
| // Step 1 | |||
| for (auto &node : all_nodes) { | |||
| // NOTE: we only care about splittable Primitive operators | |||
| @@ -451,7 +469,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||
| auto search_cnode = from_cnode_to_info.find(cnode->UniqueId()); | |||
| if (search_cnode == from_cnode_to_info.end()) { | |||
| auto operator_info = CreateTheOperatorInfo(prim, cnode); | |||
| auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); | |||
| if (operator_info == nullptr) { | |||
| return FAILED; | |||
| } | |||
| @@ -486,7 +504,13 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||
| entire_costgraph->SetDeviceMemoryAndCostParameter(); | |||
| // The map from CNode's UniqueIdThroughCopy to its operatorInfo | |||
| std::map<std::string, OperatorInfoPtr> from_cnode_to_info; | |||
| // extract strategy from checkpoint for multi-train | |||
| StrategyMap stra_map; | |||
| if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { | |||
| if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | |||
| } | |||
| } | |||
| for (auto &node : all_nodes) { | |||
| // NOTE: we only care about splittable Primitive operators | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| @@ -504,7 +528,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||
| auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy()); | |||
| if (search_cnode == from_cnode_to_info.end()) { | |||
| // In this case, the corresponding OperatorInfo is not created, create the new one. | |||
| auto operator_info = CreateTheOperatorInfo(prim, cnode); | |||
| auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); | |||
| if (operator_info == nullptr) { | |||
| return FAILED; | |||
| } | |||
| @@ -1378,6 +1378,13 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { | |||
| } | |||
| void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||
| // load strategy map from checkpoint | |||
| StrategyMap stra_map; | |||
| if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { | |||
| if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | |||
| } | |||
| } | |||
| for (auto &node : all_nodes) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| @@ -1414,7 +1421,14 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||
| (void)cnode->set_operator_info(operator_); | |||
| continue; | |||
| } | |||
| if (!StrategyFound(attrs)) { | |||
| // load strategy checkpoint | |||
| // key of strategy map | |||
| std::string instance_name = prim->instance_name(); | |||
| std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name; | |||
| bool load_strategy_from_ckpt = | |||
| StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); | |||
| if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { | |||
| MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() | |||
| << " is empty, using batch parallel"; | |||
| std::shared_ptr<std::vector<Dimensions>> strategy_v_ptr = operator_->GenerateBatchStrategies(); | |||
| @@ -1432,6 +1446,8 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||
| MS_LOG(INFO) << "node " << node->ToString() << " prim " << prim->name() << " batch parallel strategy is " | |||
| << attrs[GEN_STRATEGY]->ToString(); | |||
| strategyPtr = NewStrategy(0, *strategy_v_ptr); | |||
| } else if (load_strategy_from_ckpt) { | |||
| strategyPtr = stra_map[strategy_key_name]; | |||
| } else { | |||
| strategyPtr = ExtractStrategy(attrs); | |||
| } | |||
| @@ -2022,53 +2038,29 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo | |||
| } | |||
| } | |||
| void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_LOG(INFO) << "Save strategy to checkpoint begin"; | |||
| StrategyMap straMap; | |||
| auto ret = func_graph->get_return(); | |||
| auto all_nodes = DeepScopedGraphSearch(ret); | |||
| for (auto &node : all_nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| continue; | |||
| } | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||
| if (operator_info) { | |||
| if (prim->instance_name().empty()) { | |||
| continue; | |||
| bool NodeWithParameter(const CNodePtr &node) { | |||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | |||
| for (auto input : node_inputs) { | |||
| if (input->isa<Parameter>()) { | |||
| auto input_parameter = input->cast<ParameterPtr>(); | |||
| if (input_parameter->has_default()) { | |||
| return py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad")); | |||
| } | |||
| std::string instance_name = prim->instance_name(); | |||
| StrategyPtr strategyPtr = operator_info->strategy(); | |||
| MS_EXCEPTION_IF_NULL(node->scope()); | |||
| std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name; | |||
| straMap[node_name] = strategyPtr; | |||
| } | |||
| } | |||
| if (StrategyCheckpoint::GetInstance().Save(straMap) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; | |||
| } | |||
| return false; | |||
| } | |||
| void RestoreStrategy(const FuncGraphPtr &func_graph) { | |||
| void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_LOG(INFO) << "Extract strategy from checkpoint begin"; | |||
| StrategyMap straMap; | |||
| if (StrategyCheckpoint::GetInstance().Load(&straMap) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | |||
| } | |||
| if (StrategyCheckpoint::GetInstance().RemoveCheckPoint() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Remove strategy checkpoint failed"; | |||
| } | |||
| MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; | |||
| StrategyMap stra_map; | |||
| auto ret = func_graph->get_return(); | |||
| auto all_nodes = DeepScopedGraphSearch(ret); | |||
| for (auto &node : all_nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0)) || !NodeWithParameter(cnode)) { | |||
| continue; | |||
| } | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| @@ -2076,18 +2068,18 @@ void RestoreStrategy(const FuncGraphPtr &func_graph) { | |||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||
| if (operator_info) { | |||
| if (prim->instance_name().empty()) { | |||
| continue; | |||
| MS_LOG(EXCEPTION) << "Node with parameter to checkpoint strategy needs instance name"; | |||
| } | |||
| std::string instance_name = prim->instance_name(); | |||
| StrategyPtr strategyPtr = operator_info->strategy(); | |||
| MS_EXCEPTION_IF_NULL(node->scope()); | |||
| std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name; | |||
| MS_LOG(INFO) << "Node name is " << node_name; | |||
| if (straMap.find(node_name) != straMap.end()) { | |||
| StrategyPtr strategyPtr = straMap[node_name]; | |||
| operator_info->set_strategy(strategyPtr); | |||
| } | |||
| stra_map[node_name] = strategyPtr; | |||
| } | |||
| } | |||
| if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; | |||
| } | |||
| } | |||
| void SetForwardFlag(const std::vector<AnfNodePtr> &all_nodes) { | |||
| @@ -2264,14 +2256,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||
| // extract shape and strategy, set operator_info | |||
| ExtractInformation(all_nodes); | |||
| ReshapeInit(all_nodes); | |||
| // extract strategy from checkpoint for multi-train | |||
| if (StrategyCheckpoint::GetInstance().CheckPointOn() && StrategyCheckpoint::GetInstance().CheckPointExit()) { | |||
| RestoreStrategy(root); | |||
| } | |||
| } | |||
| // save strategy as checkpoint for multi-train | |||
| if (StrategyCheckpoint::GetInstance().CheckPointOn() && | |||
| StrategyCheckpoint::GetInstance().GetCurrentTrainTime() < StrategyCheckpoint::GetInstance().GetTrainTimes()) { | |||
| if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { | |||
| CheckpointStrategy(root); | |||
| } | |||
| @@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes); | |||
| void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, | |||
| const FuncGraphManagerPtr &manager); | |||
| void RestoreStrategy(const FuncGraphPtr &func_graph); | |||
| bool NodeWithParameter(const CNodePtr &node); | |||
| void CheckpointStrategy(const FuncGraphPtr &func_graph); | |||
| @@ -29,30 +29,32 @@ namespace mindspore { | |||
| namespace parallel { | |||
| StrategyCheckpoint &StrategyCheckpoint::GetInstance() { | |||
| static StrategyCheckpoint instance = StrategyCheckpoint(); | |||
| if (ParallelContext::GetInstance() != nullptr) { | |||
| instance.load_file_ = ParallelContext::GetInstance()->strategy_ckpt_load_file(); | |||
| instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty(); | |||
| instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file(); | |||
| instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty(); | |||
| } | |||
| return instance; | |||
| } | |||
| bool StrategyCheckpoint::CheckPointExit() const { | |||
| std::ifstream fin(path_); | |||
| bool StrategyCheckpoint::CheckPointExit(const std::string path) const { | |||
| std::ifstream fin(path); | |||
| if (fin) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| Status StrategyCheckpoint::RemoveCheckPoint() const { | |||
| if (std::remove(common::SafeCStr(path_)) == 0) { | |||
| return SUCCESS; | |||
| } | |||
| return FAILED; | |||
| } | |||
| Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { | |||
| if (strategy_map == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr"; | |||
| } | |||
| if (!CheckPointExit(load_file_)) { | |||
| MS_LOG(EXCEPTION) << "CheckPoint file is not found"; | |||
| } | |||
| straspb::ParallelStrategyMap parallel_strategy_map; | |||
| std::fstream input(path_, std::ios::in | std::ios::binary); | |||
| std::fstream input(load_file_, std::ios::in | std::ios::binary); | |||
| if (!parallel_strategy_map.ParseFromIstream(&input)) { | |||
| MS_LOG(ERROR) << "Load strategy file failed"; | |||
| return FAILED; | |||
| @@ -77,14 +79,14 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { | |||
| StrategyPtr strategy = NewStrategy(stage, strategy_inputs); | |||
| (*strategy_map)[node_name] = strategy; | |||
| current_train_time_ = (int32_t)parallel_strategy_map.train_time(); | |||
| current_stage_ = (int32_t)parallel_strategy_map.current_stage(); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { | |||
| straspb::ParallelStrategyMap parallel_strategy_map; | |||
| parallel_strategy_map.set_train_time(IntToUint(++current_train_time_)); | |||
| parallel_strategy_map.set_current_stage(IntToUint(++current_stage_)); | |||
| for (auto &node_stra : strategy_map) { | |||
| straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item(); | |||
| MS_EXCEPTION_IF_NULL(parallel_strategy_item); | |||
| @@ -100,7 +102,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { | |||
| } | |||
| } | |||
| } | |||
| std::fstream output(path_, std::ios::out | std::ios::trunc | std::ios::binary); | |||
| std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); | |||
| if (!parallel_strategy_map.SerializeToOstream(&output)) { | |||
| MS_LOG(ERROR) << "Save strategy file failed"; | |||
| return FAILED; | |||
| @@ -21,43 +21,37 @@ | |||
| #include <unordered_map> | |||
| #include "parallel/ops_info/ops_utils.h" | |||
| #include "parallel/strategy.h" | |||
| #include "parallel/context.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| constexpr char DEFAULT_CHECKPOINT_PATH[] = "./strategys.ckpt"; | |||
| using StrategyMap = std::unordered_map<std::string, StrategyPtr>; | |||
| class StrategyCheckpoint { | |||
| public: | |||
| StrategyCheckpoint() : path_(DEFAULT_CHECKPOINT_PATH), current_train_time_(1) { | |||
| train_times_ = 1; | |||
| checkpoint_on_ = false; | |||
| const char *train_times_str = std::getenv("PARALLEL_TRAIN_TIMES"); | |||
| if (train_times_str != nullptr && std::stoi(train_times_str) > 0) { | |||
| train_times_ = std::stoi(train_times_str); | |||
| } | |||
| const char *checkpoint_on_str = std::getenv("PARALLEL_CHECKPOINT_ON"); | |||
| if (checkpoint_on_str != nullptr) { | |||
| checkpoint_on_ = (std::string(checkpoint_on_str) == "on"); | |||
| } | |||
| StrategyCheckpoint() { | |||
| current_stage_ = 0; | |||
| load_file_ = ""; | |||
| load_checkpoint_on_ = false; | |||
| save_file_ = ""; | |||
| save_checkpoint_on_ = false; | |||
| } | |||
| ~StrategyCheckpoint() = default; | |||
| bool CheckPointExit() const; | |||
| Status RemoveCheckPoint() const; | |||
| Status Load(StrategyMap *strategy_map); | |||
| Status Save(const StrategyMap &strategy_map); | |||
| static StrategyCheckpoint &GetInstance(); | |||
| int32_t GetTrainTimes() const { return train_times_; } | |||
| int32_t GetCurrentTrainTime() const { return current_train_time_; } | |||
| bool CheckPointOn() const { return checkpoint_on_; } | |||
| bool LoadCheckPointOn() const { return load_checkpoint_on_; } | |||
| bool SaveCheckPointOn() const { return save_checkpoint_on_; } | |||
| private: | |||
| std::string path_; | |||
| bool checkpoint_on_; | |||
| // total train times for a train, get from Environmental variable:TRAIN_TIME, please export it | |||
| int32_t train_times_; | |||
| int32_t current_train_time_; | |||
| std::string load_file_; | |||
| std::string save_file_; | |||
| bool load_checkpoint_on_; | |||
| bool save_checkpoint_on_; | |||
| bool CheckPointExit(const std::string path) const; | |||
| int32_t current_stage_; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -191,6 +191,12 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set, | |||
| "Get parameter broadcast is set.") | |||
| .def("set_parameter_broadcast", &ParallelContext::set_parameter_broadcast, "Set parameter broadcast.") | |||
| .def("set_strategy_ckpt_load_file", &ParallelContext::set_strategy_ckpt_load_file, | |||
| "Set strategy checkpoint load file.") | |||
| .def("set_strategy_ckpt_save_file", &ParallelContext::set_strategy_ckpt_save_file, | |||
| "Set strategy checkpoint save file.") | |||
| .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.") | |||
| .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") | |||
| .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); | |||
| (void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext") | |||
| @@ -33,6 +33,6 @@ message ParallelStrategyItem { | |||
| } | |||
| message ParallelStrategyMap { | |||
| required uint32 train_time = 1; | |||
| required uint32 current_stage = 1; | |||
| repeated ParallelStrategyItem parallel_strategy_item = 2; | |||
| } | |||
| @@ -404,7 +404,7 @@ def _context(): | |||
| @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str, | |||
| parameter_broadcast=bool) | |||
| parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str) | |||
| def set_auto_parallel_context(**kwargs): | |||
| """ | |||
| Set auto parallel context. | |||
| @@ -436,6 +436,8 @@ def set_auto_parallel_context(**kwargs): | |||
| parameter_broadcast (bool): Indicating whether to broadcast parameters before training. | |||
| "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter | |||
| broadcast. Default: False. | |||
| strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' | |||
| strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' | |||
| Raises: | |||
| ValueError: If input key is not attribute in auto parallel context. | |||
| @@ -447,6 +449,8 @@ def set_auto_parallel_context(**kwargs): | |||
| >>> context.set_auto_parallel_context(cast_before_mirror=False) | |||
| >>> context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| >>> context.set_auto_parallel_context(parameter_broadcast=False) | |||
| >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") | |||
| >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") | |||
| """ | |||
| _set_auto_parallel_context(**kwargs) | |||
| @@ -477,6 +481,8 @@ def reset_auto_parallel_context(): | |||
| - cast_before_mirror: True. | |||
| - parallel_mode: "stand_alone". | |||
| - parameter_broadcast: False. | |||
| - strategy_ckpt_load_file: "". | |||
| - strategy_ckpt_save_file: "". | |||
| """ | |||
| _reset_auto_parallel_context() | |||
| @@ -88,6 +88,8 @@ class Primitive(Primitive_): | |||
| for name in self.attrs: | |||
| value = self.attrs[name] | |||
| cloned.add_prim_attr(name, value) | |||
| if hasattr(self, 'instance_name'): | |||
| cloned.set_prim_instance_name(self.instance_name) | |||
| return cloned | |||
| def add_prim_attr(self, name, value): | |||
| @@ -208,6 +208,36 @@ class _AutoParallelContext: | |||
| self.check_context_handle() | |||
| return self._context_handle.get_parameter_broadcast() | |||
| def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file): | |||
| """ | |||
| Set strategy checkpoint load path. | |||
| Args: | |||
| strategy_ckpt_load_file (bool): Path to load parallel strategy checkpoint. | |||
| """ | |||
| self.check_context_handle() | |||
| self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file) | |||
| def get_strategy_ckpt_load_file(self): | |||
| """Get strategy checkpoint load path.""" | |||
| self.check_context_handle() | |||
| return self._context_handle.get_strategy_ckpt_load_file() | |||
| def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file): | |||
| """ | |||
| Set strategy checkpoint save path. | |||
| Args: | |||
| strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint. | |||
| """ | |||
| self.check_context_handle() | |||
| self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file) | |||
| def get_strategy_ckpt_save_file(self): | |||
| """Get strategy checkpoint save path.""" | |||
| self.check_context_handle() | |||
| return self._context_handle.get_strategy_ckpt_save_file() | |||
| def get_parameter_broadcast_is_set(self): | |||
| """Get parameter broadcast is set or not.""" | |||
| self.check_context_handle() | |||
| @@ -315,7 +345,9 @@ _set_auto_parallel_context_func_map = { | |||
| "cast_before_mirror": auto_parallel_context().set_cast_before_mirror, | |||
| "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, | |||
| "parallel_mode": auto_parallel_context().set_parallel_mode, | |||
| "parameter_broadcast": auto_parallel_context().set_parameter_broadcast} | |||
| "parameter_broadcast": auto_parallel_context().set_parameter_broadcast, | |||
| "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, | |||
| "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file} | |||
| _get_auto_parallel_context_func_map = { | |||
| @@ -325,11 +357,14 @@ _get_auto_parallel_context_func_map = { | |||
| "cast_before_mirror": auto_parallel_context().get_cast_before_mirror, | |||
| "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, | |||
| "parallel_mode": auto_parallel_context().get_parallel_mode, | |||
| "parameter_broadcast": auto_parallel_context().get_parameter_broadcast} | |||
| "parameter_broadcast": auto_parallel_context().get_parameter_broadcast, | |||
| "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, | |||
| "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file} | |||
| @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, | |||
| loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool) | |||
| loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool, | |||
| strategy_ckpt_load_file=str, strategy_ckpt_save_file=str) | |||
| def _set_auto_parallel_context(**kwargs): | |||
| """ | |||
| Set auto parallel context. | |||
| @@ -360,6 +395,8 @@ def _set_auto_parallel_context(**kwargs): | |||
| parameter_broadcast (bool): Indicating whether to broadcast parameters before training. | |||
| "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter | |||
| broadcast. Default: False. | |||
| strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' | |||
| strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' | |||
| Raises: | |||
| ValueError: If input key is not attribute in auto parallel context. | |||
| @@ -400,5 +437,7 @@ def _reset_auto_parallel_context(): | |||
| - cast_before_mirror: True. | |||
| - parallel_mode: "stand_alone". | |||
| - parameter_broadcast: False. | |||
| - strategy_ckpt_load_file: "" | |||
| - strategy_ckpt_save_file: "" | |||
| """ | |||
| auto_parallel_context().reset() | |||
| @@ -25,9 +25,7 @@ StrategyCheckpoint& StrategyCheckpoint::GetInstance() { | |||
| return instance; | |||
| } | |||
| bool StrategyCheckpoint::CheckPointExit() const { return false; } | |||
| Status StrategyCheckpoint::RemoveCheckPoint() const { return SUCCESS; } | |||
| bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return false; } | |||
| Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; } | |||
| @@ -14,10 +14,10 @@ | |||
| import numpy as np | |||
| from mindspore import context | |||
| from mindspore.context import set_auto_parallel_context | |||
| from mindspore.context import set_auto_parallel_context, reset_auto_parallel_context | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore import Tensor | |||
| from mindspore import Tensor, Parameter | |||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||
| import mindspore as ms | |||
| from mindspore.common.api import _executor | |||
| @@ -25,17 +25,15 @@ from mindspore.ops import composite as C | |||
| # model_parallel test | |||
| # export PARALLEL_CHECKPOINT_ON=on | |||
| # export PARALLEL_TRAIN_TIMES=4 | |||
| def test_six_matmul(): | |||
| def test_six_matmul_save(): | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x1, x2, x3, x4, x5, x6, x7): | |||
| predict = self.network(x1, x2, x3, x4, x5, x6, x7) | |||
| def construct(self, x1, x6): | |||
| predict = self.network(x1, x6) | |||
| return self.loss(predict) | |||
| @@ -44,8 +42,8 @@ def test_six_matmul(): | |||
| super(GradWrap, self).__init__() | |||
| self.network = network | |||
| def construct(self, x1, x2, x3, x4, x5, x6, x7): | |||
| return C.grad_all(self.network)(x1, x2, x3, x4, x5, x6, x7) | |||
| def construct(self, x1, x6): | |||
| return C.grad_all(self.network)(x1, x6) | |||
| class Net(nn.Cell): | |||
| def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5, strategy6): | |||
| @@ -56,45 +54,46 @@ def test_six_matmul(): | |||
| self.matmul4 = P.MatMul().set_strategy(strategy4) | |||
| self.matmul5 = P.MatMul().set_strategy(strategy5) | |||
| self.matmul6 = P.MatMul().set_strategy(strategy6) | |||
| def construct(self, x1, x2, x3, x4, x5, x6, x7): | |||
| out = self.matmul1(x1, x2) | |||
| out = self.matmul2(out, x3) | |||
| out = self.matmul3(out, x4) | |||
| out = self.matmul4(out, x5) | |||
| out = self.matmul5(out, x6) | |||
| out = self.matmul6(out, x7) | |||
| self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1") | |||
| self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2") | |||
| self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") | |||
| self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") | |||
| self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") | |||
| def construct(self, x1, x6): | |||
| out = self.matmul1(x1, self.weight1) | |||
| out = self.matmul2(out, self.weight2) | |||
| out = self.matmul3(out, self.weight3) | |||
| out = self.matmul4(out, self.weight4) | |||
| out = self.matmul5(out, self.weight5) | |||
| out = self.matmul6(out, x6) | |||
| return out | |||
| set_auto_parallel_context(device_num=512, global_rank=0) | |||
| reset_auto_parallel_context() | |||
| set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt") | |||
| strategy1 = ((8, 1), (1, 1)) | |||
| strategy2 = ((1, 8), (8, 1)) | |||
| strategy3 = ((2, 2), (2, 2)) | |||
| strategy4 = ((4, 2), (2, 4)) | |||
| strategy5 = ((2, 4), (4, 2)) | |||
| strategy6 = ((4, 4), (4, 4)) | |||
| strategy4 = ((1, 1), (1, 8)) | |||
| strategy5 = ((4, 2), (2, 1)) | |||
| strategy6 = ((4, 1), (1, 2)) | |||
| net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3, strategy4, strategy5, strategy6))) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| x1 = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||
| x2 = Tensor(np.ones([32, 64]), dtype=ms.float32) | |||
| x3 = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| x4 = Tensor(np.ones([64, 128]), dtype=ms.float32) | |||
| x5 = Tensor(np.ones([128, 64]), dtype=ms.float32) | |||
| x6 = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| x7 = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| _executor.compile(net, x1, x2, x3, x4, x5, x6, x7) | |||
| x1 = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| x6 = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||
| _executor.compile(net, x1, x6) | |||
| # remove matmul2 | |||
| def test_six_matmul_repeated1(): | |||
| # remove matmul2, add matmul7 | |||
| def test_six_matmul_load(): | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x1, x2, x4, x5, x6, x7): | |||
| predict = self.network(x1, x2, x4, x5, x6, x7) | |||
| def construct(self, x1, x6, x7): | |||
| predict = self.network(x1, x6, x7) | |||
| return self.loss(predict) | |||
| @@ -103,53 +102,58 @@ def test_six_matmul_repeated1(): | |||
| super(GradWrap, self).__init__() | |||
| self.network = network | |||
| def construct(self, x1, x2, x4, x5, x6, x7): | |||
| return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7) | |||
| def construct(self, x1, x6, x7): | |||
| return C.grad_all(self.network)(x1, x6, x7) | |||
| class Net(nn.Cell): | |||
| def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6): | |||
| def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7): | |||
| super().__init__() | |||
| self.matmul1 = P.MatMul().set_strategy(strategy1) | |||
| self.matmul3 = P.MatMul().set_strategy(strategy3) | |||
| self.matmul4 = P.MatMul().set_strategy(strategy4) | |||
| self.matmul5 = P.MatMul().set_strategy(strategy5) | |||
| self.matmul6 = P.MatMul().set_strategy(strategy6) | |||
| def construct(self, x1, x2, x4, x5, x6, x7): | |||
| out = self.matmul1(x1, x2) | |||
| out = self.matmul3(out, x4) | |||
| out = self.matmul4(out, x5) | |||
| out = self.matmul5(out, x6) | |||
| out = self.matmul6(out, x7) | |||
| self.matmul7 = P.MatMul().set_strategy(strategy7) | |||
| self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1") | |||
| self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") | |||
| self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") | |||
| self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") | |||
| def construct(self, x1, x6, x7): | |||
| out = self.matmul1(x1, self.weight1) | |||
| out = self.matmul3(out, self.weight3) | |||
| out = self.matmul4(out, self.weight4) | |||
| out = self.matmul5(out, self.weight5) | |||
| out = self.matmul6(out, x6) | |||
| out = self.matmul7(out, x7) | |||
| return out | |||
| set_auto_parallel_context(device_num=512, global_rank=0) | |||
| reset_auto_parallel_context() | |||
| set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt") | |||
| strategy1 = ((8, 1), (1, 1)) | |||
| strategy3 = ((8, 1), (1, 1)) | |||
| strategy4 = ((8, 1), (1, 1)) | |||
| strategy5 = ((8, 1), (1, 1)) | |||
| strategy6 = ((8, 1), (1, 1)) | |||
| net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6))) | |||
| strategy7 = ((8, 1), (1, 1)) | |||
| net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7))) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| x1 = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||
| x2 = Tensor(np.ones([32, 64]), dtype=ms.float32) | |||
| x4 = Tensor(np.ones([64, 128]), dtype=ms.float32) | |||
| x5 = Tensor(np.ones([128, 64]), dtype=ms.float32) | |||
| x6 = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| x1 = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| x6 = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||
| x7 = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| _executor.compile(net, x1, x2, x4, x5, x6, x7) | |||
| _executor.compile(net, x1, x6, x7) | |||
| # add matmul7 | |||
| def test_six_matmul_repeated2(): | |||
| # model_parallel test | |||
| def test_six_matmul_save_auto(): | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x1, x2, x4, x5, x6, x7, x8): | |||
| predict = self.network(x1, x2, x4, x5, x6, x7, x8) | |||
| def construct(self, x1, x6): | |||
| predict = self.network(x1, x6) | |||
| return self.loss(predict) | |||
| @@ -158,60 +162,52 @@ def test_six_matmul_repeated2(): | |||
| super(GradWrap, self).__init__() | |||
| self.network = network | |||
| def construct(self, x1, x2, x4, x5, x6, x7, x8): | |||
| return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7, x8) | |||
| def construct(self, x1, x6): | |||
| return C.grad_all(self.network)(x1, x6) | |||
| class Net(nn.Cell): | |||
| def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.matmul1 = P.MatMul().set_strategy(strategy1) | |||
| self.matmul3 = P.MatMul().set_strategy(strategy3) | |||
| self.matmul4 = P.MatMul().set_strategy(strategy4) | |||
| self.matmul5 = P.MatMul().set_strategy(strategy5) | |||
| self.matmul6 = P.MatMul().set_strategy(strategy6) | |||
| self.matmul7 = P.MatMul().set_strategy(strategy7) | |||
| def construct(self, x1, x2, x4, x5, x6, x7, x8): | |||
| out = self.matmul1(x1, x2) | |||
| out = self.matmul3(out, x4) | |||
| out = self.matmul4(out, x5) | |||
| out = self.matmul5(out, x6) | |||
| out = self.matmul6(out, x7) | |||
| out = self.matmul7(out, x8) | |||
| self.matmul1 = P.MatMul() | |||
| self.matmul2 = P.MatMul() | |||
| self.matmul3 = P.MatMul() | |||
| self.matmul4 = P.MatMul() | |||
| self.matmul5 = P.MatMul() | |||
| self.matmul6 = P.MatMul() | |||
| self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1") | |||
| self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2") | |||
| self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") | |||
| self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") | |||
| self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") | |||
| def construct(self, x1, x6): | |||
| out = self.matmul1(x1, self.weight1) | |||
| out = self.matmul2(out, self.weight2) | |||
| out = self.matmul3(out, self.weight3) | |||
| out = self.matmul4(out, self.weight4) | |||
| out = self.matmul5(out, self.weight5) | |||
| out = self.matmul6(out, x6) | |||
| return out | |||
| set_auto_parallel_context(device_num=512, global_rank=0) | |||
| strategy1 = ((8, 1), (1, 1)) | |||
| strategy3 = ((8, 1), (1, 1)) | |||
| strategy4 = ((8, 1), (1, 1)) | |||
| strategy5 = ((8, 1), (1, 1)) | |||
| strategy6 = ((8, 1), (1, 1)) | |||
| strategy7 = ((8, 1), (1, 1)) | |||
| net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7))) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| x1 = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||
| x2 = Tensor(np.ones([32, 64]), dtype=ms.float32) | |||
| x4 = Tensor(np.ones([64, 128]), dtype=ms.float32) | |||
| x5 = Tensor(np.ones([128, 64]), dtype=ms.float32) | |||
| x6 = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| x7 = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| x8 = Tensor(np.ones([32, 128]), dtype=ms.float32) | |||
| _executor.compile(net, x1, x2, x4, x5, x6, x7, x8) | |||
| reset_auto_parallel_context() | |||
| set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.ckpt") | |||
| net = GradWrap(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| x1 = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| x6 = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||
| _executor.compile(net, x1, x6) | |||
| # add scope2 | |||
| def test_six_matmul_repeated3(): | |||
| # remove matmul2, add matmul7 | |||
| def test_six_matmul_load_auto(): | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network1, network2): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network1 | |||
| self.network2 = network2 | |||
| self.network = network | |||
| def construct(self, x1, x2, x4, x5, x6, x7, x8, x9, x10): | |||
| predict = self.network(x1, x2, x4, x5, x6, x7, x8) | |||
| predict = self.network2(predict, x9, x10) | |||
| def construct(self, x1, x6, x7): | |||
| predict = self.network(x1, x6, x7) | |||
| return self.loss(predict) | |||
| @@ -220,62 +216,42 @@ def test_six_matmul_repeated3(): | |||
| super(GradWrap, self).__init__() | |||
| self.network = network | |||
| def construct(self, x1, x2, x4, x5, x6, x7, x8, x9, x10): | |||
| return C.grad_all(self.network)(x1, x2, x4, x5, x6, x7, x8, x9, x10) | |||
| def construct(self, x1, x6, x7): | |||
| return C.grad_all(self.network)(x1, x6, x7) | |||
| class Net(nn.Cell): | |||
| def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7): | |||
| def __init__(self, strategy1, strategy3, strategy4, strategy5): | |||
| super().__init__() | |||
| self.matmul1 = P.MatMul().set_strategy(strategy1) | |||
| self.matmul3 = P.MatMul().set_strategy(strategy3) | |||
| self.matmul4 = P.MatMul().set_strategy(strategy4) | |||
| self.matmul5 = P.MatMul().set_strategy(strategy5) | |||
| self.matmul6 = P.MatMul().set_strategy(strategy6) | |||
| self.matmul7 = P.MatMul().set_strategy(strategy7) | |||
| def construct(self, x1, x2, x4, x5, x6, x7, x8): | |||
| out = self.matmul1(x1, x2) | |||
| out = self.matmul3(out, x4) | |||
| out = self.matmul4(out, x5) | |||
| out = self.matmul5(out, x6) | |||
| out = self.matmul6(out, x7) | |||
| out = self.matmul7(out, x8) | |||
| return out | |||
| class Net1(nn.Cell): | |||
| def __init__(self, strategy1, strategy2): | |||
| super().__init__() | |||
| self.matmul1 = P.MatMul().set_strategy(strategy1) | |||
| self.matmul2 = P.MatMul().set_strategy(strategy2) | |||
| def construct(self, x1, x2, x3): | |||
| out = self.matmul1(x1, x2) | |||
| out = self.matmul2(out, x3) | |||
| self.matmul6 = P.MatMul() | |||
| self.matmul7 = P.MatMul() | |||
| self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1") | |||
| self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") | |||
| self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") | |||
| self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") | |||
| def construct(self, x1, x6, x7): | |||
| out = self.matmul1(x1, self.weight1) | |||
| out = self.matmul3(out, self.weight3) | |||
| out = self.matmul4(out, self.weight4) | |||
| out = self.matmul5(out, self.weight5) | |||
| out = self.matmul6(out, x6) | |||
| out = self.matmul7(out, x7) | |||
| return out | |||
| reset_auto_parallel_context() | |||
| set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.ckpt") | |||
| strategy1 = ((2, 2), (2, 2)) | |||
| strategy3 = ((2, 2), (2, 2)) | |||
| strategy4 = ((2, 2), (2, 2)) | |||
| strategy5 = ((2, 2), (2, 2)) | |||
| net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5))) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| set_auto_parallel_context(device_num=512, global_rank=0) | |||
| strategy1 = ((8, 1), (1, 1)) | |||
| strategy3 = ((8, 1), (1, 1)) | |||
| strategy4 = ((8, 1), (1, 1)) | |||
| strategy5 = ((8, 1), (1, 1)) | |||
| strategy6 = ((8, 1), (1, 1)) | |||
| strategy7 = ((8, 1), (1, 1)) | |||
| strategy8 = ((8, 1), (1, 1)) | |||
| strategy9 = ((8, 1), (1, 1)) | |||
| net1 = Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7) | |||
| net2 = Net1(strategy8, strategy9) | |||
| net = GradWrap(NetWithLoss(net1, net2)) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| x1 = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||
| x2 = Tensor(np.ones([32, 64]), dtype=ms.float32) | |||
| x4 = Tensor(np.ones([64, 128]), dtype=ms.float32) | |||
| x5 = Tensor(np.ones([128, 64]), dtype=ms.float32) | |||
| x6 = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| x1 = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| x6 = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||
| x7 = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| x8 = Tensor(np.ones([32, 128]), dtype=ms.float32) | |||
| x9 = Tensor(np.ones([128, 64]), dtype=ms.float32) | |||
| x10 = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| _executor.compile(net, x1, x2, x4, x5, x6, x7, x8, x9, x10) | |||
| _executor.compile(net, x1, x6, x7) | |||