| @@ -124,6 +124,10 @@ void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ck | |||
| strategy_ckpt_save_file_ = strategy_ckpt_save_file; | |||
| } | |||
| void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_save_file) { | |||
| group_ckpt_save_file_ = group_ckpt_save_file; | |||
| } | |||
| void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) { | |||
| all_reduce_fusion_split_indices_[group] = indices; | |||
| } | |||
| @@ -102,6 +102,8 @@ class ParallelContext { | |||
| std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; } | |||
| void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); | |||
| std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } | |||
| void set_group_ckpt_save_file(const std::string &group_ckpt_save_file); | |||
| std::string group_ckpt_save_file() const { return group_ckpt_save_file_; } | |||
| void set_enable_parallel_optimizer(bool enable_parallel_optimizer) { | |||
| enable_parallel_optimizer_ = enable_parallel_optimizer; | |||
| @@ -132,6 +134,7 @@ class ParallelContext { | |||
| std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_; | |||
| std::string strategy_ckpt_load_file_; | |||
| std::string strategy_ckpt_save_file_; | |||
| std::string group_ckpt_save_file_; | |||
| bool enable_parallel_optimizer_; | |||
| }; | |||
| @@ -83,6 +83,7 @@ class DeviceManager { | |||
| void Clear(); | |||
| std::string world_group() const { return gm_.world_group(); } | |||
| std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info() const { return gm_.group_info(); } | |||
| std::string FindRankListNameByHashName(const std::string &hash_name); | |||
| private: | |||
| @@ -40,16 +40,16 @@ class DynCreator { | |||
| public: | |||
| ~DynCreator() = default; | |||
| // creat static singleton dyn_creator instance | |||
| // create static singleton dyn_creator instance | |||
| static DynCreator &Instance() { | |||
| static DynCreator fac = DynCreator(); | |||
| return fac; | |||
| } | |||
| // register | |||
| void Regist(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); } | |||
| void Register(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); } | |||
| // creator | |||
| OperatorInfoPtr Creat(const std::string &name, const Shapes &shape_in, const Shapes &shape_out, | |||
| const PrimitiveAttrs &attrs, size_t count) { | |||
| OperatorInfoPtr Create(const std::string &name, const Shapes &shape_in, const Shapes &shape_out, | |||
| const PrimitiveAttrs &attrs, size_t count) { | |||
| std::string op_name = name + std::to_string(count); | |||
| auto iter = Function_map_.find(name); | |||
| if (iter == Function_map_.end()) { | |||
| @@ -67,7 +67,7 @@ class DynCreator { | |||
| class RegisterAction { | |||
| public: | |||
| RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) { | |||
| DynCreator::Instance().Regist(name, creatfn); | |||
| DynCreator::Instance().Register(name, creatfn); | |||
| } | |||
| ~RegisterAction() = default; | |||
| @@ -17,6 +17,7 @@ | |||
| #include "frontend/parallel/group_manager.h" | |||
| #include <algorithm> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "backend/session/executor_manager.h" | |||
| #include "frontend/parallel/device_manager.h" | |||
| #include "utils/comm_manager.h" | |||
| @@ -109,6 +110,9 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto | |||
| return Status::FAILED; | |||
| } | |||
| std::pair<std::string, std::vector<uint32_t>> group_info = std::make_pair(group_name, ranks); | |||
| group_info_.push_back(group_info); | |||
| MS_LOG(INFO) << "Create group success, group name is " << group_name; | |||
| return Status::SUCCESS; | |||
| } | |||
| @@ -187,5 +191,27 @@ Status GroupManager::FindGroup(const std::string &name, mindspore::parallel::Gro | |||
| } | |||
| void GroupManager::Clear() { (void)DestroyAllGroups(); } | |||
| Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) { | |||
| // Create group through the executor | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||
| auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id); | |||
| MS_EXCEPTION_IF_NULL(executor); | |||
| for (auto &group : group_info) { | |||
| bool ret = executor->CreateCommGroup(group.first, group.second); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -21,6 +21,7 @@ | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "frontend/parallel/device.h" | |||
| #include "frontend/parallel/status.h" | |||
| @@ -62,6 +63,7 @@ class GroupManager { | |||
| Status FindGroup(const std::string &name, Group **group); | |||
| std::string world_group() const { return world_group_; } | |||
| void set_world_group(const std::string &name) { world_group_ = name; } | |||
| std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info() const { return group_info_; } | |||
| void Clear(); | |||
| private: | |||
| @@ -69,7 +71,10 @@ class GroupManager { | |||
| // the key is group name (name_) | |||
| std::map<std::string, Group> groups_; | |||
| std::string world_group_; | |||
| std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info_; | |||
| }; | |||
| Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -160,7 +160,7 @@ Status ReduceMethod::InferForwardCommunication() { | |||
| Shape group_creat_map; | |||
| // if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix, | |||
| // it need to handle the first dimention of map. | |||
| // it need to handle the first dimension of map. | |||
| if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) { | |||
| group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); | |||
| } | |||
| @@ -200,12 +200,12 @@ Status ReduceMethod::InferForwardCommunication() { | |||
| } | |||
| ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype) { | |||
| // Creat AllReduceSum op | |||
| // Create AllReduceSum op | |||
| Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name()); | |||
| std::string group_name = forward_group[0].name(); | |||
| MS_LOG(INFO) << "The group of forward all reduce is " << group_name; | |||
| // Creat RealDiv op | |||
| // Create RealDiv op | |||
| OperatorName operator1_name = REAL_DIV; | |||
| std::vector<Device> device_list = forward_group[0].GetDevicesList(); | |||
| auto divisor = static_cast<float>(device_list.size()); | |||
| @@ -237,7 +237,7 @@ Status ReduceMeanInfo::InferForwardCommunication() { | |||
| Shape group_creat_map; | |||
| // if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix, | |||
| // it need to handle the first dimention of map. | |||
| // it need to handle the first dimension of map. | |||
| if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) { | |||
| group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); | |||
| } | |||
| @@ -326,7 +326,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) { | |||
| std::string instance_name_base = FORWARD_OP; | |||
| std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index); | |||
| std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert, instance_name); | |||
| CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to creat anfnode | |||
| CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to create anfnode | |||
| MS_EXCEPTION_IF_NULL(forward_node); | |||
| ScopePtr scope = node->scope(); | |||
| MS_EXCEPTION_IF_NULL(scope); | |||
| @@ -371,10 +371,10 @@ void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_p | |||
| if (pos >= SizeToLong(node->inputs().size())) { | |||
| MS_LOG(EXCEPTION) << "InsertRedistribution:pos can't be larger than node's inputs'size"; | |||
| } | |||
| // Creat new node | |||
| // Create new node | |||
| AnfNodePtr target_node = node->input(LongToSize(pos)); | |||
| MS_EXCEPTION_IF_NULL(target_node); | |||
| // Creat instance_name | |||
| // Create instance_name | |||
| auto op = (redistribution_oplist_ptr->first)[index]; | |||
| std::string op_name = (redistribution_oplist_ptr->first)[index].first; | |||
| std::string instance_name_base = REDISTRIBUTION_OP; | |||
| @@ -400,7 +400,7 @@ void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const Func | |||
| MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: pos can't be larger than node's inputs'size, the instance name is " | |||
| << instance_name; | |||
| } | |||
| // Creat new node | |||
| // Create new node | |||
| AnfNodePtr pre_node = node->input(LongToSize(pos)); | |||
| MS_EXCEPTION_IF_NULL(pre_node); | |||
| InsertNode(op, node, LongToSize(pos), pre_node, func_graph, instance_name); | |||
| @@ -595,7 +595,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ | |||
| CNodePtr insert_node_new; | |||
| if (AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST)) { | |||
| MS_LOG(INFO) << "No need to insert redistribution op betweend make_tuple node and the next node"; | |||
| MS_LOG(INFO) << "No need to insert redistribution op between make_tuple node and the next node"; | |||
| return; | |||
| } | |||
| if (IsValueNode<Primitive>(node->input(0))) { | |||
| @@ -883,10 +883,10 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node | |||
| if (manager == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; | |||
| } | |||
| // Sovle the input order | |||
| // Solve the input order | |||
| // For example input_node:{segment_sum:1, segment_sum:2, gahter:2} | |||
| // The Original code here will bind the all operations to the first inputs of theses operatos | |||
| // However, the segment_sum operation needs two inputs, To sovle this | |||
| // The Original code here will bind the all operations to the first inputs of these operatos | |||
| // However, the segment_sum operation needs two inputs, To solve this | |||
| // We maintain a dict to count the times of the same operations, | |||
| // and bind the inputs according to the times of the op appears. | |||
| static std::unordered_map<AnfNodePtr, int> input_map = {}; | |||
| @@ -1241,9 +1241,9 @@ OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveA | |||
| } | |||
| } | |||
| OperatorInfoPtr operator_ = | |||
| (OperatorInfoPtr)DynCreator::Instance().Creat(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS); | |||
| (OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS); | |||
| if (operator_ == nullptr) { | |||
| MS_LOG(INFO) << "Creat " << name << " failed"; | |||
| MS_LOG(INFO) << "Create " << name << " failed"; | |||
| return nullptr; | |||
| } | |||
| std::string origin_name = operator_->name(); | |||
| @@ -1261,7 +1261,7 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs | |||
| if (IsInBatchParallelBlackList(prim)) { | |||
| MS_LOG(EXCEPTION) << "Operator " << prim->name() << " is not supported yet in auto parallel mode."; | |||
| } | |||
| MS_LOG(INFO) << "Creat " << prim->name() << " failed, use batch parallel"; | |||
| MS_LOG(INFO) << "Create " << prim->name() << " failed, use batch parallel"; | |||
| operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list); | |||
| MS_EXCEPTION_IF_NULL(operator_); | |||
| } | |||
| @@ -1351,7 +1351,7 @@ Shapes GetNodeShape(const AnfNodePtr &node) { | |||
| } | |||
| if (cnode->input(0)->isa<CNode>()) { | |||
| if (cnode->inputs().size() < 2) { | |||
| MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is samller than 2"; | |||
| MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2"; | |||
| } | |||
| base_shape_ptr = cnode->input(1)->Shape(); | |||
| } | |||
| @@ -2546,7 +2546,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt | |||
| bool has_backward = !sens_loss_pairs.empty(); | |||
| // split sens must before inserting the operators. | |||
| for (auto &pair : sens_loss_pairs) { | |||
| // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it. | |||
| // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handle it. | |||
| // If the type of sens node is not Tensor, it is unsupported now, do nothing default. | |||
| if (IsLastStage()) { | |||
| StepSplitSens(pair); | |||
| @@ -2703,7 +2703,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) { | |||
| auto param_split_shapes = gatherv2_info->param_split_shapes(); | |||
| auto index_offsets = gatherv2_info->index_offsets(); | |||
| if (param_split_shapes.size() != index_offsets.size()) { | |||
| MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets lenght should be same."; | |||
| MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets length should be same."; | |||
| } | |||
| std::vector<std::pair<int64_t, int64_t>> manual_shape; | |||
| for (int64_t i = 0; i < UlongToLong(param_split_shapes.size()); ++i) { | |||
| @@ -2713,6 +2713,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) { | |||
| } | |||
| } | |||
| } | |||
| if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; | |||
| } | |||
| @@ -3142,6 +3143,19 @@ void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) { | |||
| } | |||
| } | |||
| bool CreateGroupsByCkptFile(const std::string &file) { | |||
| GroupInfoMap group_info_map; | |||
| if (StrategyCheckpoint::GetInstance().LoadGroupInfo(file, &group_info_map) != SUCCESS) { | |||
| return false; | |||
| } | |||
| if (CreateGroups(group_info_map) != SUCCESS) { | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "Create groups by checkpoint file success"; | |||
| return true; | |||
| } | |||
| bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| @@ -3290,6 +3304,12 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||
| // ForwardCommunication BackwardCommunication TensorRedistribution | |||
| ParallelCommunication(root, all_nodes, manager); | |||
| auto group_info = g_device_manager->group_info(); | |||
| if (StrategyCheckpoint::GetInstance().group_info_save_on() && | |||
| StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Save group info failed"; | |||
| } | |||
| DumpGraph(root, std::string(STEP_PARALLEL_END)); | |||
| // step parallel only run once | |||
| @@ -109,7 +109,7 @@ void CoverSliceShape(const FuncGraphPtr &root); | |||
| void SetVirtualDatasetStrategy(const CNodePtr &node); | |||
| // Creat parallel operator for primitive node(has strategy) | |||
| // Create parallel operator for primitive node(has strategy) | |||
| void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_training = true); | |||
| TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair); | |||
| @@ -163,6 +163,8 @@ void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr | |||
| void SetLastNodeStrategy(const StrategyPtr strategyPtr); | |||
| bool CreateGroupsByCkptFile(const std::string &file); | |||
| void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -34,6 +34,8 @@ StrategyCheckpoint &StrategyCheckpoint::GetInstance() { | |||
| instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty(); | |||
| instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file(); | |||
| instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty(); | |||
| instance.group_info_save_file_ = ParallelContext::GetInstance()->group_ckpt_save_file(); | |||
| instance.group_info_save_on_ = !ParallelContext::GetInstance()->group_ckpt_save_file().empty(); | |||
| } | |||
| return instance; | |||
| } | |||
| @@ -46,6 +48,39 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const { | |||
| return false; | |||
| } | |||
| Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) { | |||
| MS_EXCEPTION_IF_NULL(group_info_map); | |||
| if (!CheckPointExit(file)) { | |||
| MS_LOG(EXCEPTION) << "CheckPoint file is not found"; | |||
| } | |||
| straspb::ParallelGroupMap parallel_group_map; | |||
| std::fstream input(file, std::ios::in | std::ios::binary); | |||
| if (!parallel_group_map.ParseFromIstream(&input)) { | |||
| MS_LOG(ERROR) << "Load strategy file failed"; | |||
| return FAILED; | |||
| } | |||
| input.close(); | |||
| size_t group_num = LongToSize(parallel_group_map.parallel_group_item_size()); | |||
| for (size_t i = 0; i < group_num; ++i) { | |||
| straspb::ParallelGroupItem parallel_group_item = parallel_group_map.parallel_group_item(SizeToLong(i)); | |||
| std::string group_name = parallel_group_item.group_name(); | |||
| straspb::ParallelGroupRanks parallel_group_ranks = parallel_group_item.parallel_group_ranks(); | |||
| size_t rank_num = LongToSize(parallel_group_ranks.dim_size()); | |||
| std::vector<uint32_t> ranks; | |||
| for (size_t j = 0; j < rank_num; ++j) { | |||
| uint32_t rank = parallel_group_ranks.dim(SizeToLong(j)); | |||
| ranks.push_back(rank); | |||
| } | |||
| std::pair<std::string, std::vector<uint32_t>> group = std::make_pair(group_name, ranks); | |||
| group_info_map->push_back(group); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { | |||
| if (strategy_map == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr"; | |||
| @@ -141,5 +176,27 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf | |||
| output.close(); | |||
| return SUCCESS; | |||
| } | |||
| Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) { | |||
| straspb::ParallelGroupMap parallel_group_map; | |||
| for (auto &group : group_info_map) { | |||
| straspb::ParallelGroupItem *parallel_group_item = parallel_group_map.add_parallel_group_item(); | |||
| MS_EXCEPTION_IF_NULL(parallel_group_item); | |||
| parallel_group_item->set_group_name(group.first); | |||
| straspb::ParallelGroupRanks *parallel_group_ranks = parallel_group_item->mutable_parallel_group_ranks(); | |||
| MS_EXCEPTION_IF_NULL(parallel_group_ranks); | |||
| for (auto &rank : group.second) { | |||
| parallel_group_ranks->add_dim(rank); | |||
| } | |||
| } | |||
| std::fstream output(group_info_save_file_, std::ios::out | std::ios::trunc | std::ios::binary); | |||
| if (!parallel_group_map.SerializeToOstream(&output)) { | |||
| MS_LOG(ERROR) << "Save strategy file failed"; | |||
| return FAILED; | |||
| } | |||
| output.close(); | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -32,6 +32,7 @@ namespace parallel { | |||
| using StrategyMap = std::unordered_map<std::string, StrategyPtr>; | |||
| using TensorInfoMap = std::unordered_map<std::string, TensorInfo>; | |||
| using ManualShapeMap = std::unordered_map<std::string, std::vector<std::pair<int64_t, int64_t>>>; | |||
| using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>; | |||
| class StrategyCheckpoint { | |||
| public: | |||
| StrategyCheckpoint() { | |||
| @@ -40,11 +41,16 @@ class StrategyCheckpoint { | |||
| load_checkpoint_on_ = false; | |||
| save_file_ = ""; | |||
| save_checkpoint_on_ = false; | |||
| group_info_save_file_ = ""; | |||
| group_info_save_on_ = false; | |||
| } | |||
| ~StrategyCheckpoint() = default; | |||
| Status Load(StrategyMap *strategy_map); | |||
| Status LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map); | |||
| Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map); | |||
| Status SaveGroupInfo(const GroupInfoMap &group_info_map); | |||
| bool group_info_save_on() const { return group_info_save_on_; } | |||
| static StrategyCheckpoint &GetInstance(); | |||
| bool LoadCheckPointOn() const { return load_checkpoint_on_; } | |||
| @@ -57,6 +63,8 @@ class StrategyCheckpoint { | |||
| bool save_checkpoint_on_; | |||
| bool CheckPointExit(const std::string path) const; | |||
| int64_t current_stage_; | |||
| std::string group_info_save_file_; | |||
| bool group_info_save_on_; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -157,6 +157,7 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| "Set strategy checkpoint save file.") | |||
| .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.") | |||
| .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") | |||
| .def("set_group_ckpt_save_file", &ParallelContext::set_group_ckpt_save_file, "Set group checkpoint save file.") | |||
| .def("set_pipeline_stage_split_num", &ParallelContext::set_pipeline_stage_split_num, | |||
| "Set pipeline stage split num.") | |||
| .def("get_pipeline_stage_split_num", &ParallelContext::pipeline_stage_split_num, "Get pipeline stage split num.") | |||
| @@ -61,6 +61,19 @@ message ParallelLayoutItem { | |||
| required ParallelLayouts parallel_layouts = 2; | |||
| } | |||
| message ParallelGroupRanks { | |||
| repeated uint32 dim = 1; | |||
| } | |||
| message ParallelGroupItem { | |||
| required string group_name = 1; | |||
| required ParallelGroupRanks parallel_group_ranks = 2; | |||
| } | |||
| message ParallelGroupMap { | |||
| repeated ParallelGroupItem parallel_group_item = 1; | |||
| } | |||
| message ParallelStrategyMap { | |||
| required uint32 current_stage = 1; | |||
| repeated ParallelStrategyItem parallel_strategy_item = 2; | |||
| @@ -283,6 +283,15 @@ class _AutoParallelContext: | |||
| self.check_context_handle() | |||
| return self._context_handle.get_strategy_ckpt_save_file() | |||
| def set_group_ckpt_save_file(self, group_ckpt_save_file): | |||
| """Set group checkpoint save path.""" | |||
| self.check_context_handle() | |||
| import os | |||
| dir_path = os.path.dirname(group_ckpt_save_file) | |||
| if dir_path and not os.path.exists(dir_path): | |||
| os.makedirs(dir_path) | |||
| self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file) | |||
| def get_parameter_broadcast_is_set(self): | |||
| """Get parameter broadcast is set or not.""" | |||
| self.check_context_handle() | |||
| @@ -505,6 +514,7 @@ _set_auto_parallel_context_func_map = { | |||
| "parameter_broadcast": auto_parallel_context().set_parameter_broadcast, | |||
| "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, | |||
| "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, | |||
| "group_ckpt_save_file": auto_parallel_context().set_group_ckpt_save_file, | |||
| "full_batch": auto_parallel_context().set_full_batch, | |||
| "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer, | |||
| "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step, | |||
| @@ -533,7 +543,7 @@ _get_auto_parallel_context_func_map = { | |||
| loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, | |||
| parameter_broadcast=bool, strategy_ckpt_load_file=str, | |||
| strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, | |||
| grad_accumulation_step=int, all_reduce_fusion_config=list) | |||
| grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str) | |||
| def _set_auto_parallel_context(**kwargs): | |||
| """ | |||
| @@ -574,6 +584,7 @@ def _set_auto_parallel_context(**kwargs): | |||
| broadcast. Default: False. | |||
| strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' | |||
| strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' | |||
| group_ckpt_save_file (str): The path to save parallel group checkpoint. Default: '' | |||
| full_batch (bool): Whether to load the whole batch on each device. Default: False. | |||
| enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False. | |||
| all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. | |||
| @@ -31,5 +31,9 @@ Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; } | |||
| Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, | |||
| ManualShapeMap *manual_shape_map) { return SUCCESS; } | |||
| Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) { return SUCCESS; } | |||
| Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) { return SUCCESS; } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -75,7 +75,8 @@ def test_six_matmul_save(): | |||
| return out | |||
| reset_auto_parallel_context() | |||
| set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt") | |||
| set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt", | |||
| group_ckpt_save_file="./group_stage1.ckpt") | |||
| strategy1 = ((8, 1), (1, 1)) | |||
| strategy2 = ((1, 8), (8, 1)) | |||
| strategy3 = ((2, 2), (2, 2)) | |||
| @@ -137,7 +138,8 @@ def test_six_matmul_load(): | |||
| return out | |||
| reset_auto_parallel_context() | |||
| set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt") | |||
| set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt", | |||
| group_ckpt_save_file="./group_stage1.ckpt") | |||
| strategy1 = ((8, 1), (1, 1)) | |||
| strategy3 = ((8, 1), (1, 1)) | |||
| strategy4 = ((8, 1), (1, 1)) | |||