ALLGATHER, and REDUCESCATTER with different factors; change the BETA and GAMMA value in cost model.tags/v0.2.0-alpha
| @@ -34,7 +34,7 @@ namespace parallel { | |||||
| #define OPERATOR_TO_OPERATOR_CONNECTOR "-" | #define OPERATOR_TO_OPERATOR_CONNECTOR "-" | ||||
| #define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0) | #define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0) | ||||
| #define DEFAULT_COST_MODEL_ALPHA 1.0 | #define DEFAULT_COST_MODEL_ALPHA 1.0 | ||||
| #define DEFAULT_COST_MODEL_BETA 260.0 | |||||
| #define DEFAULT_COST_MODEL_BETA 400.0 | |||||
| #define DEFAULT_COST_MODEL_GAMMA 0.001 | #define DEFAULT_COST_MODEL_GAMMA 0.001 | ||||
| #define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true | #define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true | ||||
| #define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0 | #define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0 | ||||
| @@ -23,7 +23,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, | Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, | ||||
| RankList dev_list) { | |||||
| RankList dev_list, bool is_cost_model) { | |||||
| in_tensor_map_ = tensor_layout.tensor_map(); | in_tensor_map_ = tensor_layout.tensor_map(); | ||||
| dev_mat_ = tensor_layout.device_arrangement(); | dev_mat_ = tensor_layout.device_arrangement(); | ||||
| @@ -51,6 +51,8 @@ Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, cons | |||||
| for (int32_t item : map) { | for (int32_t item : map) { | ||||
| map_[key++] = item; | map_[key++] = item; | ||||
| } | } | ||||
| is_cost_model_ = is_cost_model; | |||||
| return Status::SUCCESS; | return Status::SUCCESS; | ||||
| } | } | ||||
| @@ -130,15 +132,26 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() { | |||||
| std::any_of(map_.begin(), map_.end(), | std::any_of(map_.begin(), map_.end(), | ||||
| [out_dim](const RedistributionOperatorMap::value_type& a) { return a.second == out_dim; })) { | [out_dim](const RedistributionOperatorMap::value_type& a) { return a.second == out_dim; })) { | ||||
| int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); | int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); | ||||
| Args args_allconcat = {cat_dim, out_dim, dev_mat_.GetDimByReverseIdx(IntToUint(out_dim))}; | |||||
| Args args_allsplit = {dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)), UintToInt(index), out_dim}; | |||||
| if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) { | |||||
| MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; | |||||
| return Status::FAILED; | |||||
| } | |||||
| if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) { | |||||
| MS_LOG(ERROR) << "Insert SplitByAxis Error!"; | |||||
| return Status::FAILED; | |||||
| int32_t dev_num = dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)); | |||||
| if (is_cost_model_) { | |||||
| int32_t dev_dim = in_tensor_map_.GetDimByIdx(IntToUint(cat_dim)); | |||||
| Args args_alltoall = {dev_mat_.GetDimByReverseIdx(IntToUint(dev_dim)), UintToInt(index), cat_dim, dev_dim, | |||||
| dev_num}; | |||||
| if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) { | |||||
| MS_LOG(ERROR) << "Insert PermuteByAxis Error!"; | |||||
| return Status::FAILED; | |||||
| } | |||||
| } else { | |||||
| Args args_allconcat = {cat_dim, out_dim, dev_num}; | |||||
| Args args_allsplit = {dev_num, UintToInt(index), out_dim}; | |||||
| if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) { | |||||
| MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; | |||||
| return Status::FAILED; | |||||
| } | |||||
| if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) { | |||||
| MS_LOG(ERROR) << "Insert SplitByAxis Error!"; | |||||
| return Status::FAILED; | |||||
| } | |||||
| } | } | ||||
| (void)map_.erase(iter++); | (void)map_.erase(iter++); | ||||
| map_[IntToSize(cat_dim)] = NONE; | map_[IntToSize(cat_dim)] = NONE; | ||||
| @@ -40,7 +40,8 @@ class RedistributionOperatorInfer { | |||||
| public: | public: | ||||
| const int NONE = -1; | const int NONE = -1; | ||||
| explicit RedistributionOperatorInfer(bool construct_op_flag = true) : construct_op_flag_(construct_op_flag) {} | explicit RedistributionOperatorInfer(bool construct_op_flag = true) : construct_op_flag_(construct_op_flag) {} | ||||
| Status Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, RankList dev_list); | |||||
| Status Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, RankList dev_list, | |||||
| bool is_cost_model = false); | |||||
| ~RedistributionOperatorInfer() = default; | ~RedistributionOperatorInfer() = default; | ||||
| OperatorList operator_list() const { return operator_list_; } | OperatorList operator_list() const { return operator_list_; } | ||||
| OperatorVector operator_vector() const { return operator_vector_; } | OperatorVector operator_vector() const { return operator_vector_; } | ||||
| @@ -67,6 +68,7 @@ class RedistributionOperatorInfer { | |||||
| ConstructOperator constructor_; | ConstructOperator constructor_; | ||||
| RankList dev_list_; | RankList dev_list_; | ||||
| bool construct_op_flag_; | bool construct_op_flag_; | ||||
| bool is_cost_model_; | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -40,7 +40,7 @@ Status TensorRedistribution::Init(const TensorLayout& from, const TensorLayout& | |||||
| return Status::SUCCESS; | return Status::SUCCESS; | ||||
| } | } | ||||
| RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList() { | |||||
| RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) { | |||||
| // Step 1: Match device arrangement between from_ and to_ | // Step 1: Match device arrangement between from_ and to_ | ||||
| RedistributionLayoutTransfer layout_transfer; | RedistributionLayoutTransfer layout_transfer; | ||||
| Status status = layout_transfer.Init(from_, to_); | Status status = layout_transfer.Init(from_, to_); | ||||
| @@ -62,7 +62,7 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL | |||||
| MS_LOG(DEBUG) << "reshape to_ " << to_.ToString(); | MS_LOG(DEBUG) << "reshape to_ " << to_.ToString(); | ||||
| // Step 2: Infer redistribution and insert operators | // Step 2: Infer redistribution and insert operators | ||||
| RedistributionOperatorInfer operator_infer(construct_op_flag_); | RedistributionOperatorInfer operator_infer(construct_op_flag_); | ||||
| if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_) == Status::FAILED) { | |||||
| if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) { | |||||
| MS_LOG(ERROR) << "Init operatorInfer failed!"; | MS_LOG(ERROR) << "Init operatorInfer failed!"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -138,7 +138,7 @@ Status TensorRedistribution::InferReshape(const TensorLayout& from_layout, const | |||||
| } | } | ||||
| Status TensorRedistribution::ComputeCost() { | Status TensorRedistribution::ComputeCost() { | ||||
| RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(); | |||||
| RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true); | |||||
| if (redistribution_oplist_ptr == nullptr) { | if (redistribution_oplist_ptr == nullptr) { | ||||
| MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed"; | MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed"; | ||||
| return Status::FAILED; | return Status::FAILED; | ||||
| @@ -151,14 +151,22 @@ Status TensorRedistribution::ComputeCost() { | |||||
| std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>()); | std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>()); | ||||
| std::string str = op.first; | std::string str = op.first; | ||||
| if (str == PERMUTE_BY_AXIS) { | if (str == PERMUTE_BY_AXIS) { | ||||
| // The shape does not change after PermuteByAxis operation. | |||||
| // communication cost = all_to_all + all_to_all = 2 * slice_shape | |||||
| // computation cost = slice_shape | |||||
| forward_comm_cost_ += prod; | |||||
| backward_comm_cost_ += prod; | |||||
| comm_cost_ += 2.0 * prod; | |||||
| computation_cost_ += prod; | |||||
| memory_cost_ += prod; | |||||
| // Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost. | |||||
| // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape | |||||
| forward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; | |||||
| backward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; | |||||
| comm_cost_ += 2.0 * prod * ALLTOALL_SCALE_FACTOR; | |||||
| int32_t concat_dim = op.second[2]; | |||||
| if (concat_dim == 0) { | |||||
| // memory cost = all_gather | |||||
| computation_cost_ += prod; | |||||
| memory_cost_ += prod; | |||||
| } else { | |||||
| // memory cost = all_gather + split + concat | |||||
| int32_t dev_num = op.second[4]; | |||||
| computation_cost_ += (prod + prod * dev_num + prod * dev_num); | |||||
| memory_cost_ += (prod * dev_num + prod * dev_num + prod); | |||||
| } | |||||
| } else if (str == CONCAT_BY_AXIS) { | } else if (str == CONCAT_BY_AXIS) { | ||||
| // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape | // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape | ||||
| // computation cost = before_slice_shape | // computation cost = before_slice_shape | ||||
| @@ -168,9 +176,9 @@ Status TensorRedistribution::ComputeCost() { | |||||
| } | } | ||||
| double dev_num = op.second[2]; | double dev_num = op.second[2]; | ||||
| // here, communication cost = all_gather + reduce_scatter | // here, communication cost = all_gather + reduce_scatter | ||||
| forward_comm_cost_ += prod * dev_num; | |||||
| backward_comm_cost_ += prod; | |||||
| comm_cost_ += prod * (dev_num + 1.0); | |||||
| forward_comm_cost_ += prod * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; | |||||
| backward_comm_cost_ += prod * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; | |||||
| comm_cost_ += prod * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; | |||||
| int32_t concat_dim = op.second[0]; | int32_t concat_dim = op.second[0]; | ||||
| if (concat_dim == 0) { | if (concat_dim == 0) { | ||||
| // computation cost = all_gather | // computation cost = all_gather | ||||
| @@ -33,6 +33,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| constexpr double ALLTOALL_SCALE_FACTOR = 2.0; | |||||
| constexpr double ALLGATHER_REDUCESCATTER_SCALE_FACTOR = 0.5; | |||||
| class TensorRedistribution { | class TensorRedistribution { | ||||
| public: | public: | ||||
| explicit TensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false) | explicit TensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false) | ||||
| @@ -46,7 +48,7 @@ class TensorRedistribution { | |||||
| keep_reshape_(keep_reshape) {} | keep_reshape_(keep_reshape) {} | ||||
| Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list); | Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list); | ||||
| ~TensorRedistribution() = default; | ~TensorRedistribution() = default; | ||||
| RedistributionOpListPtr InferTensorRedistributionOperatorList(); | |||||
| RedistributionOpListPtr InferTensorRedistributionOperatorList(bool is_cost_model = false); | |||||
| OperatorList operator_list() const { return operator_list_; } | OperatorList operator_list() const { return operator_list_; } | ||||
| bool reshape_flag() const { return reshape_flag_; } | bool reshape_flag() const { return reshape_flag_; } | ||||
| Status ComputeCost(); | Status ComputeCost(); | ||||
| @@ -304,7 +304,7 @@ def train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): | |||||
| def test_train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192 | def test_train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192 | ||||
| cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0) | |||||
| cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0) | |||||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) | cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) | ||||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) | cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) | ||||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) | cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) | ||||
| @@ -651,7 +651,7 @@ def test_train_32k_8p_fusion2(epoch_size=3, batch_size=32, num_classes=32768): # | |||||
| def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576 #131072 #32768 #8192 | def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576 #131072 #32768 #8192 | ||||
| dev_num = 8 | dev_num = 8 | ||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) | context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) | ||||
| cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0) | |||||
| cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0) | |||||
| set_algo_parameters(elementwise_op_strategy_follow=True) | set_algo_parameters(elementwise_op_strategy_follow=True) | ||||
| resset_op_id() | resset_op_id() | ||||
| np.random.seed(6) | np.random.seed(6) | ||||
| @@ -86,7 +86,7 @@ def test_two_matmul(): | |||||
| costmodel_alpha = cost_model_context.get_cost_model_context("costmodel_alpha") | costmodel_alpha = cost_model_context.get_cost_model_context("costmodel_alpha") | ||||
| assert costmodel_alpha == 1.0 | assert costmodel_alpha == 1.0 | ||||
| costmodel_beta = cost_model_context.get_cost_model_context("costmodel_beta") | costmodel_beta = cost_model_context.get_cost_model_context("costmodel_beta") | ||||
| assert costmodel_beta == 260.0 | |||||
| assert costmodel_beta == 400.0 | |||||
| costmodel_gamma = cost_model_context.get_cost_model_context("costmodel_gamma") | costmodel_gamma = cost_model_context.get_cost_model_context("costmodel_gamma") | ||||
| assert costmodel_gamma == 0.001 | assert costmodel_gamma == 0.001 | ||||
| costmodel_communi_threshold = cost_model_context.get_cost_model_context("costmodel_communi_threshold") | costmodel_communi_threshold = cost_model_context.get_cost_model_context("costmodel_communi_threshold") | ||||