| @@ -113,7 +113,7 @@ Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, co | |||
| size_t type_length, TypePtr type, CostPtr *cost) { | |||
| MS_EXCEPTION_IF_NULL(prev_op_); | |||
| MS_EXCEPTION_IF_NULL(cost); | |||
| RankList dev_list = prev_op_->global_device_list(); | |||
| RankList dev_list = prev_op_->stage_device_list(); | |||
| TensorRedistribution tensor_redistribution(false); | |||
| // Init TensorRedistribution | |||
| @@ -140,7 +140,12 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, | |||
| const std::string &backend) { | |||
| if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) { | |||
| MS_LOG(ERROR) << "Invalid backend: " << backend; | |||
| return Status::FAILED; | |||
| return FAILED; | |||
| } | |||
| if (stage_map.empty() || devices.empty()) { | |||
| MS_LOG(ERROR) << "The size of stage_map and devices must be positive"; | |||
| return FAILED; | |||
| } | |||
| for (auto &dev : devices) { | |||
| @@ -153,11 +158,11 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, | |||
| int64_t num_device = stage; | |||
| if (num_device > MAX_DEVICE_NUM) { | |||
| MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM; | |||
| return Status::FAILED; | |||
| return FAILED; | |||
| } | |||
| if (num_device <= 0) { | |||
| MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive"; | |||
| return Status::FAILED; | |||
| return FAILED; | |||
| } | |||
| RankList curr_dev_list; | |||
| for (int64_t i = 0; i < num_device; ++i) { | |||
| @@ -170,10 +175,11 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, | |||
| std::shared_ptr<Device> dev = std::make_shared<Device>(global_device_rank); | |||
| device_ = dev; | |||
| set_global_rank(global_device_rank); | |||
| set_stage_num(static_cast<const int64_t>(stage_map.size())); | |||
| int64_t stage_id = global_device_rank / static_cast<const int64_t>(devices.size() / stage_map.size()); | |||
| set_stage_id(stage_id); | |||
| global_rank_ = global_device_rank; | |||
| stage_num_ = static_cast<const int64_t>(stage_map.size()); | |||
| stage_id_ = global_device_rank / static_cast<const int64_t>(devices.size() / stage_map.size()); | |||
| rank_index_in_stage_ = global_rank_ - stage_id_ * (static_cast<const int64_t>(devices.size()) / stage_num_); | |||
| stage_device_num_ = static_cast<const int64_t>(devices.size()) / stage_num_; | |||
| backend_ = backend; | |||
| @@ -185,10 +191,13 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, | |||
| gm_.set_world_group(UNDEFINED_WORLD_GROUP); | |||
| } | |||
| MS_LOG(INFO) << "The device num: " << devices.size() << ", rank id: " << global_device_rank | |||
| << ", the backend: " << backend << ", the stage num: " << stage_num() << ", the stage id: " << stage_id; | |||
| return Status::SUCCESS; | |||
| << ", the backend: " << backend << ", the stage num: " << stage_num_ << ", the stage id: " << stage_id_ | |||
| << ", the rank index in stage is: " << rank_index_in_stage_; | |||
| return SUCCESS; | |||
| } | |||
| RankList DeviceManager::GetDeviceListInThisStage() const { return GetDeviceListByStageId(stage_id_); } | |||
| RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const { | |||
| if (LongToSize(stage_id) >= stage_devices_.size()) | |||
| MS_LOG(ERROR) << "the 'stage_id': " << stage_id | |||
| @@ -204,49 +213,6 @@ RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const { | |||
| return res; | |||
| } | |||
| RankList DeviceManager::global_device_list(int64_t stage_id, int64_t rank, int64_t split_num) const { | |||
| RankList res; | |||
| if (split_num <= 0) { | |||
| return res; | |||
| } | |||
| if (LongToSize(stage_id) >= stage_devices_.size()) { | |||
| MS_LOG(ERROR) << "the 'stage_id': " << stage_id | |||
| << ", is out of the scope of 'stage_devices_': " << stage_devices_.size(); | |||
| return res; | |||
| } | |||
| RankList global_list = GetDeviceListByStageId(stage_id); | |||
| if (global_list.size() % LongToSize(split_num)) { | |||
| MS_LOG(ERROR) << "dev list size(" << global_list.size() << ") can not be divisible by split num: " << stage_id; | |||
| return res; | |||
| } | |||
| std::vector<int64_t> dev_list; | |||
| (void)std::copy(global_list.begin(), global_list.end(), std::back_inserter(dev_list)); | |||
| size_t index = 0; | |||
| size_t slice_size = dev_list.size() / LongToSize(split_num); | |||
| for (int64_t i = 0; i < split_num; ++i) { | |||
| bool found = false; | |||
| index = slice_size * LongToSize(i); | |||
| for (size_t j = 0; j < slice_size; ++j) { | |||
| if (dev_list[index + j] == rank) { | |||
| found = true; | |||
| break; | |||
| } | |||
| } | |||
| if (found) { | |||
| break; | |||
| } | |||
| } | |||
| for (size_t k = 0; k < slice_size; ++k) { | |||
| res.push_back(dev_list[index + k]); | |||
| } | |||
| return res; | |||
| } | |||
| Device DeviceManager::CreateNewDeviceByRank(int64_t rank) const { return Device(rank); } | |||
| std::vector<Device> DeviceManager::CreateDeviceListByRankList(RankList ranks) { | |||
| @@ -57,14 +57,14 @@ std::string HashName(const std::string &rank_list_name); | |||
| class DeviceManager { | |||
| // This class is used to manage the abstract devices, including group-related and stage-related management. | |||
| public: | |||
| DeviceManager() : local_rank_(0), global_rank_(0), stage_num_(1), stage_id_(0) { gm_ = GroupManager(); } | |||
| DeviceManager() { gm_ = GroupManager(); } | |||
| ~DeviceManager() = default; | |||
| Status Init(const RankList &devices, int64_t local_device, const RankList &stage_map, const std::string &backend); | |||
| static DeviceManager &GetInstance(); | |||
| RankList GetDeviceListByStageId(int64_t stage_id) const; | |||
| RankList global_device_list(int64_t stage_id, int64_t rank, int64_t split_num) const; | |||
| RankList GetDeviceListInThisStage() const; | |||
| Device CreateNewDeviceByRank(int64_t rank) const; | |||
| std::vector<Device> CreateDeviceListByRankList(RankList ranks); | |||
| @@ -74,17 +74,11 @@ class DeviceManager { | |||
| Group CreateGroup(const RankList &dev_ranks); | |||
| size_t DeviceNum() const { return devices_.size(); } | |||
| int64_t stage_num() const { return stage_num_; } | |||
| void set_stage_num(int64_t num) { stage_num_ = num; } | |||
| int64_t stage_id() const { return stage_id_; } | |||
| void set_stage_id(int64_t id) { stage_id_ = id; } | |||
| std::string backend() const { return backend_; } | |||
| int64_t rank_index_in_stage() const { return rank_index_in_stage_; } | |||
| int64_t global_rank() const { return global_rank_; } | |||
| void set_global_rank(int64_t global_rank) { global_rank_ = global_rank; } | |||
| std::string backend() const { return backend_; } | |||
| void Clear(); | |||
| std::string world_group() const { return gm_.world_group(); } | |||
| @@ -102,10 +96,11 @@ class DeviceManager { | |||
| std::map<std::string, std::string> rank_to_group_; // the key is rank list, value is hash name | |||
| std::map<std::string, std::string> group_to_rank_; // the key is hash name, value is rank list | |||
| int64_t local_rank_; | |||
| int64_t global_rank_; | |||
| int64_t stage_num_; | |||
| int64_t stage_id_; | |||
| int64_t global_rank_ = 0; // the real rank in all devices | |||
| int64_t stage_num_ = 0; // the stage num | |||
| int64_t stage_id_ = 0; // the stage id of the global_rank_ | |||
| int64_t rank_index_in_stage_ = 0; // the index of this rank in it's stage | |||
| int64_t stage_device_num_ = 0; // the device num of one stage | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -232,7 +232,7 @@ Status GatherV2Info::InferTensorSubOps() { | |||
| MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ")."; | |||
| } | |||
| int64_t mod_p = mod_n * dev_matrix_shape_.at(axis_); | |||
| int64_t rank = g_device_manager->global_rank(); | |||
| int64_t rank = g_device_manager->rank_index_in_stage(); | |||
| int64_t mod_rank = rank % mod_p; | |||
| mod_rank = static_cast<int64_t>(mod_rank / mod_n); | |||
| if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | |||
| @@ -451,7 +451,7 @@ Status GatherV2PInfo::InferTensorInfo() { | |||
| Shape input_shape = inputs_shape_.at(0); | |||
| Shape input_index_shape = inputs_shape_.at(1); | |||
| Shape output_shape = outputs_shape_.at(0); | |||
| int64_t rank = g_device_manager->global_rank(); | |||
| int64_t rank = g_device_manager->rank_index_in_stage(); | |||
| // infer tensor layout | |||
| TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; | |||
| if (manual_split_) { | |||
| @@ -481,7 +481,7 @@ Status GatherV2PInfo::InferTensorInfo() { | |||
| Status GatherV2PInfo::InferBias() { | |||
| CheckGlobalDeviceManager(); | |||
| int64_t rank = g_device_manager->global_rank(); | |||
| int64_t rank = g_device_manager->rank_index_in_stage(); | |||
| auto input_shape = inputs_shape_.at(0); | |||
| auto params_strategy = strategy_->GetInputDim().at(0); | |||
| // axis don't split | |||
| @@ -513,7 +513,7 @@ Status GatherV2PInfo::InferBias() { | |||
| Status GatherV2PInfo::InferOffset() { | |||
| CheckGlobalDeviceManager(); | |||
| size_t rank = g_device_manager->global_rank(); | |||
| size_t rank = g_device_manager->rank_index_in_stage(); | |||
| MS_EXCEPTION_IF_NULL(strategy_); | |||
| auto param_strategy = strategy_->GetInputDim()[0]; | |||
| @@ -134,7 +134,7 @@ Status OneHotInfo::InferTensorInfo() { | |||
| Status OneHotInfo::ExtractInputInfo() { | |||
| CheckGlobalDeviceManager(); | |||
| rank_ = g_device_manager->global_rank(); | |||
| rank_ = g_device_manager->rank_index_in_stage(); | |||
| mod_rank_ = rank_ % old_dev_matrix_back_; | |||
| if (!cnode_) { | |||
| MS_LOG(ERROR) << "Failure:OneHot cnode_ is nullptr"; | |||
| @@ -116,7 +116,6 @@ void OperatorInfo::ResetQueueMember() { | |||
| replace_op_.clear(); | |||
| replace_op_info_.clear(); | |||
| virtual_div_op_.clear(); | |||
| global_device_list_.clear(); | |||
| } | |||
| Status OperatorInfo::InferAttrs() { | |||
| @@ -131,14 +130,8 @@ Status OperatorInfo::InferAttrs() { | |||
| return SUCCESS; | |||
| } | |||
| void OperatorInfo::SetDeviceListByStrategy() { | |||
| int64_t stage = strategy_->GetInputStage(); | |||
| CheckGlobalDeviceManager(); | |||
| global_device_list_ = g_device_manager->GetDeviceListByStageId(stage); | |||
| } | |||
| Status OperatorInfo::InferRepeatedCalcInfo() { | |||
| int64_t g_dev_list_size = SizeToLong(global_device_list_.size()); | |||
| int64_t g_dev_list_size = stage_device_size_; | |||
| int64_t dev_matrix_size = | |||
| std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>()); | |||
| if (dev_matrix_size == 0) { | |||
| @@ -155,12 +148,6 @@ Status OperatorInfo::InferRepeatedCalcInfo() { | |||
| << dev_matrix_size; | |||
| return FAILED; | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| int64_t rank = g_device_manager->global_rank(); | |||
| int64_t stage = strategy_->GetInputStage(); | |||
| local_device_list_ = g_device_manager->global_device_list(stage, rank, repeated_calc_num_); | |||
| return SUCCESS; | |||
| } | |||
| @@ -331,7 +318,7 @@ Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| int64_t rank = g_device_manager->global_rank(); | |||
| DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_); | |||
| DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_); | |||
| RankList group_devices; | |||
| if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) { | |||
| return FAILED; | |||
| @@ -354,7 +341,7 @@ Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector<Group> *group) { | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| int64_t rank = g_device_manager->global_rank(); | |||
| DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_); | |||
| DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_); | |||
| RankList group_devices; | |||
| if (dev_matrix.GetDevicesAlongDim(SizeToUlong(axis), &group_devices) != SUCCESS) { | |||
| return FAILED; | |||
| @@ -469,7 +456,6 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strat | |||
| ResetQueueMember(); | |||
| strategy_ = strategy; | |||
| SetDeviceListByStrategy(); | |||
| if (InferDevMatrixShape() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; | |||
| @@ -526,7 +512,6 @@ Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr &str | |||
| ResetQueueMember(); | |||
| strategy_ = strategy; | |||
| SetDeviceListByStrategy(); | |||
| if (InferDevMatrixShape() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; | |||
| @@ -1325,7 +1310,7 @@ Status OperatorInfo::InferAsLossDivisor() { | |||
| } | |||
| if (outputs_tensor_map_[0].empty()) { | |||
| as_loss_divisor_ = SizeToLong(global_device_list_.size()); | |||
| as_loss_divisor_ = stage_device_size_; | |||
| MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; | |||
| return SUCCESS; | |||
| } | |||
| @@ -64,6 +64,8 @@ class OperatorInfo { | |||
| std::vector<bool> not_parameteter(inputs_shape_.size(), false); | |||
| is_parameter_ = not_parameteter; | |||
| refkey_parameter_name_ = ""; | |||
| stage_device_list_ = g_device_manager->GetDeviceListInThisStage(); | |||
| stage_device_size_ = SizeToLong(stage_device_list_.size()); | |||
| } | |||
| virtual ~OperatorInfo() = default; | |||
| @@ -119,7 +121,7 @@ class OperatorInfo { | |||
| std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost() const { return strategy_cost_; } | |||
| const std::string &name() const { return name_; } | |||
| void set_name(const std::string &name) { name_ = name; } | |||
| RankList global_device_list() const { return global_device_list_; } | |||
| RankList stage_device_list() const { return stage_device_list_; } | |||
| void AddSuccEdge(const std::shared_ptr<Edge> &e) { succ_edges_.push_back(e); } | |||
| void AddPrevEdge(const std::shared_ptr<Edge> &e) { prev_edges_.push_back(e); } | |||
| @@ -187,7 +189,6 @@ class OperatorInfo { | |||
| virtual Status InferTensorInfo() = 0; | |||
| virtual Status InferDevMatrixShape() = 0; | |||
| Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape); | |||
| void SetDeviceListByStrategy(); | |||
| void SetRepeatedCalcDevMatrix(); | |||
| void ResetTensorMapIfRepeatedCalc(); | |||
| Status CreateGroupByDim(size_t axis, std::vector<Group> *group); | |||
| @@ -231,8 +232,8 @@ class OperatorInfo { | |||
| ReplaceGraphPtr replace_graph_; | |||
| MirrorOps mirror_ops_; | |||
| VirtualDivOp virtual_div_op_; | |||
| RankList global_device_list_; // the size of global_device_list equal to the size of stageID | |||
| RankList local_device_list_; // the size equal to global_device_list_.size() / repeated_calc_num_ | |||
| RankList stage_device_list_; // the device list in this stage | |||
| int64_t stage_device_size_ = 0; | |||
| bool infer_attrs_completed_ = false; | |||
| bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel | |||
| @@ -136,7 +136,7 @@ Status RangeInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| Status RangeInfo::InferNewAttr() { | |||
| CheckGlobalDeviceManager(); | |||
| int64_t rank = g_device_manager->global_rank(); | |||
| int64_t rank = g_device_manager->rank_index_in_stage(); | |||
| // If repeated calculation and repeated num as the last dimension of dev-matrix, | |||
| // the dev-matrix is [split_num_, repeated_calc_num_], so from rank 0 to rank repeated_calc_num_ | |||
| @@ -531,7 +531,7 @@ Status ArgMaxWithValueInfo::InferAsLossDivisor() { | |||
| MS_LOG(INFO) << name_ << " has two outputs, use output[0] to infer"; | |||
| if (outputs_tensor_map_[0].empty()) { | |||
| as_loss_divisor_ = SizeToLong(global_device_list_.size()); | |||
| as_loss_divisor_ = stage_device_size_; | |||
| MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size" << as_loss_divisor_ << " as loss divisor."; | |||
| return SUCCESS; | |||
| } | |||
| @@ -172,7 +172,7 @@ Status ReLUV2Info::InferAsLossDivisor() { | |||
| } | |||
| if (outputs_tensor_map_[0].empty()) { | |||
| as_loss_divisor_ = SizeToInt(global_device_list_.size()); | |||
| as_loss_divisor_ = stage_device_size_; | |||
| MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; | |||
| return SUCCESS; | |||
| } | |||
| @@ -113,7 +113,7 @@ Status ReshapeInfo::GetParameterInput() { | |||
| } | |||
| Status ReshapeInfo::ComputeReplaceOp() { | |||
| RankList dev_list = global_device_list(); | |||
| RankList dev_list = stage_device_list(); | |||
| TensorRedistribution tensor_redistribution(!is_generating_costs_, true); | |||
| if (tensor_redistribution.Init(input_layout_, output_layout_, dev_list) == FAILED) { | |||
| if (is_generating_costs_) { | |||
| @@ -289,13 +289,7 @@ void ReshapeInfo::InferTensorInfoByLayout() { | |||
| Status ReshapeInfo::GetAttrs() { return GetParameterInput(); } | |||
| void ReshapeInfo::device_number(const StrategyPtr &strategy) { | |||
| int64_t stage = 0; | |||
| if (strategy != nullptr) { | |||
| stage = strategy->GetInputStage(); | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| global_device_list_ = g_device_manager->GetDeviceListByStageId(stage); | |||
| dev_num_ = SizeToLong(global_device_list_.size()); | |||
| dev_num_ = stage_device_size_; | |||
| MS_ASSERT(dev_num_ > 0); | |||
| } | |||
| @@ -260,7 +260,7 @@ Status SplitInfo::InferAsLossDivisor() { | |||
| } | |||
| if (outputs_tensor_map_[0].empty()) { | |||
| as_loss_divisor_ = SizeToInt(global_device_list_.size()); | |||
| as_loss_divisor_ = stage_device_size_; | |||
| MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; | |||
| return SUCCESS; | |||
| } | |||
| @@ -325,7 +325,7 @@ void Redistribution(const std::pair<AnfNodePtr, int64_t> &node_pair, const Opera | |||
| if (next_distribute_operator == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure: " << next_node->ToString() << " GetDistributeOperator failed"; | |||
| } | |||
| RankList dev_list = distribute_operator->global_device_list(); | |||
| RankList dev_list = distribute_operator->stage_device_list(); | |||
| std::string next_prim_name = GetValueNode<PrimitivePtr>(next_node->input(0))->name(); | |||
| MS_LOG(DEBUG) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim " << next_prim_name; | |||
| MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString(); | |||
| @@ -161,6 +161,8 @@ class EmbeddingLookup(Cell): | |||
| Examples: | |||
| >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) | |||
| >>> out = nn.EmbeddingLookup(4,2)(input_indices) | |||
| >>> output.shape | |||
| (2, 2, 2) | |||
| """ | |||
| BATCH_SLICE = "batch_slice" | |||
| FIELD_SLICE = "field_slice" | |||
| @@ -135,6 +135,8 @@ TEST_F(TestDeviceManager, test_StageID) { | |||
| ASSERT_EQ(dm_.DeviceNum(), 4); | |||
| ASSERT_EQ(dm_.stage_num(), 2); | |||
| ASSERT_EQ(dm_.stage_id(), 1); | |||
| ASSERT_EQ(dm_.rank_index_in_stage(), 0); | |||
| ASSERT_EQ(dm_.GetDeviceListInThisStage().back(), 3); | |||
| RankList dev_list_0 = dm_.GetDeviceListByStageId(0); | |||
| RankList dev_list_1 = dm_.GetDeviceListByStageId(1); | |||
| @@ -171,7 +171,7 @@ TEST_F(TestLogSoftmaxInfo, GetDeviceList1) { | |||
| StrategyPtr strategy = NewStrategy(0, inputs); | |||
| log_softmax->Init(strategy); | |||
| RankList dev_list = log_softmax->global_device_list(); | |||
| RankList dev_list = log_softmax->stage_device_list(); | |||
| ASSERT_EQ(dev_list.size(), 128); | |||
| } | |||