Merge pull request !6673 from huangxinjing/stage_strategytags/v1.1.0
| @@ -63,6 +63,8 @@ void ParallelContext::Reset() { | |||
| all_reduce_fusion_split_indices_.clear(); | |||
| all_reduce_fusion_split_sizes_.clear(); | |||
| strategy_search_mode_ = DYNAMIC_PROGRAMMING; | |||
| stages_.clear(); | |||
| pipeline_stage_split_num_ = 0; | |||
| } | |||
| void ParallelContext::set_device_num(int32_t device_num) { | |||
| @@ -83,6 +85,10 @@ void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient | |||
| void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } | |||
| void ParallelContext::set_pipeline_stage_split_num(const int32_t stage_num) { pipeline_stage_split_num_ = stage_num; } | |||
| void ParallelContext::set_stage(const std::vector<int32_t> &stages) { stages_ = stages; } | |||
| bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) { | |||
| auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); | |||
| if (iter == PARALLEL_MODE_LIST.end()) { | |||
| @@ -67,6 +67,12 @@ class ParallelContext { | |||
| void set_device_num(int32_t device_num); | |||
| int32_t device_num() const { return device_num_; } | |||
| void set_pipeline_stage_split_num(const int32_t stages); | |||
| int32_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; } | |||
| void set_stage(const std::vector<int32_t> &stages); | |||
| std::vector<int32_t> stage() const { return stages_; } | |||
| void set_global_rank(int32_t global_rank); | |||
| int32_t global_rank() const { return global_rank_; } | |||
| @@ -115,6 +121,8 @@ class ParallelContext { | |||
| int32_t global_rank_; | |||
| std::string parallel_mode_; | |||
| std::string strategy_search_mode_; | |||
| std::vector<int32_t> stages_; | |||
| int32_t pipeline_stage_split_num_; | |||
| bool parameter_broadcast_; | |||
| bool device_num_is_set_; | |||
| bool global_rank_is_set_; | |||
| @@ -36,7 +36,8 @@ Stage::Stage(const std::vector<mindspore::parallel::Device> &devices, int num, i | |||
| // NOTE: '-1' indicates ERROR | |||
| int Stage::global_rank(Group *g) const { return ((g == nullptr) ? rank_ : -1); } | |||
| bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend) { | |||
| bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend, | |||
| const std::vector<int32_t> &stage) { | |||
| if (device_num <= 0) { | |||
| MS_LOG(ERROR) << "'device_num' must be positive."; | |||
| return false; | |||
| @@ -68,7 +69,30 @@ bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &back | |||
| devices.push_back(i); | |||
| } | |||
| stage_map.push_back(device_num); | |||
| if (stage.size()) { | |||
| int32_t summed_value = 0; | |||
| for (auto begin = stage.begin(); begin != stage.end(); ++begin) { | |||
| if (*begin <= 0) { | |||
| MS_LOG(ERROR) << "The value in the pipeline stages should be positive value"; | |||
| return false; | |||
| } | |||
| summed_value += *begin; | |||
| stage_map.push_back(*begin); | |||
| } | |||
| if (summed_value != device_num) { | |||
| MS_LOG(ERROR) << "The sum of the pipeline stage :" << summed_value << " is not equal to the device_num " | |||
| << device_num; | |||
| return false; | |||
| } | |||
| } else { | |||
| stage_map.push_back(device_num); | |||
| } | |||
| for (auto &y : stage_map) { | |||
| MS_LOG(DEBUG) << "Obtained stage id :" << y; | |||
| } | |||
| g_device_manager = std::make_shared<DeviceManager>(); | |||
| if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) { | |||
| MS_LOG(INFO) << "Device initialization succeeds."; | |||
| @@ -70,7 +70,7 @@ class Stage { | |||
| // This method is used for initializing the global DeviceManager 'g_device_manager', | |||
| // arguments including 'device_num' and 'global_rank' | |||
| bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend); | |||
| bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend, const std::vector<int32_t> &stage); | |||
| void CheckGlobalDeviceManager(); | |||
| @@ -126,9 +126,22 @@ Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *ra | |||
| } | |||
| } | |||
| Shape current_rank_coordinate = ConvertRankToCoordinate(rank_, dev_shape_); | |||
| // Convert the global rank to the local rank(The index of the array) to compute the coordinate | |||
| uint32_t local_rank = 0; | |||
| for (auto &tmp_rank : dev_list_) { | |||
| Shape tmp_rank_coordinate = ConvertRankToCoordinate(tmp_rank, dev_shape_); | |||
| if (tmp_rank == rank_) { | |||
| break; | |||
| } | |||
| ++local_rank; | |||
| } | |||
| if (local_rank == dev_list_.size()) { | |||
| MS_LOG(ERROR) << "Rank id: " << local_rank << "is not in the device list."; | |||
| return FAILED; | |||
| } | |||
| Shape current_rank_coordinate = ConvertRankToCoordinate((int32_t)local_rank, dev_shape_); | |||
| for (uint32_t loop_local_rank = 0; loop_local_rank < dev_list_.size(); ++loop_local_rank) { | |||
| Shape tmp_rank_coordinate = ConvertRankToCoordinate(loop_local_rank, dev_shape_); | |||
| bool matched = true; | |||
| for (auto &map : tensor_map) { | |||
| if (map == MAP_NONE) { | |||
| @@ -141,7 +154,7 @@ Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *ra | |||
| } | |||
| } | |||
| if (matched) { | |||
| rank_list->push_back(tmp_rank); | |||
| rank_list->push_back(dev_list_[loop_local_rank]); | |||
| } | |||
| } | |||
| @@ -43,7 +43,7 @@ Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| // dropout don't support repeated calculation | |||
| CheckGlobalDeviceManager(); | |||
| auto input_strategy = strategy->GetInputDim().at(0); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); | |||
| auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies<int>()); | |||
| if (IntToSize(product_p) != dev_num) { | |||
| MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; | |||
| @@ -196,7 +196,7 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) { | |||
| // Don't support repeated calc | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); | |||
| auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>()); | |||
| if (IntToSize(product_p) < dev_num) { | |||
| MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc"; | |||
| @@ -269,7 +269,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| // param_strategy(axis) != 1, Don't support repeated calc | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); | |||
| auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>()); | |||
| if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) { | |||
| MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; | |||
| @@ -346,7 +346,7 @@ Status GatherV2PInfo::InferDevMatrixShape() { | |||
| out_dev_matrix_shape_ = dev_matrix_shape_; | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); | |||
| auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>()); | |||
| auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int>()); | |||
| if (param_product * index_product < SizeToInt(dev_num)) { | |||
| @@ -516,10 +516,11 @@ Status GatherV2PInfo::InferGroup() { | |||
| if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) { | |||
| dim = (axis_ + 1) % 2; | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||
| RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id_); | |||
| int32_t rank = g_device_manager->global_rank(); | |||
| RankList dev_list = g_device_manager->GetDeviceListByStageId(0); | |||
| DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_); | |||
| RankList group_devices; | |||
| if (dev_matrix.GetDevicesAlongDim(SizeToUint(dim), &group_devices) != SUCCESS) { | |||
| @@ -162,7 +162,8 @@ class OperatorInfo { | |||
| void set_type(const std::string &type) { type_ = type; } | |||
| const std::string &type() const { return type_; } | |||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | |||
| void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; } | |||
| int32_t stage_id() const { return stage_id_; } | |||
| // Key for user data. | |||
| constexpr static char key[] = "OpInfo"; | |||
| @@ -205,6 +206,7 @@ class OperatorInfo { | |||
| std::vector<ValuePtr> input_value_; | |||
| TypePtr outputs_dtype_; | |||
| int32_t stage_id_ = 0; | |||
| StrategyPtr strategy_; | |||
| std::vector<TensorInfo> inputs_tensor_info_; | |||
| std::vector<TensorInfo> outputs_tensor_info_; | |||
| @@ -55,6 +55,7 @@ constexpr char AUTO_PARALLEL_RUN_ONCE_ONLY[] = "auto_parallel_run_once_only"; | |||
| constexpr char SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY[] = "semi_auto_parallel_run_once_only"; | |||
| constexpr char CHECK_SET_STRATEGY_VALID_ONCE_ONLY[] = "check_set_strategy_valid_once_only"; | |||
| constexpr char STRATEGY[] = "strategy"; | |||
| constexpr char STAGE_ATTR[] = "stage"; | |||
| constexpr char GEN_STRATEGY[] = "gen_strategy"; | |||
| constexpr char REDUCE_OP_SUM[] = "sum"; | |||
| constexpr char REDUCE_OP_MAX[] = "max"; | |||
| @@ -133,9 +133,9 @@ Status ReduceMethod::InferTensorMap() { | |||
| return SUCCESS; | |||
| } | |||
| bool IsDataParallelStrategy(const Dimensions &strategy) { | |||
| bool IsDataParallelStrategy(const Dimensions &strategy, int32_t stage_id) { | |||
| CheckGlobalDeviceManager(); | |||
| size_t total_dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); | |||
| if (strategy.empty()) { | |||
| MS_LOG(EXCEPTION) << "IsDataParallelStrategy: strategy is empty"; | |||
| } | |||
| @@ -145,7 +145,7 @@ bool IsDataParallelStrategy(const Dimensions &strategy) { | |||
| Status ReduceMethod::InferForwardCommunication() { | |||
| Dimensions stra = strategy_->GetInputDim().at(0); | |||
| if (cross_batch_ && IsDataParallelStrategy(stra)) { | |||
| if (cross_batch_ && IsDataParallelStrategy(stra, stage_id_)) { | |||
| MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; | |||
| return SUCCESS; | |||
| } | |||
| @@ -211,7 +211,7 @@ ForwardOp CreatReduceMeanForwardOp(const std::vector<Group> &forward_group, cons | |||
| Status ReduceMeanInfo::InferForwardCommunication() { | |||
| Dimensions stra = strategy_->GetInputDim().at(0); | |||
| if (cross_batch_ && IsDataParallelStrategy(stra)) { | |||
| if (cross_batch_ && IsDataParallelStrategy(stra, stage_id_)) { | |||
| MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; | |||
| return SUCCESS; | |||
| } | |||
| @@ -998,6 +998,17 @@ OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAtt | |||
| StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) { | |||
| ValueTuplePtr var = attrs[STRATEGY]->cast<ValueTuplePtr>(); | |||
| StrategyPtr strategyPtr; | |||
| std::vector<int32_t> stages = ParallelContext::GetInstance()->stage(); | |||
| auto res = attrs.find(STAGE_ATTR); | |||
| int32_t stage_id = 0; | |||
| if (res != attrs.end()) { | |||
| stage_id = GetValue<int>(res->second); | |||
| } | |||
| if (stage_id && stages.empty()) { | |||
| MS_LOG(ERROR) << "Find stage id:" << stage_id << " but the pipeline_stages is 0."; | |||
| return nullptr; | |||
| } | |||
| MS_LOG(INFO) << "Extract information: strategy " << attrs[STRATEGY]->ToString(); | |||
| if (var == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Strategy value is nullptr"; | |||
| @@ -1016,13 +1027,13 @@ StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) { | |||
| }); | |||
| strategy.push_back(dim); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequeue"; | |||
| MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequence"; | |||
| } | |||
| } | |||
| if (strategy.empty()) { | |||
| MS_LOG(EXCEPTION) << "ExtractStrategy:failed to extract strategy"; | |||
| } | |||
| strategyPtr = NewStrategy(0, strategy); | |||
| strategyPtr = NewStrategy(stage_id, strategy); | |||
| } | |||
| return strategyPtr; | |||
| @@ -1420,6 +1431,30 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { | |||
| (void)prim->SetAttrs(attrs_temp); | |||
| } | |||
| } | |||
| // This function aims to check the valid rank and stage in the operations | |||
| // If the rank is not valid for the given stage, we chose not to init the strategy of the operation | |||
| // For example stage is [4, 4], and the group_list [[0,1,2,3],[4,5,6,7]] | |||
| // For stage 0, we require the rank_id is in [0,1,2,3] | |||
| Status ValidRankCheck(int32_t global_rank, int32_t strategy_stage) { | |||
| RankList local_group_list = g_device_manager->GetDeviceListByStageId(strategy_stage); | |||
| int32_t target = global_rank; | |||
| if (std::any_of(local_group_list.begin(), local_group_list.end(), [target](int32_t a) { return a == target; })) { | |||
| return Status::SUCCESS; | |||
| } | |||
| return Status::FAILED; | |||
| } | |||
| Status ValidStageCheck(const std::vector<int32_t> &stages, int32_t strategy_stage) { | |||
| if (stages.size() > 0) { | |||
| if (strategy_stage >= 0 && strategy_stage < (int32_t)stages.size()) { | |||
| return Status::SUCCESS; | |||
| } | |||
| return Status::FAILED; | |||
| } else { | |||
| return Status::SUCCESS; | |||
| } | |||
| } | |||
| void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||
| // load strategy map from checkpoint | |||
| @@ -1429,6 +1464,11 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||
| MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | |||
| } | |||
| } | |||
| // Get global rank after the checkpoint? | |||
| int32_t global_rank = ParallelContext::GetInstance()->global_rank(); | |||
| std::vector<int32_t> stages = ParallelContext::GetInstance()->stage(); | |||
| for (auto &node : all_nodes) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| @@ -1501,7 +1541,18 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||
| strategyPtr = ExtractStrategy(attrs); | |||
| } | |||
| if (strategyPtr != nullptr) { | |||
| if (operator_->Init(strategyPtr) == FAILED) { | |||
| (*operator_).set_stage_id(strategyPtr->GetInputStage()); | |||
| MS_LOG(INFO) << "Extract stage id for op " << prim->name() << " is " << (*operator_).stage_id(); | |||
| if (ValidStageCheck(stages, (*operator_).stage_id()) == FAILED) { | |||
| MS_LOG(ERROR) << "Find stage " << strategyPtr->GetInputStage() << " for operator " << prim->name() | |||
| << " exceeds the global stage size " << stages.size() << '.'; | |||
| return; | |||
| } | |||
| // If the strategy is not valid for the given global rank, then we skip the Init of the strategy | |||
| if (ValidRankCheck(global_rank, (*operator_).stage_id()) == FAILED) { | |||
| MS_LOG(INFO) << "Find global exceeds the range of the stage, skip the strategy init for operator " | |||
| << prim->name(); | |||
| } else if (operator_->Init(strategyPtr) == FAILED) { | |||
| MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; | |||
| } | |||
| cnode->set_user_data<OperatorInfo>(operator_); | |||
| @@ -2416,6 +2467,9 @@ Status ParallelInit() { | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| int32_t device_num = ParallelContext::GetInstance()->device_num(); | |||
| int32_t global_rank = ParallelContext::GetInstance()->global_rank(); | |||
| int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num(); | |||
| std::vector<int32_t> stages = ParallelContext::GetInstance()->stage(); | |||
| std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| @@ -2431,6 +2485,26 @@ Status ParallelInit() { | |||
| MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend; | |||
| } | |||
| if (device_num <= 0) { | |||
| MS_LOG(ERROR) << "Invalid device num " << device_num << " , expected a positive device number"; | |||
| return FAILED; | |||
| } | |||
| if (split_stage_num > 0) { | |||
| if (device_num % split_stage_num != 0) { | |||
| MS_LOG(ERROR) << "Device num " << device_num << " can't be divided by stage num " << split_stage_num | |||
| << " , as we support only extract devision now"; | |||
| return FAILED; | |||
| } | |||
| for (int i = 0; i < split_stage_num; i++) { | |||
| stages.push_back(device_num / split_stage_num); | |||
| } | |||
| } else if (split_stage_num < 0) { | |||
| MS_LOG(ERROR) << "Invalid stage num " << split_stage_num << " , expected a positive stage number"; | |||
| return FAILED; | |||
| } | |||
| ParallelContext::GetInstance()->set_stage(stages); | |||
| uint32_t world_rank_size = 0; | |||
| if (!ParallelContext::GetInstance()->device_num_is_set()) { | |||
| if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) { | |||
| @@ -2449,7 +2523,12 @@ Status ParallelInit() { | |||
| MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank; | |||
| } | |||
| if (!InitDevice(device_num, global_rank, communication_backend)) { | |||
| if (!stages.empty() && parallel_mode != SEMI_AUTO_PARALLEL) { | |||
| MS_LOG(ERROR) << "To enable the pipeline parallel, please set the parallel mode to " << SEMI_AUTO_PARALLEL; | |||
| return FAILED; | |||
| } | |||
| if (!InitDevice(device_num, global_rank, communication_backend, stages)) { | |||
| MS_LOG(ERROR) << "Init device failed"; | |||
| return FAILED; | |||
| } | |||
| @@ -2457,6 +2536,7 @@ Status ParallelInit() { | |||
| MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank | |||
| << ", backend: " << backend << ", gradients_mean: " << ParallelContext::GetInstance()->gradients_mean() | |||
| << ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync(); | |||
| return SUCCESS; | |||
| } | |||
| @@ -152,6 +152,9 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| "Set strategy checkpoint save file.") | |||
| .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.") | |||
| .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") | |||
| .def("set_pipeline_stage_split_num", &ParallelContext::set_pipeline_stage_split_num, | |||
| "Set pipeline stage split num.") | |||
| .def("get_pipeline_stage_split_num", &ParallelContext::pipeline_stage_split_num, "Get pipeline stage split num.") | |||
| .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.") | |||
| .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.") | |||
| .def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer, | |||
| @@ -331,7 +331,7 @@ def _context(): | |||
| @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, | |||
| auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, | |||
| strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, | |||
| all_reduce_fusion_config=list) | |||
| all_reduce_fusion_config=list, pipeline_stages=int) | |||
| def set_auto_parallel_context(**kwargs): | |||
| """ | |||
| Set auto parallel context. | |||
| @@ -357,6 +357,7 @@ def set_auto_parallel_context(**kwargs): | |||
| parallel_mode strategy_ckpt_load_file | |||
| all_reduce_fusion_config strategy_ckpt_save_file | |||
| full_batch | |||
| pipeline_stages | |||
| =========================== =========================== ================= | |||
| Args: | |||
| @@ -399,6 +400,10 @@ def set_auto_parallel_context(**kwargs): | |||
| the fusion is closed. | |||
| all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM | |||
| and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed. | |||
| pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how | |||
| the devices are distributed alone the pipeline. The total devices will be divided into | |||
| 'pipeline_stags' stages. This currently could only be used when | |||
| parall mode semi_auto_parallel is enabled. | |||
| Raises: | |||
| ValueError: If input key is not attribute in auto parallel context. | |||
| @@ -416,10 +421,10 @@ def set_auto_parallel_context(**kwargs): | |||
| >>> context.set_auto_parallel_context(full_batch=True) | |||
| >>> context.set_auto_parallel_context(enable_parallel_optimizer=False) | |||
| >>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160]) | |||
| >>> context.set_auto_parallel_context(pipeline_stages=2) | |||
| """ | |||
| _set_auto_parallel_context(**kwargs) | |||
| def get_auto_parallel_context(attr_key): | |||
| """ | |||
| Gets auto parallel context attribute value according to the key. | |||
| @@ -102,6 +102,20 @@ class Primitive(Primitive_): | |||
| self.add_attr(name, value) | |||
| return self | |||
| def set_stage(self, stage): | |||
| """ | |||
| Add stage id to primitive attribute. | |||
| Note: | |||
| It is valid only in semi auto parallel. | |||
| In other parallel modes, please set it to be 0. | |||
| Args: | |||
| stage (int): The stage id for the current operation | |||
| """ | |||
| self.add_prim_attr("stage", stage) | |||
| return self | |||
| def shard(self, strategy): | |||
| """ | |||
| Add strategies to primitive attribute. | |||
| @@ -95,6 +95,16 @@ class _AutoParallelContext: | |||
| self.check_context_handle() | |||
| return self._context_handle.get_global_rank() | |||
| def set_pipeline_stages(self, stages): | |||
| """Set the stages of the pipeline""" | |||
| self.check_context_handle() | |||
| self._context_handle.set_pipeline_stage_split_num(stages) | |||
| def get_pipeline_stages(self): | |||
| """Get the stages of the pipeline""" | |||
| self.check_context_handle() | |||
| return self._context_handle.get_pipeline_stage_split_num() | |||
| def set_gradients_mean(self, gradients_mean): | |||
| """ | |||
| Set gradients_mean flag. | |||
| @@ -466,6 +476,7 @@ _set_auto_parallel_context_func_map = { | |||
| "gradients_mean": auto_parallel_context().set_gradients_mean, | |||
| "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync, | |||
| "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, | |||
| "pipeline_stages": auto_parallel_context().set_pipeline_stages, | |||
| "parallel_mode": auto_parallel_context().set_parallel_mode, | |||
| "auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode, | |||
| "parameter_broadcast": auto_parallel_context().set_parameter_broadcast, | |||
| @@ -482,6 +493,7 @@ _get_auto_parallel_context_func_map = { | |||
| "gradients_mean": auto_parallel_context().get_gradients_mean, | |||
| "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync, | |||
| "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, | |||
| "pipeline_stages": auto_parallel_context().get_pipeline_stages, | |||
| "parallel_mode": auto_parallel_context().get_parallel_mode, | |||
| "auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode, | |||
| "parameter_broadcast": auto_parallel_context().get_parameter_broadcast, | |||
| @@ -569,7 +581,6 @@ def _get_auto_parallel_context(attr_key): | |||
| get_func = _get_auto_parallel_context_func_map[attr_key] | |||
| return get_func() | |||
| def _reset_auto_parallel_context(): | |||
| """ | |||
| Reset auto parallel context attributes to the default values: | |||
| @@ -584,5 +595,6 @@ def _reset_auto_parallel_context(): | |||
| - strategy_ckpt_save_file: "" | |||
| - enable_parallel_optimizer: False | |||
| - auto_parallel_search_mode: dynamic_programming | |||
| - pipeline_stages: 0 | |||
| """ | |||
| auto_parallel_context().reset() | |||
| @@ -83,6 +83,39 @@ TEST_F(TestDeviceMatrix, TestCornerCaseGetAlongDim) { | |||
| EXPECT_THROW({ DeviceMatrix arr(3, dev_list, shape); }, std::runtime_error); | |||
| } | |||
| TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapRandomOrderSliceOne) { | |||
| RankList dev_list = {10, 3, 2, 9, 11, 100, 1, 0}; | |||
| Shape tensor_map = {-1, 0}; | |||
| RankList rank_list; | |||
| Shape shape = {4, 2}; | |||
| DeviceMatrix arr(0, dev_list, shape); | |||
| arr.GetDevicesByTensorMap(tensor_map, &rank_list); | |||
| RankList rank_list_except = {3, 9, 100, 0}; | |||
| ASSERT_EQ(rank_list, rank_list_except); | |||
| } | |||
| TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapRandomOrderSliceTwo) { | |||
| RankList dev_list = {10, 3, 2, 9, 11, 100, 1, 0}; | |||
| Shape tensor_map = {1, 0}; | |||
| RankList rank_list; | |||
| Shape shape = {4, 2}; | |||
| DeviceMatrix arr(0, dev_list, shape); | |||
| arr.GetDevicesByTensorMap(tensor_map, &rank_list); | |||
| RankList rank_list_except = {0}; | |||
| ASSERT_EQ(rank_list, rank_list_except); | |||
| } | |||
| TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapNoramalOrder2D) { | |||
| RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7}; | |||
| Shape tensor_map = {-1, 0}; | |||
| RankList rank_list; | |||
| Shape shape = {4, 2}; | |||
| DeviceMatrix arr(6, dev_list, shape); | |||
| arr.GetDevicesByTensorMap(tensor_map, &rank_list); | |||
| RankList rank_list_except = {0, 2, 4, 6}; | |||
| ASSERT_EQ(rank_list, rank_list_except); | |||
| } | |||
| TEST_F(TestDeviceMatrix, TestCornerCase2GetAlongDim) { | |||
| // Rank is out of range | |||
| RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7}; | |||
| @@ -0,0 +1,89 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.common.api import _executor | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||
| grad_all = C.GradOperation(get_all=True) | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| predict = self.network(x, y) | |||
| return self.loss(predict) | |||
| class GradWrap(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradWrap, self).__init__() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| return grad_all(self.network)(x, y) | |||
| class Net(nn.Cell): | |||
| def __init__(self, axis=0, stage1=0, stage2=0, strategy1=None, strategy2=None, shape=None, target=""): | |||
| super().__init__() | |||
| if shape is None: | |||
| shape = [64, 64] | |||
| self.gatherv2 = P.GatherV2().shard(strategy1).add_prim_attr("primitive_target", target) | |||
| self.mul = P.Mul().shard(strategy2) | |||
| self.index = Tensor(np.ones(shape), dtype=ms.int32) | |||
| self.gatherv2.set_stage(stage1) | |||
| self.mul.set_stage(stage2) | |||
| self.axis = axis | |||
| def construct(self, x, y): | |||
| out = self.gatherv2(x, self.index, self.axis) | |||
| out = self.mul(out, y) | |||
| return out | |||
| def test_gatherv2_semi_samestage1(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, \ | |||
| parallel_mode="semi_auto_parallel", pipeline_stages=2) | |||
| strategy1 = ((1, 2), (1, 1)) | |||
| strategy2 = ((2, 1, 1), (2, 1, 1)) | |||
| net = GradWrap(NetWithLoss(Net(0, 0, 0, strategy1, strategy2))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | |||
| _executor.compile(net, x, y) | |||
| def test_gatherv2_semi_samestage2(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=5, \ | |||
| parallel_mode="semi_auto_parallel", pipeline_stages=2) | |||
| strategy1 = ((1, 2), (1, 1)) | |||
| strategy2 = ((2, 1, 1), (2, 1, 1)) | |||
| net = GradWrap(NetWithLoss(Net(0, 1, 1, strategy1, strategy2))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | |||
| _executor.compile(net, x, y) | |||
| @@ -81,6 +81,11 @@ def test_set_auto_parallel_context(): | |||
| assert context.get_auto_parallel_context("enable_parallel_optimizer") | |||
| assert not auto_parallel_context().get_all_reduce_fusion_split_indices() | |||
| def test_pipeline_parallel_context(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=4, | |||
| parallel_mode="semi_auto_parallel", pipeline_stages=2) | |||
| stage = auto_parallel_context().get_pipeline_stages() | |||
| assert stage == 2 | |||
| def test_reset_auto_parallel_context(): | |||
| context.reset_auto_parallel_context() | |||
| @@ -92,6 +97,8 @@ def test_reset_auto_parallel_context(): | |||
| parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") | |||
| device_num_is_set = auto_parallel_context().get_device_num_is_set() | |||
| parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() | |||
| stage = auto_parallel_context().get_pipeline_stages() | |||
| assert device_num == 1 | |||
| assert global_rank == 0 | |||
| assert not gradients_mean | |||
| @@ -100,3 +107,4 @@ def test_reset_auto_parallel_context(): | |||
| assert not parameter_broadcast | |||
| assert not device_num_is_set | |||
| assert not parameter_broadcast_is_set | |||
| assert not stage | |||