| @@ -113,7 +113,7 @@ Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, co | |||||
| size_t type_length, TypePtr type, CostPtr *cost) { | size_t type_length, TypePtr type, CostPtr *cost) { | ||||
| MS_EXCEPTION_IF_NULL(prev_op_); | MS_EXCEPTION_IF_NULL(prev_op_); | ||||
| MS_EXCEPTION_IF_NULL(cost); | MS_EXCEPTION_IF_NULL(cost); | ||||
| RankList dev_list = prev_op_->global_device_list(); | |||||
| RankList dev_list = prev_op_->stage_device_list(); | |||||
| TensorRedistribution tensor_redistribution(false); | TensorRedistribution tensor_redistribution(false); | ||||
| // Init TensorRedistribution | // Init TensorRedistribution | ||||
| @@ -140,7 +140,12 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, | |||||
| const std::string &backend) { | const std::string &backend) { | ||||
| if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) { | if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) { | ||||
| MS_LOG(ERROR) << "Invalid backend: " << backend; | MS_LOG(ERROR) << "Invalid backend: " << backend; | ||||
| return Status::FAILED; | |||||
| return FAILED; | |||||
| } | |||||
| if (stage_map.empty() || devices.empty()) { | |||||
| MS_LOG(ERROR) << "The size of stage_map and devices must be positive"; | |||||
| return FAILED; | |||||
| } | } | ||||
| for (auto &dev : devices) { | for (auto &dev : devices) { | ||||
| @@ -153,11 +158,11 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, | |||||
| int64_t num_device = stage; | int64_t num_device = stage; | ||||
| if (num_device > MAX_DEVICE_NUM) { | if (num_device > MAX_DEVICE_NUM) { | ||||
| MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM; | MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM; | ||||
| return Status::FAILED; | |||||
| return FAILED; | |||||
| } | } | ||||
| if (num_device <= 0) { | if (num_device <= 0) { | ||||
| MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive"; | MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive"; | ||||
| return Status::FAILED; | |||||
| return FAILED; | |||||
| } | } | ||||
| RankList curr_dev_list; | RankList curr_dev_list; | ||||
| for (int64_t i = 0; i < num_device; ++i) { | for (int64_t i = 0; i < num_device; ++i) { | ||||
| @@ -170,10 +175,11 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, | |||||
| std::shared_ptr<Device> dev = std::make_shared<Device>(global_device_rank); | std::shared_ptr<Device> dev = std::make_shared<Device>(global_device_rank); | ||||
| device_ = dev; | device_ = dev; | ||||
| set_global_rank(global_device_rank); | |||||
| set_stage_num(static_cast<const int64_t>(stage_map.size())); | |||||
| int64_t stage_id = global_device_rank / static_cast<const int64_t>(devices.size() / stage_map.size()); | |||||
| set_stage_id(stage_id); | |||||
| global_rank_ = global_device_rank; | |||||
| stage_num_ = static_cast<const int64_t>(stage_map.size()); | |||||
| stage_id_ = global_device_rank / static_cast<const int64_t>(devices.size() / stage_map.size()); | |||||
| rank_index_in_stage_ = global_rank_ - stage_id_ * (static_cast<const int64_t>(devices.size()) / stage_num_); | |||||
| stage_device_num_ = static_cast<const int64_t>(devices.size()) / stage_num_; | |||||
| backend_ = backend; | backend_ = backend; | ||||
| @@ -185,10 +191,13 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, | |||||
| gm_.set_world_group(UNDEFINED_WORLD_GROUP); | gm_.set_world_group(UNDEFINED_WORLD_GROUP); | ||||
| } | } | ||||
| MS_LOG(INFO) << "The device num: " << devices.size() << ", rank id: " << global_device_rank | MS_LOG(INFO) << "The device num: " << devices.size() << ", rank id: " << global_device_rank | ||||
| << ", the backend: " << backend << ", the stage num: " << stage_num() << ", the stage id: " << stage_id; | |||||
| return Status::SUCCESS; | |||||
| << ", the backend: " << backend << ", the stage num: " << stage_num_ << ", the stage id: " << stage_id_ | |||||
| << ", the rank index in stage is: " << rank_index_in_stage_; | |||||
| return SUCCESS; | |||||
| } | } | ||||
| RankList DeviceManager::GetDeviceListInThisStage() const { return GetDeviceListByStageId(stage_id_); } | |||||
| RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const { | RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const { | ||||
| if (LongToSize(stage_id) >= stage_devices_.size()) | if (LongToSize(stage_id) >= stage_devices_.size()) | ||||
| MS_LOG(ERROR) << "the 'stage_id': " << stage_id | MS_LOG(ERROR) << "the 'stage_id': " << stage_id | ||||
| @@ -204,49 +213,6 @@ RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const { | |||||
| return res; | return res; | ||||
| } | } | ||||
| RankList DeviceManager::global_device_list(int64_t stage_id, int64_t rank, int64_t split_num) const { | |||||
| RankList res; | |||||
| if (split_num <= 0) { | |||||
| return res; | |||||
| } | |||||
| if (LongToSize(stage_id) >= stage_devices_.size()) { | |||||
| MS_LOG(ERROR) << "the 'stage_id': " << stage_id | |||||
| << ", is out of the scope of 'stage_devices_': " << stage_devices_.size(); | |||||
| return res; | |||||
| } | |||||
| RankList global_list = GetDeviceListByStageId(stage_id); | |||||
| if (global_list.size() % LongToSize(split_num)) { | |||||
| MS_LOG(ERROR) << "dev list size(" << global_list.size() << ") can not be divisible by split num: " << stage_id; | |||||
| return res; | |||||
| } | |||||
| std::vector<int64_t> dev_list; | |||||
| (void)std::copy(global_list.begin(), global_list.end(), std::back_inserter(dev_list)); | |||||
| size_t index = 0; | |||||
| size_t slice_size = dev_list.size() / LongToSize(split_num); | |||||
| for (int64_t i = 0; i < split_num; ++i) { | |||||
| bool found = false; | |||||
| index = slice_size * LongToSize(i); | |||||
| for (size_t j = 0; j < slice_size; ++j) { | |||||
| if (dev_list[index + j] == rank) { | |||||
| found = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (found) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| for (size_t k = 0; k < slice_size; ++k) { | |||||
| res.push_back(dev_list[index + k]); | |||||
| } | |||||
| return res; | |||||
| } | |||||
| Device DeviceManager::CreateNewDeviceByRank(int64_t rank) const { return Device(rank); } | Device DeviceManager::CreateNewDeviceByRank(int64_t rank) const { return Device(rank); } | ||||
| std::vector<Device> DeviceManager::CreateDeviceListByRankList(RankList ranks) { | std::vector<Device> DeviceManager::CreateDeviceListByRankList(RankList ranks) { | ||||
| @@ -57,14 +57,14 @@ std::string HashName(const std::string &rank_list_name); | |||||
| class DeviceManager { | class DeviceManager { | ||||
| // This class is used to manage the abstract devices, including group-related and stage-related management. | // This class is used to manage the abstract devices, including group-related and stage-related management. | ||||
| public: | public: | ||||
| DeviceManager() : local_rank_(0), global_rank_(0), stage_num_(1), stage_id_(0) { gm_ = GroupManager(); } | |||||
| DeviceManager() { gm_ = GroupManager(); } | |||||
| ~DeviceManager() = default; | ~DeviceManager() = default; | ||||
| Status Init(const RankList &devices, int64_t local_device, const RankList &stage_map, const std::string &backend); | Status Init(const RankList &devices, int64_t local_device, const RankList &stage_map, const std::string &backend); | ||||
| static DeviceManager &GetInstance(); | static DeviceManager &GetInstance(); | ||||
| RankList GetDeviceListByStageId(int64_t stage_id) const; | RankList GetDeviceListByStageId(int64_t stage_id) const; | ||||
| RankList global_device_list(int64_t stage_id, int64_t rank, int64_t split_num) const; | |||||
| RankList GetDeviceListInThisStage() const; | |||||
| Device CreateNewDeviceByRank(int64_t rank) const; | Device CreateNewDeviceByRank(int64_t rank) const; | ||||
| std::vector<Device> CreateDeviceListByRankList(RankList ranks); | std::vector<Device> CreateDeviceListByRankList(RankList ranks); | ||||
| @@ -74,17 +74,11 @@ class DeviceManager { | |||||
| Group CreateGroup(const RankList &dev_ranks); | Group CreateGroup(const RankList &dev_ranks); | ||||
| size_t DeviceNum() const { return devices_.size(); } | size_t DeviceNum() const { return devices_.size(); } | ||||
| int64_t stage_num() const { return stage_num_; } | int64_t stage_num() const { return stage_num_; } | ||||
| void set_stage_num(int64_t num) { stage_num_ = num; } | |||||
| int64_t stage_id() const { return stage_id_; } | int64_t stage_id() const { return stage_id_; } | ||||
| void set_stage_id(int64_t id) { stage_id_ = id; } | |||||
| std::string backend() const { return backend_; } | |||||
| int64_t rank_index_in_stage() const { return rank_index_in_stage_; } | |||||
| int64_t global_rank() const { return global_rank_; } | int64_t global_rank() const { return global_rank_; } | ||||
| void set_global_rank(int64_t global_rank) { global_rank_ = global_rank; } | |||||
| std::string backend() const { return backend_; } | |||||
| void Clear(); | void Clear(); | ||||
| std::string world_group() const { return gm_.world_group(); } | std::string world_group() const { return gm_.world_group(); } | ||||
| @@ -102,10 +96,11 @@ class DeviceManager { | |||||
| std::map<std::string, std::string> rank_to_group_; // the key is rank list, value is hash name | std::map<std::string, std::string> rank_to_group_; // the key is rank list, value is hash name | ||||
| std::map<std::string, std::string> group_to_rank_; // the key is hash name, value is rank list | std::map<std::string, std::string> group_to_rank_; // the key is hash name, value is rank list | ||||
| int64_t local_rank_; | |||||
| int64_t global_rank_; | |||||
| int64_t stage_num_; | |||||
| int64_t stage_id_; | |||||
| int64_t global_rank_ = 0; // the real rank in all devices | |||||
| int64_t stage_num_ = 0; // the stage num | |||||
| int64_t stage_id_ = 0; // the stage id of the global_rank_ | |||||
| int64_t rank_index_in_stage_ = 0; // the index of this rank in it's stage | |||||
| int64_t stage_device_num_ = 0; // the device num of one stage | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -232,7 +232,7 @@ Status GatherV2Info::InferTensorSubOps() { | |||||
| MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ")."; | MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ")."; | ||||
| } | } | ||||
| int64_t mod_p = mod_n * dev_matrix_shape_.at(axis_); | int64_t mod_p = mod_n * dev_matrix_shape_.at(axis_); | ||||
| int64_t rank = g_device_manager->global_rank(); | |||||
| int64_t rank = g_device_manager->rank_index_in_stage(); | |||||
| int64_t mod_rank = rank % mod_p; | int64_t mod_rank = rank % mod_p; | ||||
| mod_rank = static_cast<int64_t>(mod_rank / mod_n); | mod_rank = static_cast<int64_t>(mod_rank / mod_n); | ||||
| if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | ||||
| @@ -451,7 +451,7 @@ Status GatherV2PInfo::InferTensorInfo() { | |||||
| Shape input_shape = inputs_shape_.at(0); | Shape input_shape = inputs_shape_.at(0); | ||||
| Shape input_index_shape = inputs_shape_.at(1); | Shape input_index_shape = inputs_shape_.at(1); | ||||
| Shape output_shape = outputs_shape_.at(0); | Shape output_shape = outputs_shape_.at(0); | ||||
| int64_t rank = g_device_manager->global_rank(); | |||||
| int64_t rank = g_device_manager->rank_index_in_stage(); | |||||
| // infer tensor layout | // infer tensor layout | ||||
| TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; | TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; | ||||
| if (manual_split_) { | if (manual_split_) { | ||||
| @@ -481,7 +481,7 @@ Status GatherV2PInfo::InferTensorInfo() { | |||||
| Status GatherV2PInfo::InferBias() { | Status GatherV2PInfo::InferBias() { | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| int64_t rank = g_device_manager->global_rank(); | |||||
| int64_t rank = g_device_manager->rank_index_in_stage(); | |||||
| auto input_shape = inputs_shape_.at(0); | auto input_shape = inputs_shape_.at(0); | ||||
| auto params_strategy = strategy_->GetInputDim().at(0); | auto params_strategy = strategy_->GetInputDim().at(0); | ||||
| // axis don't split | // axis don't split | ||||
| @@ -513,7 +513,7 @@ Status GatherV2PInfo::InferBias() { | |||||
| Status GatherV2PInfo::InferOffset() { | Status GatherV2PInfo::InferOffset() { | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| size_t rank = g_device_manager->global_rank(); | |||||
| size_t rank = g_device_manager->rank_index_in_stage(); | |||||
| MS_EXCEPTION_IF_NULL(strategy_); | MS_EXCEPTION_IF_NULL(strategy_); | ||||
| auto param_strategy = strategy_->GetInputDim()[0]; | auto param_strategy = strategy_->GetInputDim()[0]; | ||||
| @@ -134,7 +134,7 @@ Status OneHotInfo::InferTensorInfo() { | |||||
| Status OneHotInfo::ExtractInputInfo() { | Status OneHotInfo::ExtractInputInfo() { | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| rank_ = g_device_manager->global_rank(); | |||||
| rank_ = g_device_manager->rank_index_in_stage(); | |||||
| mod_rank_ = rank_ % old_dev_matrix_back_; | mod_rank_ = rank_ % old_dev_matrix_back_; | ||||
| if (!cnode_) { | if (!cnode_) { | ||||
| MS_LOG(ERROR) << "Failure:OneHot cnode_ is nullptr"; | MS_LOG(ERROR) << "Failure:OneHot cnode_ is nullptr"; | ||||
| @@ -116,7 +116,6 @@ void OperatorInfo::ResetQueueMember() { | |||||
| replace_op_.clear(); | replace_op_.clear(); | ||||
| replace_op_info_.clear(); | replace_op_info_.clear(); | ||||
| virtual_div_op_.clear(); | virtual_div_op_.clear(); | ||||
| global_device_list_.clear(); | |||||
| } | } | ||||
| Status OperatorInfo::InferAttrs() { | Status OperatorInfo::InferAttrs() { | ||||
| @@ -131,14 +130,8 @@ Status OperatorInfo::InferAttrs() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void OperatorInfo::SetDeviceListByStrategy() { | |||||
| int64_t stage = strategy_->GetInputStage(); | |||||
| CheckGlobalDeviceManager(); | |||||
| global_device_list_ = g_device_manager->GetDeviceListByStageId(stage); | |||||
| } | |||||
| Status OperatorInfo::InferRepeatedCalcInfo() { | Status OperatorInfo::InferRepeatedCalcInfo() { | ||||
| int64_t g_dev_list_size = SizeToLong(global_device_list_.size()); | |||||
| int64_t g_dev_list_size = stage_device_size_; | |||||
| int64_t dev_matrix_size = | int64_t dev_matrix_size = | ||||
| std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>()); | std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>()); | ||||
| if (dev_matrix_size == 0) { | if (dev_matrix_size == 0) { | ||||
| @@ -155,12 +148,6 @@ Status OperatorInfo::InferRepeatedCalcInfo() { | |||||
| << dev_matrix_size; | << dev_matrix_size; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| CheckGlobalDeviceManager(); | |||||
| int64_t rank = g_device_manager->global_rank(); | |||||
| int64_t stage = strategy_->GetInputStage(); | |||||
| local_device_list_ = g_device_manager->global_device_list(stage, rank, repeated_calc_num_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -331,7 +318,7 @@ Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector | |||||
| } | } | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| int64_t rank = g_device_manager->global_rank(); | int64_t rank = g_device_manager->global_rank(); | ||||
| DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_); | |||||
| DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_); | |||||
| RankList group_devices; | RankList group_devices; | ||||
| if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) { | if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) { | ||||
| return FAILED; | return FAILED; | ||||
| @@ -354,7 +341,7 @@ Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector<Group> *group) { | |||||
| } | } | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| int64_t rank = g_device_manager->global_rank(); | int64_t rank = g_device_manager->global_rank(); | ||||
| DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_); | |||||
| DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_); | |||||
| RankList group_devices; | RankList group_devices; | ||||
| if (dev_matrix.GetDevicesAlongDim(SizeToUlong(axis), &group_devices) != SUCCESS) { | if (dev_matrix.GetDevicesAlongDim(SizeToUlong(axis), &group_devices) != SUCCESS) { | ||||
| return FAILED; | return FAILED; | ||||
| @@ -469,7 +456,6 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strat | |||||
| ResetQueueMember(); | ResetQueueMember(); | ||||
| strategy_ = strategy; | strategy_ = strategy; | ||||
| SetDeviceListByStrategy(); | |||||
| if (InferDevMatrixShape() != SUCCESS) { | if (InferDevMatrixShape() != SUCCESS) { | ||||
| MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; | MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; | ||||
| @@ -526,7 +512,6 @@ Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr &str | |||||
| ResetQueueMember(); | ResetQueueMember(); | ||||
| strategy_ = strategy; | strategy_ = strategy; | ||||
| SetDeviceListByStrategy(); | |||||
| if (InferDevMatrixShape() != SUCCESS) { | if (InferDevMatrixShape() != SUCCESS) { | ||||
| MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; | MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; | ||||
| @@ -1325,7 +1310,7 @@ Status OperatorInfo::InferAsLossDivisor() { | |||||
| } | } | ||||
| if (outputs_tensor_map_[0].empty()) { | if (outputs_tensor_map_[0].empty()) { | ||||
| as_loss_divisor_ = SizeToLong(global_device_list_.size()); | |||||
| as_loss_divisor_ = stage_device_size_; | |||||
| MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; | MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -64,6 +64,8 @@ class OperatorInfo { | |||||
| std::vector<bool> not_parameteter(inputs_shape_.size(), false); | std::vector<bool> not_parameteter(inputs_shape_.size(), false); | ||||
| is_parameter_ = not_parameteter; | is_parameter_ = not_parameteter; | ||||
| refkey_parameter_name_ = ""; | refkey_parameter_name_ = ""; | ||||
| stage_device_list_ = g_device_manager->GetDeviceListInThisStage(); | |||||
| stage_device_size_ = SizeToLong(stage_device_list_.size()); | |||||
| } | } | ||||
| virtual ~OperatorInfo() = default; | virtual ~OperatorInfo() = default; | ||||
| @@ -119,7 +121,7 @@ class OperatorInfo { | |||||
| std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost() const { return strategy_cost_; } | std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost() const { return strategy_cost_; } | ||||
| const std::string &name() const { return name_; } | const std::string &name() const { return name_; } | ||||
| void set_name(const std::string &name) { name_ = name; } | void set_name(const std::string &name) { name_ = name; } | ||||
| RankList global_device_list() const { return global_device_list_; } | |||||
| RankList stage_device_list() const { return stage_device_list_; } | |||||
| void AddSuccEdge(const std::shared_ptr<Edge> &e) { succ_edges_.push_back(e); } | void AddSuccEdge(const std::shared_ptr<Edge> &e) { succ_edges_.push_back(e); } | ||||
| void AddPrevEdge(const std::shared_ptr<Edge> &e) { prev_edges_.push_back(e); } | void AddPrevEdge(const std::shared_ptr<Edge> &e) { prev_edges_.push_back(e); } | ||||
| @@ -187,7 +189,6 @@ class OperatorInfo { | |||||
| virtual Status InferTensorInfo() = 0; | virtual Status InferTensorInfo() = 0; | ||||
| virtual Status InferDevMatrixShape() = 0; | virtual Status InferDevMatrixShape() = 0; | ||||
| Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape); | Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape); | ||||
| void SetDeviceListByStrategy(); | |||||
| void SetRepeatedCalcDevMatrix(); | void SetRepeatedCalcDevMatrix(); | ||||
| void ResetTensorMapIfRepeatedCalc(); | void ResetTensorMapIfRepeatedCalc(); | ||||
| Status CreateGroupByDim(size_t axis, std::vector<Group> *group); | Status CreateGroupByDim(size_t axis, std::vector<Group> *group); | ||||
| @@ -231,8 +232,8 @@ class OperatorInfo { | |||||
| ReplaceGraphPtr replace_graph_; | ReplaceGraphPtr replace_graph_; | ||||
| MirrorOps mirror_ops_; | MirrorOps mirror_ops_; | ||||
| VirtualDivOp virtual_div_op_; | VirtualDivOp virtual_div_op_; | ||||
| RankList global_device_list_; // the size of global_device_list equal to the size of stageID | |||||
| RankList local_device_list_; // the size equal to global_device_list_.size() / repeated_calc_num_ | |||||
| RankList stage_device_list_; // the device list in this stage | |||||
| int64_t stage_device_size_ = 0; | |||||
| bool infer_attrs_completed_ = false; | bool infer_attrs_completed_ = false; | ||||
| bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel | bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel | ||||
| @@ -136,7 +136,7 @@ Status RangeInfo::InitForCostModel(const StrategyPtr &strategy) { | |||||
| Status RangeInfo::InferNewAttr() { | Status RangeInfo::InferNewAttr() { | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| int64_t rank = g_device_manager->global_rank(); | |||||
| int64_t rank = g_device_manager->rank_index_in_stage(); | |||||
| // If repeated calculation and repeated num as the last dimension of dev-matrix, | // If repeated calculation and repeated num as the last dimension of dev-matrix, | ||||
| // the dev-matrix is [split_num_, repeated_calc_num_], so from rank 0 to rank repeated_calc_num_ | // the dev-matrix is [split_num_, repeated_calc_num_], so from rank 0 to rank repeated_calc_num_ | ||||
| @@ -531,7 +531,7 @@ Status ArgMaxWithValueInfo::InferAsLossDivisor() { | |||||
| MS_LOG(INFO) << name_ << " has two outputs, use output[0] to infer"; | MS_LOG(INFO) << name_ << " has two outputs, use output[0] to infer"; | ||||
| if (outputs_tensor_map_[0].empty()) { | if (outputs_tensor_map_[0].empty()) { | ||||
| as_loss_divisor_ = SizeToLong(global_device_list_.size()); | |||||
| as_loss_divisor_ = stage_device_size_; | |||||
| MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size" << as_loss_divisor_ << " as loss divisor."; | MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size" << as_loss_divisor_ << " as loss divisor."; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -172,7 +172,7 @@ Status ReLUV2Info::InferAsLossDivisor() { | |||||
| } | } | ||||
| if (outputs_tensor_map_[0].empty()) { | if (outputs_tensor_map_[0].empty()) { | ||||
| as_loss_divisor_ = SizeToInt(global_device_list_.size()); | |||||
| as_loss_divisor_ = stage_device_size_; | |||||
| MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; | MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -113,7 +113,7 @@ Status ReshapeInfo::GetParameterInput() { | |||||
| } | } | ||||
| Status ReshapeInfo::ComputeReplaceOp() { | Status ReshapeInfo::ComputeReplaceOp() { | ||||
| RankList dev_list = global_device_list(); | |||||
| RankList dev_list = stage_device_list(); | |||||
| TensorRedistribution tensor_redistribution(!is_generating_costs_, true); | TensorRedistribution tensor_redistribution(!is_generating_costs_, true); | ||||
| if (tensor_redistribution.Init(input_layout_, output_layout_, dev_list) == FAILED) { | if (tensor_redistribution.Init(input_layout_, output_layout_, dev_list) == FAILED) { | ||||
| if (is_generating_costs_) { | if (is_generating_costs_) { | ||||
| @@ -289,13 +289,7 @@ void ReshapeInfo::InferTensorInfoByLayout() { | |||||
| Status ReshapeInfo::GetAttrs() { return GetParameterInput(); } | Status ReshapeInfo::GetAttrs() { return GetParameterInput(); } | ||||
| void ReshapeInfo::device_number(const StrategyPtr &strategy) { | void ReshapeInfo::device_number(const StrategyPtr &strategy) { | ||||
| int64_t stage = 0; | |||||
| if (strategy != nullptr) { | |||||
| stage = strategy->GetInputStage(); | |||||
| } | |||||
| CheckGlobalDeviceManager(); | |||||
| global_device_list_ = g_device_manager->GetDeviceListByStageId(stage); | |||||
| dev_num_ = SizeToLong(global_device_list_.size()); | |||||
| dev_num_ = stage_device_size_; | |||||
| MS_ASSERT(dev_num_ > 0); | MS_ASSERT(dev_num_ > 0); | ||||
| } | } | ||||
| @@ -260,7 +260,7 @@ Status SplitInfo::InferAsLossDivisor() { | |||||
| } | } | ||||
| if (outputs_tensor_map_[0].empty()) { | if (outputs_tensor_map_[0].empty()) { | ||||
| as_loss_divisor_ = SizeToInt(global_device_list_.size()); | |||||
| as_loss_divisor_ = stage_device_size_; | |||||
| MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; | MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -325,7 +325,7 @@ void Redistribution(const std::pair<AnfNodePtr, int64_t> &node_pair, const Opera | |||||
| if (next_distribute_operator == nullptr) { | if (next_distribute_operator == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Failure: " << next_node->ToString() << " GetDistributeOperator failed"; | MS_LOG(EXCEPTION) << "Failure: " << next_node->ToString() << " GetDistributeOperator failed"; | ||||
| } | } | ||||
| RankList dev_list = distribute_operator->global_device_list(); | |||||
| RankList dev_list = distribute_operator->stage_device_list(); | |||||
| std::string next_prim_name = GetValueNode<PrimitivePtr>(next_node->input(0))->name(); | std::string next_prim_name = GetValueNode<PrimitivePtr>(next_node->input(0))->name(); | ||||
| MS_LOG(DEBUG) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim " << next_prim_name; | MS_LOG(DEBUG) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim " << next_prim_name; | ||||
| MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString(); | MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString(); | ||||
| @@ -161,6 +161,8 @@ class EmbeddingLookup(Cell): | |||||
| Examples: | Examples: | ||||
| >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) | >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) | ||||
| >>> out = nn.EmbeddingLookup(4,2)(input_indices) | >>> out = nn.EmbeddingLookup(4,2)(input_indices) | ||||
| >>> output.shape | |||||
| (2, 2, 2) | |||||
| """ | """ | ||||
| BATCH_SLICE = "batch_slice" | BATCH_SLICE = "batch_slice" | ||||
| FIELD_SLICE = "field_slice" | FIELD_SLICE = "field_slice" | ||||
| @@ -135,6 +135,8 @@ TEST_F(TestDeviceManager, test_StageID) { | |||||
| ASSERT_EQ(dm_.DeviceNum(), 4); | ASSERT_EQ(dm_.DeviceNum(), 4); | ||||
| ASSERT_EQ(dm_.stage_num(), 2); | ASSERT_EQ(dm_.stage_num(), 2); | ||||
| ASSERT_EQ(dm_.stage_id(), 1); | ASSERT_EQ(dm_.stage_id(), 1); | ||||
| ASSERT_EQ(dm_.rank_index_in_stage(), 0); | |||||
| ASSERT_EQ(dm_.GetDeviceListInThisStage().back(), 3); | |||||
| RankList dev_list_0 = dm_.GetDeviceListByStageId(0); | RankList dev_list_0 = dm_.GetDeviceListByStageId(0); | ||||
| RankList dev_list_1 = dm_.GetDeviceListByStageId(1); | RankList dev_list_1 = dm_.GetDeviceListByStageId(1); | ||||
| @@ -171,7 +171,7 @@ TEST_F(TestLogSoftmaxInfo, GetDeviceList1) { | |||||
| StrategyPtr strategy = NewStrategy(0, inputs); | StrategyPtr strategy = NewStrategy(0, inputs); | ||||
| log_softmax->Init(strategy); | log_softmax->Init(strategy); | ||||
| RankList dev_list = log_softmax->global_device_list(); | |||||
| RankList dev_list = log_softmax->stage_device_list(); | |||||
| ASSERT_EQ(dev_list.size(), 128); | ASSERT_EQ(dev_list.size(), 128); | ||||
| } | } | ||||