ALLGATHER, and REDUCESCATTER with different factors; change the BETA and GAMMA value in cost model.tags/v0.2.0-alpha
| @@ -34,7 +34,7 @@ namespace parallel { | |||
| #define OPERATOR_TO_OPERATOR_CONNECTOR "-" | |||
| #define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0) | |||
| #define DEFAULT_COST_MODEL_ALPHA 1.0 | |||
| #define DEFAULT_COST_MODEL_BETA 260.0 | |||
| #define DEFAULT_COST_MODEL_BETA 400.0 | |||
| #define DEFAULT_COST_MODEL_GAMMA 0.001 | |||
| #define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true | |||
| #define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0 | |||
| @@ -23,7 +23,7 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, | |||
| RankList dev_list) { | |||
| RankList dev_list, bool is_cost_model) { | |||
| in_tensor_map_ = tensor_layout.tensor_map(); | |||
| dev_mat_ = tensor_layout.device_arrangement(); | |||
| @@ -51,6 +51,8 @@ Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, cons | |||
| for (int32_t item : map) { | |||
| map_[key++] = item; | |||
| } | |||
| is_cost_model_ = is_cost_model; | |||
| return Status::SUCCESS; | |||
| } | |||
| @@ -130,15 +132,26 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() { | |||
| std::any_of(map_.begin(), map_.end(), | |||
| [out_dim](const RedistributionOperatorMap::value_type& a) { return a.second == out_dim; })) { | |||
| int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); | |||
| Args args_allconcat = {cat_dim, out_dim, dev_mat_.GetDimByReverseIdx(IntToUint(out_dim))}; | |||
| Args args_allsplit = {dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)), UintToInt(index), out_dim}; | |||
| if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) { | |||
| MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; | |||
| return Status::FAILED; | |||
| } | |||
| if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) { | |||
| MS_LOG(ERROR) << "Insert SplitByAxis Error!"; | |||
| return Status::FAILED; | |||
| int32_t dev_num = dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)); | |||
| if (is_cost_model_) { | |||
| int32_t dev_dim = in_tensor_map_.GetDimByIdx(IntToUint(cat_dim)); | |||
| Args args_alltoall = {dev_mat_.GetDimByReverseIdx(IntToUint(dev_dim)), UintToInt(index), cat_dim, dev_dim, | |||
| dev_num}; | |||
| if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) { | |||
| MS_LOG(ERROR) << "Insert PermuteByAxis Error!"; | |||
| return Status::FAILED; | |||
| } | |||
| } else { | |||
| Args args_allconcat = {cat_dim, out_dim, dev_num}; | |||
| Args args_allsplit = {dev_num, UintToInt(index), out_dim}; | |||
| if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) { | |||
| MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; | |||
| return Status::FAILED; | |||
| } | |||
| if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) { | |||
| MS_LOG(ERROR) << "Insert SplitByAxis Error!"; | |||
| return Status::FAILED; | |||
| } | |||
| } | |||
| (void)map_.erase(iter++); | |||
| map_[IntToSize(cat_dim)] = NONE; | |||
| @@ -40,7 +40,8 @@ class RedistributionOperatorInfer { | |||
| public: | |||
| const int NONE = -1; | |||
| explicit RedistributionOperatorInfer(bool construct_op_flag = true) : construct_op_flag_(construct_op_flag) {} | |||
| Status Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, RankList dev_list); | |||
| Status Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, RankList dev_list, | |||
| bool is_cost_model = false); | |||
| ~RedistributionOperatorInfer() = default; | |||
| OperatorList operator_list() const { return operator_list_; } | |||
| OperatorVector operator_vector() const { return operator_vector_; } | |||
| @@ -67,6 +68,7 @@ class RedistributionOperatorInfer { | |||
| ConstructOperator constructor_; | |||
| RankList dev_list_; | |||
| bool construct_op_flag_; | |||
| bool is_cost_model_; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -40,7 +40,7 @@ Status TensorRedistribution::Init(const TensorLayout& from, const TensorLayout& | |||
| return Status::SUCCESS; | |||
| } | |||
| RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList() { | |||
| RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) { | |||
| // Step 1: Match device arrangement between from_ and to_ | |||
| RedistributionLayoutTransfer layout_transfer; | |||
| Status status = layout_transfer.Init(from_, to_); | |||
| @@ -62,7 +62,7 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL | |||
| MS_LOG(DEBUG) << "reshape to_ " << to_.ToString(); | |||
| // Step 2: Infer redistribution and insert operators | |||
| RedistributionOperatorInfer operator_infer(construct_op_flag_); | |||
| if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_) == Status::FAILED) { | |||
| if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) { | |||
| MS_LOG(ERROR) << "Init operatorInfer failed!"; | |||
| return nullptr; | |||
| } | |||
| @@ -138,7 +138,7 @@ Status TensorRedistribution::InferReshape(const TensorLayout& from_layout, const | |||
| } | |||
| Status TensorRedistribution::ComputeCost() { | |||
| RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(); | |||
| RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true); | |||
| if (redistribution_oplist_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed"; | |||
| return Status::FAILED; | |||
| @@ -151,14 +151,22 @@ Status TensorRedistribution::ComputeCost() { | |||
| std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>()); | |||
| std::string str = op.first; | |||
| if (str == PERMUTE_BY_AXIS) { | |||
| // The shape does not change after PermuteByAxis operation. | |||
| // communication cost = all_to_all + all_to_all = 2 * slice_shape | |||
| // computation cost = slice_shape | |||
| forward_comm_cost_ += prod; | |||
| backward_comm_cost_ += prod; | |||
| comm_cost_ += 2.0 * prod; | |||
| computation_cost_ += prod; | |||
| memory_cost_ += prod; | |||
| // Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost. | |||
| // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape | |||
| forward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; | |||
| backward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; | |||
| comm_cost_ += 2.0 * prod * ALLTOALL_SCALE_FACTOR; | |||
| int32_t concat_dim = op.second[2]; | |||
| if (concat_dim == 0) { | |||
| // memory cost = all_gather | |||
| computation_cost_ += prod; | |||
| memory_cost_ += prod; | |||
| } else { | |||
| // memory cost = all_gather + split + concat | |||
| int32_t dev_num = op.second[4]; | |||
| computation_cost_ += (prod + prod * dev_num + prod * dev_num); | |||
| memory_cost_ += (prod * dev_num + prod * dev_num + prod); | |||
| } | |||
| } else if (str == CONCAT_BY_AXIS) { | |||
| // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape | |||
| // computation cost = before_slice_shape | |||
| @@ -168,9 +176,9 @@ Status TensorRedistribution::ComputeCost() { | |||
| } | |||
| double dev_num = op.second[2]; | |||
| // here, communication cost = all_gather + reduce_scatter | |||
| forward_comm_cost_ += prod * dev_num; | |||
| backward_comm_cost_ += prod; | |||
| comm_cost_ += prod * (dev_num + 1.0); | |||
| forward_comm_cost_ += prod * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; | |||
| backward_comm_cost_ += prod * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; | |||
| comm_cost_ += prod * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; | |||
| int32_t concat_dim = op.second[0]; | |||
| if (concat_dim == 0) { | |||
| // computation cost = all_gather | |||
| @@ -33,6 +33,8 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| constexpr double ALLTOALL_SCALE_FACTOR = 2.0; | |||
| constexpr double ALLGATHER_REDUCESCATTER_SCALE_FACTOR = 0.5; | |||
| class TensorRedistribution { | |||
| public: | |||
| explicit TensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false) | |||
| @@ -46,7 +48,7 @@ class TensorRedistribution { | |||
| keep_reshape_(keep_reshape) {} | |||
| Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list); | |||
| ~TensorRedistribution() = default; | |||
| RedistributionOpListPtr InferTensorRedistributionOperatorList(); | |||
| RedistributionOpListPtr InferTensorRedistributionOperatorList(bool is_cost_model = false); | |||
| OperatorList operator_list() const { return operator_list_; } | |||
| bool reshape_flag() const { return reshape_flag_; } | |||
| Status ComputeCost(); | |||
| @@ -304,7 +304,7 @@ def train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): | |||
| def test_train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192 | |||
| cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0) | |||
| cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0) | |||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) | |||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) | |||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) | |||
| @@ -651,7 +651,7 @@ def test_train_32k_8p_fusion2(epoch_size=3, batch_size=32, num_classes=32768): # | |||
| def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576 #131072 #32768 #8192 | |||
| dev_num = 8 | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) | |||
| cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0) | |||
| cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0) | |||
| set_algo_parameters(elementwise_op_strategy_follow=True) | |||
| resset_op_id() | |||
| np.random.seed(6) | |||
| @@ -86,7 +86,7 @@ def test_two_matmul(): | |||
| costmodel_alpha = cost_model_context.get_cost_model_context("costmodel_alpha") | |||
| assert costmodel_alpha == 1.0 | |||
| costmodel_beta = cost_model_context.get_cost_model_context("costmodel_beta") | |||
| assert costmodel_beta == 260.0 | |||
| assert costmodel_beta == 400.0 | |||
| costmodel_gamma = cost_model_context.get_cost_model_context("costmodel_gamma") | |||
| assert costmodel_gamma == 0.001 | |||
| costmodel_communi_threshold = cost_model_context.get_cost_model_context("costmodel_communi_threshold") | |||