| @@ -75,6 +75,7 @@ class DeviceManager { | |||||
| size_t DeviceNum() const { return devices_.size(); } | size_t DeviceNum() const { return devices_.size(); } | ||||
| int64_t stage_num() const { return stage_num_; } | int64_t stage_num() const { return stage_num_; } | ||||
| int64_t stage_device_num() const { return stage_device_num_; } | |||||
| int64_t stage_id() const { return stage_id_; } | int64_t stage_id() const { return stage_id_; } | ||||
| int64_t rank_index_in_stage() const { return rank_index_in_stage_; } | int64_t rank_index_in_stage() const { return rank_index_in_stage_; } | ||||
| int64_t global_rank() const { return global_rank_; } | int64_t global_rank() const { return global_rank_; } | ||||
| @@ -41,11 +41,9 @@ Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| } | } | ||||
| // dropout don't support repeated calculation | // dropout don't support repeated calculation | ||||
| CheckGlobalDeviceManager(); | |||||
| auto input_strategy = strategy->GetInputDim().at(0); | auto input_strategy = strategy->GetInputDim().at(0); | ||||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); | |||||
| auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies<int64_t>()); | auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies<int64_t>()); | ||||
| if (IntToSize(product_p) != dev_num) { | |||||
| if (product_p != stage_device_size_) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; | MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -32,11 +32,6 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| int64_t stage = strategy->GetInputStage(); | |||||
| CheckGlobalDeviceManager(); | |||||
| int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size()); | |||||
| dev_num_ = dev_num; | |||||
| size_t strategy_size = strategy->GetInputNumber(); | size_t strategy_size = strategy->GetInputNumber(); | ||||
| Strategys stra = strategy->GetInputDim(); | Strategys stra = strategy->GetInputDim(); | ||||
| for (size_t i = 0; i < strategy_size; ++i) { | for (size_t i = 0; i < strategy_size; ++i) { | ||||
| @@ -46,7 +41,7 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| for (size_t j = 0; j < strategy_len; ++j) { | for (size_t j = 0; j < strategy_len; ++j) { | ||||
| int64_t strategy_value = sub_strategy.at(j); | int64_t strategy_value = sub_strategy.at(j); | ||||
| if (strategy_value > 1) { | if (strategy_value > 1) { | ||||
| if (flag || strategy_value != dev_num_) { | |||||
| if (flag || strategy_value != stage_device_size_) { | |||||
| MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; | MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -58,7 +53,7 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| } | } | ||||
| Status BatchParallelInfo::InferDevMatrixShape() { | Status BatchParallelInfo::InferDevMatrixShape() { | ||||
| dev_matrix_shape_.push_back(dev_num_); | |||||
| dev_matrix_shape_.push_back(stage_device_size_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -81,14 +76,14 @@ Status BatchParallelInfo::InferMirrorOps() { | |||||
| Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; } | Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; } | ||||
| Status BatchParallelInfo::InferTensorMap() { | Status BatchParallelInfo::InferTensorMap() { | ||||
| if (strategy_->GetInputDim()[0][0] != dev_num_) { | |||||
| if (strategy_->GetInputDim()[0][0] != stage_device_size_) { | |||||
| MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; | MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| for (size_t i = 0; i < inputs_shape_.size(); i++) { | for (size_t i = 0; i < inputs_shape_.size(); i++) { | ||||
| Shape tensor_map_index; | Shape tensor_map_index; | ||||
| for (size_t j = 0; j < inputs_shape_[i].size(); ++j) { | for (size_t j = 0; j < inputs_shape_[i].size(); ++j) { | ||||
| if (strategy_->GetInputDim()[i][j] == dev_num_ && j == 0) { | |||||
| if (strategy_->GetInputDim()[i][j] == stage_device_size_ && j == 0) { | |||||
| tensor_map_index.push_back(0); | tensor_map_index.push_back(0); | ||||
| } else { | } else { | ||||
| tensor_map_index.push_back(MAP_NONE); | tensor_map_index.push_back(MAP_NONE); | ||||
| @@ -117,7 +112,7 @@ Strategys BatchParallelInfo::GetOutputsStrategy() { | |||||
| Dimensions strategy; | Dimensions strategy; | ||||
| for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { | for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { | ||||
| if (i == 0 && j == 0) { | if (i == 0 && j == 0) { | ||||
| strategy.push_back(dev_num_); | |||||
| strategy.push_back(stage_device_size_); | |||||
| } else { | } else { | ||||
| strategy.push_back(1); | strategy.push_back(1); | ||||
| } | } | ||||
| @@ -176,14 +171,12 @@ Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||||
| } | } | ||||
| Status BatchParallelInfo::GenerateStrategies(int64_t stage_id) { | Status BatchParallelInfo::GenerateStrategies(int64_t stage_id) { | ||||
| CheckGlobalDeviceManager(); | |||||
| size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); | |||||
| StrategyPtr sp; | StrategyPtr sp; | ||||
| Strategys strategy; | Strategys strategy; | ||||
| for (size_t i = 0; i < inputs_shape_.size(); i++) { | for (size_t i = 0; i < inputs_shape_.size(); i++) { | ||||
| Shape temp(inputs_shape_[i].size(), 1); | Shape temp(inputs_shape_[i].size(), 1); | ||||
| if (split_flag_list_[i]) { | if (split_flag_list_[i]) { | ||||
| temp[0] = SizeToLong(total_dev_num); | |||||
| temp[0] = stage_device_size_; | |||||
| } | } | ||||
| strategy.push_back(temp); | strategy.push_back(temp); | ||||
| } | } | ||||
| @@ -151,10 +151,8 @@ Status DropoutDoMaskInfo::GenerateStrategies(int64_t stage_id) { | |||||
| } | } | ||||
| std::shared_ptr<Strategys> DropoutDoMaskInfo::GenerateBatchStrategies() { | std::shared_ptr<Strategys> DropoutDoMaskInfo::GenerateBatchStrategies() { | ||||
| CheckGlobalDeviceManager(); | |||||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||||
| Dimensions strategy(inputs_shape_[0].size() - 1, 1); | Dimensions strategy(inputs_shape_[0].size() - 1, 1); | ||||
| (void)strategy.insert(strategy.begin(), SizeToLong(dev_num)); | |||||
| (void)strategy.insert(strategy.begin(), stage_device_size_); | |||||
| Strategys strategy_v = {strategy}; | Strategys strategy_v = {strategy}; | ||||
| return std::make_shared<Strategys>(strategy_v); | return std::make_shared<Strategys>(strategy_v); | ||||
| } | } | ||||
| @@ -308,8 +308,6 @@ std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() { | |||||
| MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " | MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " | ||||
| << inputs_shape_.size(); | << inputs_shape_.size(); | ||||
| } | } | ||||
| CheckGlobalDeviceManager(); | |||||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||||
| if (GetAttrs() != SUCCESS) { | if (GetAttrs() != SUCCESS) { | ||||
| MS_LOG(EXCEPTION) << "GetAttrs failed!"; | MS_LOG(EXCEPTION) << "GetAttrs failed!"; | ||||
| } | } | ||||
| @@ -318,7 +316,7 @@ std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() { | |||||
| if (index_size_ != 1) { | if (index_size_ != 1) { | ||||
| strategy.push_back(1); | strategy.push_back(1); | ||||
| } else { | } else { | ||||
| strategy.push_back(SizeToLong(dev_num)); | |||||
| strategy.push_back(stage_device_size_); | |||||
| } | } | ||||
| for (size_t i = 1; i < inputs_shape_[0].size(); i++) { | for (size_t i = 1; i < inputs_shape_[0].size(); i++) { | ||||
| strategy.push_back(1); | strategy.push_back(1); | ||||
| @@ -199,10 +199,8 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) { | |||||
| } | } | ||||
| // Don't support repeated calc | // Don't support repeated calc | ||||
| CheckGlobalDeviceManager(); | |||||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); | |||||
| auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>()); | auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>()); | ||||
| if (IntToSize(product_p) < dev_num) { | |||||
| if (product_p < stage_device_size_) { | |||||
| MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc"; | MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc"; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -272,10 +270,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| } | } | ||||
| // param_strategy(axis) != 1, Don't support repeated calc | // param_strategy(axis) != 1, Don't support repeated calc | ||||
| CheckGlobalDeviceManager(); | |||||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); | |||||
| auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>()); | auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>()); | ||||
| if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) { | |||||
| if (product_p != stage_device_size_ && param_strategy.at(IntToSize(axis_)) != 1) { | |||||
| MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; | MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -349,13 +345,11 @@ Status GatherV2PInfo::InferDevMatrixShape() { | |||||
| } else { | } else { | ||||
| out_dev_matrix_shape_ = dev_matrix_shape_; | out_dev_matrix_shape_ = dev_matrix_shape_; | ||||
| } | } | ||||
| CheckGlobalDeviceManager(); | |||||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); | |||||
| auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>()); | auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>()); | ||||
| auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int64_t>()); | auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int64_t>()); | ||||
| if (param_product * index_product < SizeToInt(dev_num)) { | |||||
| if (param_product * index_product < stage_device_size_) { | |||||
| // add the repeated calculation num to the last dimension of dev matrix | // add the repeated calculation num to the last dimension of dev matrix | ||||
| out_dev_matrix_shape_.push_back(SizeToInt(dev_num / (param_product * index_product))); | |||||
| out_dev_matrix_shape_.push_back(stage_device_size_ / (param_product * index_product)); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -539,11 +533,8 @@ Status GatherV2PInfo::InferGroup() { | |||||
| dim = (axis_ + 1) % 2; | dim = (axis_ + 1) % 2; | ||||
| } | } | ||||
| CheckGlobalDeviceManager(); | |||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||||
| RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id_); | |||||
| int64_t rank = g_device_manager->global_rank(); | int64_t rank = g_device_manager->global_rank(); | ||||
| DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_); | |||||
| DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_); | |||||
| RankList group_devices; | RankList group_devices; | ||||
| if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) { | if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) { | ||||
| MS_LOG(ERROR) << name_ << ": Create group failed."; | MS_LOG(ERROR) << name_ << ": Create group failed."; | ||||
| @@ -777,11 +768,10 @@ std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() { | |||||
| if (manual_split_) { | if (manual_split_) { | ||||
| MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy"; | MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy"; | ||||
| } | } | ||||
| CheckGlobalDeviceManager(); | |||||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||||
| Dimensions param_strategy(inputs_shape_[0].size(), 1); | Dimensions param_strategy(inputs_shape_[0].size(), 1); | ||||
| Dimensions index_strategy; | Dimensions index_strategy; | ||||
| index_strategy.push_back(SizeToLong(dev_num)); | |||||
| index_strategy.push_back(stage_device_size_); | |||||
| for (size_t i = 1; i < inputs_shape_[1].size(); i++) { | for (size_t i = 1; i < inputs_shape_[1].size(); i++) { | ||||
| index_strategy.push_back(1); | index_strategy.push_back(1); | ||||
| } | } | ||||
| @@ -66,7 +66,7 @@ Strategys GetNextInfo::GetOutputStrategy() { | |||||
| Strategys outputs_strategy; | Strategys outputs_strategy; | ||||
| for (auto shp : shapes_) { | for (auto shp : shapes_) { | ||||
| Dimensions out_strategy; | Dimensions out_strategy; | ||||
| out_strategy.push_back(dev_num_); | |||||
| out_strategy.push_back(stage_device_size_); | |||||
| for (size_t i = 1; i < shp.size(); ++i) { | for (size_t i = 1; i < shp.size(); ++i) { | ||||
| out_strategy.push_back(1); | out_strategy.push_back(1); | ||||
| } | } | ||||
| @@ -97,7 +97,7 @@ Status GetNextInfo::InferDevMatrixShape() { | |||||
| if (max_shape_length == 0) { | if (max_shape_length == 0) { | ||||
| MS_LOG(ERROR) << name_ << " : shape is 0"; | MS_LOG(ERROR) << name_ << " : shape is 0"; | ||||
| } | } | ||||
| dev_matrix_shape_.push_back(dev_num_); | |||||
| dev_matrix_shape_.push_back(stage_device_size_); | |||||
| for (size_t i = 1; i < max_shape_length; ++i) { | for (size_t i = 1; i < max_shape_length; ++i) { | ||||
| dev_matrix_shape_.push_back(1); | dev_matrix_shape_.push_back(1); | ||||
| } | } | ||||
| @@ -125,9 +125,6 @@ Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| int64_t stage = strategy->GetInputStage(); | |||||
| int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size()); | |||||
| dev_num_ = dev_num; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -199,16 +196,16 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { | |||||
| Shapes out_shapes = outputs_shape_; | Shapes out_shapes = outputs_shape_; | ||||
| for (size_t i = 0; i < out_shapes.size(); ++i) { | for (size_t i = 0; i < out_shapes.size(); ++i) { | ||||
| if (dev_num_ <= 0) { | |||||
| if (stage_device_size_ <= 0) { | |||||
| MS_LOG(ERROR) << name_ << " : The dev num is 0."; | MS_LOG(ERROR) << name_ << " : The dev num is 0."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (!full_batch) { | if (!full_batch) { | ||||
| if (out_shapes[i][0] % dev_num_ != 0) { | |||||
| if (out_shapes[i][0] % stage_device_size_ != 0) { | |||||
| MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num."; | MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| out_shapes[i][0] = out_shapes[i][0] / dev_num_; | |||||
| out_shapes[i][0] = out_shapes[i][0] / stage_device_size_; | |||||
| } | } | ||||
| } | } | ||||
| ValuePtr new_shapes = MakeValue(out_shapes); | ValuePtr new_shapes = MakeValue(out_shapes); | ||||
| @@ -601,10 +601,8 @@ Status MatMulBase::CheckForTensorSliceValid() const { | |||||
| } | } | ||||
| std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() { | std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() { | ||||
| CheckGlobalDeviceManager(); | |||||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||||
| Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1); | Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1); | ||||
| batch_strategy.insert(batch_strategy.begin(), SizeToLong(dev_num)); | |||||
| batch_strategy.insert(batch_strategy.begin(), stage_device_size_); | |||||
| Strategys strategy_v = {batch_strategy, batch_strategy}; | Strategys strategy_v = {batch_strategy, batch_strategy}; | ||||
| return std::make_shared<Strategys>(strategy_v); | return std::make_shared<Strategys>(strategy_v); | ||||
| } | } | ||||
| @@ -268,9 +268,7 @@ Status OneHotInfo::GenerateStrategies(int64_t stage_id) { | |||||
| Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | ||||
| std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() { | std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() { | ||||
| CheckGlobalDeviceManager(); | |||||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||||
| Dimensions strategy = {SizeToLong(dev_num), 1}; | |||||
| Dimensions strategy = {stage_device_size_, 1}; | |||||
| Dimensions empty_strategy; | Dimensions empty_strategy; | ||||
| Strategys strategy_v = {strategy, empty_strategy, empty_strategy}; | Strategys strategy_v = {strategy, empty_strategy, empty_strategy}; | ||||
| return std::make_shared<Strategys>(strategy_v); | return std::make_shared<Strategys>(strategy_v); | ||||
| @@ -688,7 +688,7 @@ std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shap | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); | |||||
| int64_t dev_num = g_device_manager->stage_device_num(); | |||||
| Strategys strategy_v; | Strategys strategy_v; | ||||
| for (size_t i = 0; i != shapes.size(); i++) { | for (size_t i = 0; i != shapes.size(); i++) { | ||||
| if (shapes[i].empty()) { | if (shapes[i].empty()) { | ||||
| @@ -1393,9 +1393,7 @@ Status OperatorInfo::set_outputs_type(const std::vector<TypePtr> &outputs_type) | |||||
| void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) { | void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) { | ||||
| if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) { | if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) { | ||||
| CheckGlobalDeviceManager(); | |||||
| auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); | |||||
| if (LongToSize(stra->GetInputDim()[0][0]) == total_device_num) { | |||||
| if (stra->GetInputDim()[0][0] == stage_device_size_) { | |||||
| if (cost->computation_cost_ > 1.0) { | if (cost->computation_cost_ > 1.0) { | ||||
| cost->computation_cost_ -= 1.0; | cost->computation_cost_ -= 1.0; | ||||
| } | } | ||||
| @@ -233,15 +233,13 @@ std::shared_ptr<Strategys> SplitInfo::GenerateBatchStrategies() { | |||||
| if (GetAttrs() != SUCCESS) { | if (GetAttrs() != SUCCESS) { | ||||
| MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; | MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; | ||||
| } | } | ||||
| CheckGlobalDeviceManager(); | |||||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||||
| Dimensions input_strategy(inputs_shape_[0].size(), 1); | Dimensions input_strategy(inputs_shape_[0].size(), 1); | ||||
| // axis can't split | // axis can't split | ||||
| if (inputs_shape_[0].size() > 1) { | if (inputs_shape_[0].size() > 1) { | ||||
| if (axis_ == 0) { | if (axis_ == 0) { | ||||
| input_strategy[1] = dev_num; | |||||
| input_strategy[1] = stage_device_size_; | |||||
| } else { | } else { | ||||
| input_strategy[0] = dev_num; | |||||
| input_strategy[0] = stage_device_size_; | |||||
| } | } | ||||
| } | } | ||||
| Strategys strategy_v = {input_strategy}; | Strategys strategy_v = {input_strategy}; | ||||
| @@ -408,17 +408,14 @@ std::shared_ptr<Strategys> TensorDotInfo::GenerateBatchStrategies() { | |||||
| if (GetAttrs() != SUCCESS) { | if (GetAttrs() != SUCCESS) { | ||||
| MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; | MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; | ||||
| } | } | ||||
| CheckGlobalDeviceManager(); | |||||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||||
| Dimensions input_a_strategy(inputs_shape_[0].size(), 1); | Dimensions input_a_strategy(inputs_shape_[0].size(), 1); | ||||
| Dimensions input_b_strategy(inputs_shape_[1].size(), 1); | Dimensions input_b_strategy(inputs_shape_[1].size(), 1); | ||||
| input_a_strategy[0] = SizeToInt(dev_num); | |||||
| input_a_strategy[0] = stage_device_size_; | |||||
| if (axes_type_ == INT_TYPE) { | if (axes_type_ == INT_TYPE) { | ||||
| if (IntToSize(axes_int_) == inputs_shape_[0].size()) { | if (IntToSize(axes_int_) == inputs_shape_[0].size()) { | ||||
| input_b_strategy[0] = SizeToInt(dev_num); // find the relavent dimension for input_b | |||||
| input_b_strategy[0] = stage_device_size_; // find the relavent dimension for input_b | |||||
| } | } | ||||
| } else if (axes_type_ == TUPLE_TUPLE_TYPE) { | } else if (axes_type_ == TUPLE_TUPLE_TYPE) { | ||||
| // if the input_a's axes contain 0, the input_b has the relavent dimension with batch dimension | // if the input_a's axes contain 0, the input_b has the relavent dimension with batch dimension | ||||
| @@ -434,7 +431,7 @@ std::shared_ptr<Strategys> TensorDotInfo::GenerateBatchStrategies() { | |||||
| if (found) { | if (found) { | ||||
| // find the relavant | // find the relavant | ||||
| input_b_strategy[axes_tuple_tuple_[1][relavant_index]] = dev_num; | |||||
| input_b_strategy[axes_tuple_tuple_[1][relavant_index]] = stage_device_size_; | |||||
| } | } | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << name_ << ": Now do not support TUPLE_TYPE"; | MS_LOG(EXCEPTION) << name_ << ": Now do not support TUPLE_TYPE"; | ||||
| @@ -85,7 +85,7 @@ Status UniqueInfo::InferTensorInfo() { | |||||
| } | } | ||||
| Status UniqueInfo::InferDevMatrixShape() { | Status UniqueInfo::InferDevMatrixShape() { | ||||
| dev_matrix_shape_.push_back(dev_num_); | |||||
| dev_matrix_shape_.push_back(stage_device_size_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -110,9 +110,7 @@ Status UniqueInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| int64_t stage = strategy->GetInputStage(); | |||||
| int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size()); | |||||
| dev_num_ = dev_num; | |||||
| if (stras[0][0] != 1) { | if (stras[0][0] != 1) { | ||||
| MS_LOG(ERROR) << "Currently, unique only support repeat calculate in all devices"; | MS_LOG(ERROR) << "Currently, unique only support repeat calculate in all devices"; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -277,20 +277,17 @@ std::shared_ptr<Strategys> UnsortedSegmentOpInfo::GenerateBatchStrategies() { | |||||
| MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << UNSORTEDSEGMENTOP_INPUTS_SIZE << ", but is " | MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << UNSORTEDSEGMENTOP_INPUTS_SIZE << ", but is " | ||||
| << inputs_shape_.size(); | << inputs_shape_.size(); | ||||
| } | } | ||||
| CheckGlobalDeviceManager(); | |||||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||||
| if (GetAttrs() != SUCCESS) { | if (GetAttrs() != SUCCESS) { | ||||
| MS_LOG(EXCEPTION) << "GetAttrs failed!"; | MS_LOG(EXCEPTION) << "GetAttrs failed!"; | ||||
| } | } | ||||
| Dimensions strategy_a; | |||||
| Dimensions strategy_b; | |||||
| strategy_a.push_back(SizeToInt(dev_num)); | |||||
| Dimensions strategy_a, strategy_b; | |||||
| strategy_a.push_back(stage_device_size_); | |||||
| for (size_t i = 1; i < inputs_shape_[0].size(); i++) { | for (size_t i = 1; i < inputs_shape_[0].size(); i++) { | ||||
| strategy_a.push_back(1); | strategy_a.push_back(1); | ||||
| } | } | ||||
| strategy_b.push_back(SizeToInt(dev_num)); | |||||
| strategy_b.push_back(stage_device_size_); | |||||
| for (size_t i = 1; i < inputs_shape_[1].size(); i++) { | for (size_t i = 1; i < inputs_shape_[1].size(); i++) { | ||||
| strategy_b.push_back(1); | strategy_b.push_back(1); | ||||
| } | } | ||||
| @@ -66,13 +66,10 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| Status VirtualDatasetInfo::InferDevMatrixShape() { | Status VirtualDatasetInfo::InferDevMatrixShape() { | ||||
| Strategys stra = strategy_->GetInputDim(); | Strategys stra = strategy_->GetInputDim(); | ||||
| Dimensions strategy_first = stra.at(0); | Dimensions strategy_first = stra.at(0); | ||||
| int64_t stage = strategy_->GetInputStage(); | |||||
| CheckGlobalDeviceManager(); | |||||
| int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size()); | |||||
| int64_t batch_split_num = ((int64_t)(strategy_first.at(0))); | int64_t batch_split_num = ((int64_t)(strategy_first.at(0))); | ||||
| dev_matrix_shape_.push_back(batch_split_num); | dev_matrix_shape_.push_back(batch_split_num); | ||||
| if (dev_num > batch_split_num) { | |||||
| dev_matrix_shape_.push_back(dev_num / batch_split_num); | |||||
| if (stage_device_size_ > batch_split_num) { | |||||
| dev_matrix_shape_.push_back(stage_device_size_ / batch_split_num); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -156,11 +153,10 @@ Status VirtualDatasetInfo::GenerateStrategies(int64_t stage_id) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| CheckGlobalDeviceManager(); | |||||
| if (full_batch) { | if (full_batch) { | ||||
| total_dev_num = 1; | total_dev_num = 1; | ||||
| } else { | } else { | ||||
| total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); | |||||
| total_dev_num = stage_device_size_; | |||||
| } | } | ||||
| StrategyPtr sp; | StrategyPtr sp; | ||||
| Strategys strategy; | Strategys strategy; | ||||
| @@ -1640,7 +1640,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { | |||||
| if (full_batch) { | if (full_batch) { | ||||
| dev_num = 1; | dev_num = 1; | ||||
| } else { | } else { | ||||
| dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); | |||||
| dev_num = SizeToLong(g_device_manager->stage_device_num()); | |||||
| } | } | ||||
| auto attrs_temp = prim->attrs(); | auto attrs_temp = prim->attrs(); | ||||
| std::vector<Shapes> shape_list = ExtractShape(node); | std::vector<Shapes> shape_list = ExtractShape(node); | ||||
| @@ -1984,7 +1984,7 @@ std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) { | |||||
| return next_layout; | return next_layout; | ||||
| } | } | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); | |||||
| int64_t dev_num = g_device_manager->stage_device_num(); | |||||
| TensorLayout input_tensor_layout; | TensorLayout input_tensor_layout; | ||||
| // create input_shape | // create input_shape | ||||
| Shapes inputs_shape = GetNodeShape(node); | Shapes inputs_shape = GetNodeShape(node); | ||||
| @@ -2009,7 +2009,7 @@ RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const Te | |||||
| TensorRedistribution tensor_redistribution; | TensorRedistribution tensor_redistribution; | ||||
| // create stand alone layout:TensorMap:[all -1],dev_matrix:[dev_num]. | // create stand alone layout:TensorMap:[all -1],dev_matrix:[dev_num]. | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); | |||||
| int64_t dev_num = g_device_manager->stage_device_num(); | |||||
| TensorLayout stand_alone_layout; | TensorLayout stand_alone_layout; | ||||
| Shapes inputs_shape = GetNodeShape(node); | Shapes inputs_shape = GetNodeShape(node); | ||||
| if (inputs_shape.empty()) { | if (inputs_shape.empty()) { | ||||
| @@ -2029,7 +2029,7 @@ RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const Te | |||||
| } | } | ||||
| // Infer Redistribution op list for stand alone and loss layout. | // Infer Redistribution op list for stand alone and loss layout. | ||||
| RankList dev_list = g_device_manager->GetDeviceListByStageId(0); | |||||
| RankList dev_list = g_device_manager->GetDeviceListInThisStage(); | |||||
| if (tensor_redistribution.Init(stand_alone_layout, loss_layout, dev_list) == FAILED) { | if (tensor_redistribution.Init(stand_alone_layout, loss_layout, dev_list) == FAILED) { | ||||
| MS_LOG(EXCEPTION) << "Redistribution for Sens init failed."; | MS_LOG(EXCEPTION) << "Redistribution for Sens init failed."; | ||||
| } | } | ||||
| @@ -3093,7 +3093,7 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) { | |||||
| if (full_batch) { | if (full_batch) { | ||||
| return; | return; | ||||
| } | } | ||||
| auto dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||||
| auto dev_num = g_device_manager->stage_device_num(); | |||||
| auto parameters = root->parameters(); | auto parameters = root->parameters(); | ||||
| for (auto ¶meter : parameters) { | for (auto ¶meter : parameters) { | ||||
| if (IsUsedParameter(root, parameter)) { | if (IsUsedParameter(root, parameter)) { | ||||