| @@ -75,6 +75,7 @@ class DeviceManager { | |||
| size_t DeviceNum() const { return devices_.size(); } | |||
| int64_t stage_num() const { return stage_num_; } | |||
| int64_t stage_device_num() const { return stage_device_num_; } | |||
| int64_t stage_id() const { return stage_id_; } | |||
| int64_t rank_index_in_stage() const { return rank_index_in_stage_; } | |||
| int64_t global_rank() const { return global_rank_; } | |||
| @@ -41,11 +41,9 @@ Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| } | |||
| // dropout don't support repeated calculation | |||
| CheckGlobalDeviceManager(); | |||
| auto input_strategy = strategy->GetInputDim().at(0); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); | |||
| auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies<int64_t>()); | |||
| if (IntToSize(product_p) != dev_num) { | |||
| if (product_p != stage_device_size_) { | |||
| MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; | |||
| return FAILED; | |||
| } | |||
| @@ -32,11 +32,6 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| return FAILED; | |||
| } | |||
| int64_t stage = strategy->GetInputStage(); | |||
| CheckGlobalDeviceManager(); | |||
| int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size()); | |||
| dev_num_ = dev_num; | |||
| size_t strategy_size = strategy->GetInputNumber(); | |||
| Strategys stra = strategy->GetInputDim(); | |||
| for (size_t i = 0; i < strategy_size; ++i) { | |||
| @@ -46,7 +41,7 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| for (size_t j = 0; j < strategy_len; ++j) { | |||
| int64_t strategy_value = sub_strategy.at(j); | |||
| if (strategy_value > 1) { | |||
| if (flag || strategy_value != dev_num_) { | |||
| if (flag || strategy_value != stage_device_size_) { | |||
| MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; | |||
| return FAILED; | |||
| } | |||
| @@ -58,7 +53,7 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| } | |||
| Status BatchParallelInfo::InferDevMatrixShape() { | |||
| dev_matrix_shape_.push_back(dev_num_); | |||
| dev_matrix_shape_.push_back(stage_device_size_); | |||
| return SUCCESS; | |||
| } | |||
| @@ -81,14 +76,14 @@ Status BatchParallelInfo::InferMirrorOps() { | |||
| Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; } | |||
| Status BatchParallelInfo::InferTensorMap() { | |||
| if (strategy_->GetInputDim()[0][0] != dev_num_) { | |||
| if (strategy_->GetInputDim()[0][0] != stage_device_size_) { | |||
| MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; | |||
| return FAILED; | |||
| } | |||
| for (size_t i = 0; i < inputs_shape_.size(); i++) { | |||
| Shape tensor_map_index; | |||
| for (size_t j = 0; j < inputs_shape_[i].size(); ++j) { | |||
| if (strategy_->GetInputDim()[i][j] == dev_num_ && j == 0) { | |||
| if (strategy_->GetInputDim()[i][j] == stage_device_size_ && j == 0) { | |||
| tensor_map_index.push_back(0); | |||
| } else { | |||
| tensor_map_index.push_back(MAP_NONE); | |||
| @@ -117,7 +112,7 @@ Strategys BatchParallelInfo::GetOutputsStrategy() { | |||
| Dimensions strategy; | |||
| for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { | |||
| if (i == 0 && j == 0) { | |||
| strategy.push_back(dev_num_); | |||
| strategy.push_back(stage_device_size_); | |||
| } else { | |||
| strategy.push_back(1); | |||
| } | |||
| @@ -176,14 +171,12 @@ Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||
| } | |||
| Status BatchParallelInfo::GenerateStrategies(int64_t stage_id) { | |||
| CheckGlobalDeviceManager(); | |||
| size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); | |||
| StrategyPtr sp; | |||
| Strategys strategy; | |||
| for (size_t i = 0; i < inputs_shape_.size(); i++) { | |||
| Shape temp(inputs_shape_[i].size(), 1); | |||
| if (split_flag_list_[i]) { | |||
| temp[0] = SizeToLong(total_dev_num); | |||
| temp[0] = stage_device_size_; | |||
| } | |||
| strategy.push_back(temp); | |||
| } | |||
| @@ -151,10 +151,8 @@ Status DropoutDoMaskInfo::GenerateStrategies(int64_t stage_id) { | |||
| } | |||
| std::shared_ptr<Strategys> DropoutDoMaskInfo::GenerateBatchStrategies() { | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| Dimensions strategy(inputs_shape_[0].size() - 1, 1); | |||
| (void)strategy.insert(strategy.begin(), SizeToLong(dev_num)); | |||
| (void)strategy.insert(strategy.begin(), stage_device_size_); | |||
| Strategys strategy_v = {strategy}; | |||
| return std::make_shared<Strategys>(strategy_v); | |||
| } | |||
| @@ -308,8 +308,6 @@ std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() { | |||
| MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " | |||
| << inputs_shape_.size(); | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| if (GetAttrs() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "GetAttrs failed!"; | |||
| } | |||
| @@ -318,7 +316,7 @@ std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() { | |||
| if (index_size_ != 1) { | |||
| strategy.push_back(1); | |||
| } else { | |||
| strategy.push_back(SizeToLong(dev_num)); | |||
| strategy.push_back(stage_device_size_); | |||
| } | |||
| for (size_t i = 1; i < inputs_shape_[0].size(); i++) { | |||
| strategy.push_back(1); | |||
| @@ -199,10 +199,8 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) { | |||
| } | |||
| // Don't support repeated calc | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); | |||
| auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>()); | |||
| if (IntToSize(product_p) < dev_num) { | |||
| if (product_p < stage_device_size_) { | |||
| MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc"; | |||
| return FAILED; | |||
| } | |||
| @@ -272,10 +270,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| } | |||
| // param_strategy(axis) != 1, Don't support repeated calc | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); | |||
| auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>()); | |||
| if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) { | |||
| if (product_p != stage_device_size_ && param_strategy.at(IntToSize(axis_)) != 1) { | |||
| MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; | |||
| return FAILED; | |||
| } | |||
| @@ -349,13 +345,11 @@ Status GatherV2PInfo::InferDevMatrixShape() { | |||
| } else { | |||
| out_dev_matrix_shape_ = dev_matrix_shape_; | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size(); | |||
| auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>()); | |||
| auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int64_t>()); | |||
| if (param_product * index_product < SizeToInt(dev_num)) { | |||
| if (param_product * index_product < stage_device_size_) { | |||
| // add the repeated calculation num to the last dimension of dev matrix | |||
| out_dev_matrix_shape_.push_back(SizeToInt(dev_num / (param_product * index_product))); | |||
| out_dev_matrix_shape_.push_back(stage_device_size_ / (param_product * index_product)); | |||
| } | |||
| return SUCCESS; | |||
| @@ -539,11 +533,8 @@ Status GatherV2PInfo::InferGroup() { | |||
| dim = (axis_ + 1) % 2; | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||
| RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id_); | |||
| int64_t rank = g_device_manager->global_rank(); | |||
| DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_); | |||
| DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_); | |||
| RankList group_devices; | |||
| if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Create group failed."; | |||
| @@ -777,11 +768,10 @@ std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() { | |||
| if (manual_split_) { | |||
| MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy"; | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| Dimensions param_strategy(inputs_shape_[0].size(), 1); | |||
| Dimensions index_strategy; | |||
| index_strategy.push_back(SizeToLong(dev_num)); | |||
| index_strategy.push_back(stage_device_size_); | |||
| for (size_t i = 1; i < inputs_shape_[1].size(); i++) { | |||
| index_strategy.push_back(1); | |||
| } | |||
| @@ -66,7 +66,7 @@ Strategys GetNextInfo::GetOutputStrategy() { | |||
| Strategys outputs_strategy; | |||
| for (auto shp : shapes_) { | |||
| Dimensions out_strategy; | |||
| out_strategy.push_back(dev_num_); | |||
| out_strategy.push_back(stage_device_size_); | |||
| for (size_t i = 1; i < shp.size(); ++i) { | |||
| out_strategy.push_back(1); | |||
| } | |||
| @@ -97,7 +97,7 @@ Status GetNextInfo::InferDevMatrixShape() { | |||
| if (max_shape_length == 0) { | |||
| MS_LOG(ERROR) << name_ << " : shape is 0"; | |||
| } | |||
| dev_matrix_shape_.push_back(dev_num_); | |||
| dev_matrix_shape_.push_back(stage_device_size_); | |||
| for (size_t i = 1; i < max_shape_length; ++i) { | |||
| dev_matrix_shape_.push_back(1); | |||
| } | |||
| @@ -125,9 +125,6 @@ Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| return FAILED; | |||
| } | |||
| } | |||
| int64_t stage = strategy->GetInputStage(); | |||
| int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size()); | |||
| dev_num_ = dev_num; | |||
| return SUCCESS; | |||
| } | |||
| @@ -199,16 +196,16 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { | |||
| Shapes out_shapes = outputs_shape_; | |||
| for (size_t i = 0; i < out_shapes.size(); ++i) { | |||
| if (dev_num_ <= 0) { | |||
| if (stage_device_size_ <= 0) { | |||
| MS_LOG(ERROR) << name_ << " : The dev num is 0."; | |||
| return FAILED; | |||
| } | |||
| if (!full_batch) { | |||
| if (out_shapes[i][0] % dev_num_ != 0) { | |||
| if (out_shapes[i][0] % stage_device_size_ != 0) { | |||
| MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num."; | |||
| return FAILED; | |||
| } | |||
| out_shapes[i][0] = out_shapes[i][0] / dev_num_; | |||
| out_shapes[i][0] = out_shapes[i][0] / stage_device_size_; | |||
| } | |||
| } | |||
| ValuePtr new_shapes = MakeValue(out_shapes); | |||
| @@ -601,10 +601,8 @@ Status MatMulBase::CheckForTensorSliceValid() const { | |||
| } | |||
| std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() { | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1); | |||
| batch_strategy.insert(batch_strategy.begin(), SizeToLong(dev_num)); | |||
| batch_strategy.insert(batch_strategy.begin(), stage_device_size_); | |||
| Strategys strategy_v = {batch_strategy, batch_strategy}; | |||
| return std::make_shared<Strategys>(strategy_v); | |||
| } | |||
| @@ -268,9 +268,7 @@ Status OneHotInfo::GenerateStrategies(int64_t stage_id) { | |||
| Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||
| std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() { | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| Dimensions strategy = {SizeToLong(dev_num), 1}; | |||
| Dimensions strategy = {stage_device_size_, 1}; | |||
| Dimensions empty_strategy; | |||
| Strategys strategy_v = {strategy, empty_strategy, empty_strategy}; | |||
| return std::make_shared<Strategys>(strategy_v); | |||
| @@ -688,7 +688,7 @@ std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shap | |||
| return nullptr; | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); | |||
| int64_t dev_num = g_device_manager->stage_device_num(); | |||
| Strategys strategy_v; | |||
| for (size_t i = 0; i != shapes.size(); i++) { | |||
| if (shapes[i].empty()) { | |||
| @@ -1393,9 +1393,7 @@ Status OperatorInfo::set_outputs_type(const std::vector<TypePtr> &outputs_type) | |||
| void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) { | |||
| if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) { | |||
| CheckGlobalDeviceManager(); | |||
| auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); | |||
| if (LongToSize(stra->GetInputDim()[0][0]) == total_device_num) { | |||
| if (stra->GetInputDim()[0][0] == stage_device_size_) { | |||
| if (cost->computation_cost_ > 1.0) { | |||
| cost->computation_cost_ -= 1.0; | |||
| } | |||
| @@ -233,15 +233,13 @@ std::shared_ptr<Strategys> SplitInfo::GenerateBatchStrategies() { | |||
| if (GetAttrs() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| Dimensions input_strategy(inputs_shape_[0].size(), 1); | |||
| // axis can't split | |||
| if (inputs_shape_[0].size() > 1) { | |||
| if (axis_ == 0) { | |||
| input_strategy[1] = dev_num; | |||
| input_strategy[1] = stage_device_size_; | |||
| } else { | |||
| input_strategy[0] = dev_num; | |||
| input_strategy[0] = stage_device_size_; | |||
| } | |||
| } | |||
| Strategys strategy_v = {input_strategy}; | |||
| @@ -408,17 +408,14 @@ std::shared_ptr<Strategys> TensorDotInfo::GenerateBatchStrategies() { | |||
| if (GetAttrs() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| Dimensions input_a_strategy(inputs_shape_[0].size(), 1); | |||
| Dimensions input_b_strategy(inputs_shape_[1].size(), 1); | |||
| input_a_strategy[0] = SizeToInt(dev_num); | |||
| input_a_strategy[0] = stage_device_size_; | |||
| if (axes_type_ == INT_TYPE) { | |||
| if (IntToSize(axes_int_) == inputs_shape_[0].size()) { | |||
| input_b_strategy[0] = SizeToInt(dev_num); // find the relavent dimension for input_b | |||
| input_b_strategy[0] = stage_device_size_; // find the relavent dimension for input_b | |||
| } | |||
| } else if (axes_type_ == TUPLE_TUPLE_TYPE) { | |||
| // if the input_a's axes contain 0, the input_b has the relavent dimension with batch dimension | |||
| @@ -434,7 +431,7 @@ std::shared_ptr<Strategys> TensorDotInfo::GenerateBatchStrategies() { | |||
| if (found) { | |||
| // find the relavant | |||
| input_b_strategy[axes_tuple_tuple_[1][relavant_index]] = dev_num; | |||
| input_b_strategy[axes_tuple_tuple_[1][relavant_index]] = stage_device_size_; | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << name_ << ": Now do not support TUPLE_TYPE"; | |||
| @@ -85,7 +85,7 @@ Status UniqueInfo::InferTensorInfo() { | |||
| } | |||
| Status UniqueInfo::InferDevMatrixShape() { | |||
| dev_matrix_shape_.push_back(dev_num_); | |||
| dev_matrix_shape_.push_back(stage_device_size_); | |||
| return SUCCESS; | |||
| } | |||
| @@ -110,9 +110,7 @@ Status UniqueInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| return FAILED; | |||
| } | |||
| } | |||
| int64_t stage = strategy->GetInputStage(); | |||
| int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size()); | |||
| dev_num_ = dev_num; | |||
| if (stras[0][0] != 1) { | |||
| MS_LOG(ERROR) << "Currently, unique only support repeat calculate in all devices"; | |||
| return FAILED; | |||
| @@ -277,20 +277,17 @@ std::shared_ptr<Strategys> UnsortedSegmentOpInfo::GenerateBatchStrategies() { | |||
| MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << UNSORTEDSEGMENTOP_INPUTS_SIZE << ", but is " | |||
| << inputs_shape_.size(); | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| if (GetAttrs() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "GetAttrs failed!"; | |||
| } | |||
| Dimensions strategy_a; | |||
| Dimensions strategy_b; | |||
| strategy_a.push_back(SizeToInt(dev_num)); | |||
| Dimensions strategy_a, strategy_b; | |||
| strategy_a.push_back(stage_device_size_); | |||
| for (size_t i = 1; i < inputs_shape_[0].size(); i++) { | |||
| strategy_a.push_back(1); | |||
| } | |||
| strategy_b.push_back(SizeToInt(dev_num)); | |||
| strategy_b.push_back(stage_device_size_); | |||
| for (size_t i = 1; i < inputs_shape_[1].size(); i++) { | |||
| strategy_b.push_back(1); | |||
| } | |||
| @@ -66,13 +66,10 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| Status VirtualDatasetInfo::InferDevMatrixShape() { | |||
| Strategys stra = strategy_->GetInputDim(); | |||
| Dimensions strategy_first = stra.at(0); | |||
| int64_t stage = strategy_->GetInputStage(); | |||
| CheckGlobalDeviceManager(); | |||
| int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(stage).size()); | |||
| int64_t batch_split_num = ((int64_t)(strategy_first.at(0))); | |||
| dev_matrix_shape_.push_back(batch_split_num); | |||
| if (dev_num > batch_split_num) { | |||
| dev_matrix_shape_.push_back(dev_num / batch_split_num); | |||
| if (stage_device_size_ > batch_split_num) { | |||
| dev_matrix_shape_.push_back(stage_device_size_ / batch_split_num); | |||
| } | |||
| return SUCCESS; | |||
| @@ -156,11 +153,10 @@ Status VirtualDatasetInfo::GenerateStrategies(int64_t stage_id) { | |||
| return FAILED; | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| if (full_batch) { | |||
| total_dev_num = 1; | |||
| } else { | |||
| total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); | |||
| total_dev_num = stage_device_size_; | |||
| } | |||
| StrategyPtr sp; | |||
| Strategys strategy; | |||
| @@ -1640,7 +1640,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { | |||
| if (full_batch) { | |||
| dev_num = 1; | |||
| } else { | |||
| dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); | |||
| dev_num = SizeToLong(g_device_manager->stage_device_num()); | |||
| } | |||
| auto attrs_temp = prim->attrs(); | |||
| std::vector<Shapes> shape_list = ExtractShape(node); | |||
| @@ -1984,7 +1984,7 @@ std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) { | |||
| return next_layout; | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); | |||
| int64_t dev_num = g_device_manager->stage_device_num(); | |||
| TensorLayout input_tensor_layout; | |||
| // create input_shape | |||
| Shapes inputs_shape = GetNodeShape(node); | |||
| @@ -2009,7 +2009,7 @@ RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const Te | |||
| TensorRedistribution tensor_redistribution; | |||
| // create stand alone layout:TensorMap:[all -1],dev_matrix:[dev_num]. | |||
| CheckGlobalDeviceManager(); | |||
| int64_t dev_num = SizeToLong(g_device_manager->GetDeviceListByStageId(0).size()); | |||
| int64_t dev_num = g_device_manager->stage_device_num(); | |||
| TensorLayout stand_alone_layout; | |||
| Shapes inputs_shape = GetNodeShape(node); | |||
| if (inputs_shape.empty()) { | |||
| @@ -2029,7 +2029,7 @@ RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const Te | |||
| } | |||
| // Infer Redistribution op list for stand alone and loss layout. | |||
| RankList dev_list = g_device_manager->GetDeviceListByStageId(0); | |||
| RankList dev_list = g_device_manager->GetDeviceListInThisStage(); | |||
| if (tensor_redistribution.Init(stand_alone_layout, loss_layout, dev_list) == FAILED) { | |||
| MS_LOG(EXCEPTION) << "Redistribution for Sens init failed."; | |||
| } | |||
| @@ -3093,7 +3093,7 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) { | |||
| if (full_batch) { | |||
| return; | |||
| } | |||
| auto dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| auto dev_num = g_device_manager->stage_device_num(); | |||
| auto parameters = root->parameters(); | |||
| for (auto ¶meter : parameters) { | |||
| if (IsUsedParameter(root, parameter)) { | |||