Merge pull request !6782 from gziyan/enable_optimizer_shard_in_auto_paralleltags/v1.1.0
| @@ -22,6 +22,7 @@ | |||
| #include "ir/anf.h" | |||
| #include "frontend/parallel/allreduce_fusion/allreduce_graph.h" | |||
| #include "frontend/parallel/status.h" | |||
| #include "frontend/parallel/ops_info/ops_utils.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -35,7 +36,6 @@ constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME = 0 | |||
| constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH = 0.1; | |||
| constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER = 0.1; | |||
| constexpr char FUSION[] = "fusion"; | |||
| constexpr char PARAMETER[] = "parameter"; | |||
| const uint32_t MAX_RECURSIVE_CALL_TIMES = 100; | |||
| class AllreduceFusion { | |||
| @@ -42,15 +42,11 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { | |||
| auto device_arrangement = tensor_layout->device_arrangement().array(); | |||
| auto tensor_map = tensor_layout->tensor_map().array(); | |||
| auto slice_shape = tensor_layout->slice_shape().array(); | |||
| Shape field_size = {tensor_layout->get_field_size()}; | |||
| Shape uniform_split; | |||
| if (tensor_layout->uniform_split()) { | |||
| uniform_split.push_back(1); | |||
| } else { | |||
| uniform_split.push_back(0); | |||
| } | |||
| std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size, uniform_split}; | |||
| int32_t field_size = tensor_layout->get_field_size(); | |||
| bool uniform_split = tensor_layout->uniform_split(); | |||
| std::string opt_shard_group = tensor_layout->opt_shard_group(); | |||
| py::tuple layout = | |||
| py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group); | |||
| dict[py::str(name)] = layout; | |||
| MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); | |||
| } | |||
| @@ -226,6 +226,21 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string & | |||
| return op; | |||
| } | |||
| Operator CreateAllGatherOp(const std::string &group) { | |||
| OperatorName operator_name = ALL_GATHER; | |||
| ValuePtr attr0_value = MakeValue(group); // group | |||
| Attr attr0 = std::make_pair(GROUP, attr0_value); | |||
| OperatorAttrs operator_attrs; | |||
| operator_attrs.push_back(attr0); | |||
| OperatorParams operator_param; | |||
| OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); | |||
| Operator op = std::make_pair(operator_name, operator_arg); | |||
| MS_LOG(INFO) << "Create allgather op success, the group is " << group; | |||
| return op; | |||
| } | |||
| // use for get tensor slice | |||
| Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { | |||
| Shape tensor_map = tensor_layout.tensor_map().array(); | |||
| @@ -164,6 +164,10 @@ class OperatorInfo { | |||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | |||
| void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; } | |||
| int32_t stage_id() const { return stage_id_; } | |||
| void set_opt_shard_flag(bool flag) { opt_shard_flag_ = flag; } | |||
| bool opt_shard_flag() { return opt_shard_flag_; } | |||
| Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group); | |||
| // Key for user data. | |||
| constexpr static char key[] = "OpInfo"; | |||
| @@ -180,7 +184,6 @@ class OperatorInfo { | |||
| Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape); | |||
| void SetDeviceListByStrategy(); | |||
| void SetRepeatedCalcDevMatrix(); | |||
| Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group); | |||
| Status CreateGroupByDim(size_t axis, std::vector<Group> *group); | |||
| Status InferAttrs(); | |||
| void ResetQueueMember(); | |||
| @@ -263,6 +266,7 @@ class OperatorInfo { | |||
| private: | |||
| OperatorCostPtr operator_cost_; | |||
| std::vector<TypePtr> outputs_type_; | |||
| bool opt_shard_flag_ = false; | |||
| }; | |||
| Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy); | |||
| @@ -270,6 +274,7 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap | |||
| Operator CreateVirtualDivOp(int32_t div_num); | |||
| Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group); | |||
| Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group); | |||
| Operator CreateAllGatherOp(const std::string &group); | |||
| Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); | |||
| OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); | |||
| int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); | |||
| @@ -98,6 +98,7 @@ constexpr char BEGIN[] = "begin"; | |||
| constexpr char END[] = "end"; | |||
| constexpr char STRIDES[] = "strides"; | |||
| constexpr char GROUP[] = "group"; | |||
| constexpr char FUSION[] = "fusion"; | |||
| constexpr char AXIS[] = "axis"; | |||
| constexpr char OUTPUT_NUM[] = "output_num"; | |||
| constexpr char SPLIT_COUNT[] = "split_count"; | |||
| @@ -140,6 +141,7 @@ constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; | |||
| constexpr char FIELD_SIZE[] = "field_size"; | |||
| constexpr char OPTIMIZER_SUB_STRING[] = "optimizer"; | |||
| constexpr char DEVICE[] = "Device"; | |||
| constexpr char PARALLEL_OPTIMIZER_ALLGATHER[] = "parallel_optimizer_allgather"; | |||
| // Operator | |||
| constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; | |||
| @@ -121,6 +121,7 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An | |||
| new_node->set_scope(scope); | |||
| node_input[0]->set_scope(scope); | |||
| manager->SetEdge(node, SizeToInt(index), new_node); | |||
| MS_LOG(INFO) << "Insert " << instance_name << " success"; | |||
| } | |||
| std::string CreateInstanceName(const CNodePtr &node, size_t index) { | |||
| @@ -924,7 +925,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo | |||
| MirrorOps mirror_ops = distribute_operator->mirror_ops(); | |||
| VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op(); | |||
| // insert mirror op | |||
| if (!mirror_ops.empty()) { | |||
| if (!mirror_ops.empty() && !distribute_operator->opt_shard_flag()) { | |||
| MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name(); | |||
| InsertMirrorOps(mirror_ops, node); | |||
| } | |||
| @@ -1263,6 +1264,37 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode | |||
| return std::make_pair(nullptr, 0); | |||
| } | |||
| void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator, | |||
| const CNodePtr &cnode, const AnfNodePtr ¶meter, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(distribute_operator); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| std::vector<Group> dev_group; | |||
| // create communication group for allgather operator | |||
| if (distribute_operator->CreateGroupByTensorMap(tensor_layout->origin_tensor_map().array(), &dev_group) == | |||
| Status::SUCCESS && | |||
| !dev_group.empty()) { | |||
| // set optimizer shard split flag to avoid inserting mirror_ops | |||
| distribute_operator->set_opt_shard_flag(true); | |||
| // insert allgather operator between shard parameter and cnode | |||
| Operator op = CreateAllGatherOp(dev_group[0].name()); | |||
| auto graph = cnode->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| InsertNode(op, cnode, index, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); | |||
| // set communication group in tensor layout for checkpoint saving | |||
| tensor_layout->set_opt_shard_group(dev_group[0].name()); | |||
| // add fusion flag | |||
| auto allgather = cnode->input(index)->cast<CNodePtr>(); | |||
| auto prim = GetValueNode<PrimitivePtr>(allgather->input(0)); | |||
| auto attrs = prim->attrs(); | |||
| attrs["fusion"] = MakeValue(1); | |||
| prim->SetAttrs(attrs); | |||
| MS_LOG(INFO) << "Parallel optimizer is applied on " << parameter->ToString(); | |||
| } else { | |||
| MS_LOG(ERROR) << "Parallel optimizer applied on " << parameter->ToString() << "failed!"; | |||
| } | |||
| } | |||
| void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, int> &res) { | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| AbstractBasePtr abstract = parameter->abstract(); | |||
| @@ -1280,7 +1312,22 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i | |||
| << distribute_operator->inputs_tensor_info().size(); | |||
| } | |||
| TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)]; | |||
| Shape slice_shape = tensorinfo_in.slice_shape(); | |||
| TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer(); | |||
| Shape slice_shape = tensor_layout.slice_shape().array(); | |||
| if (enable_parallel_optimizer) { | |||
| if (!ParameterRequireGrad(parameter)) { | |||
| // only trainable parameters need parallel optimizer | |||
| MS_LOG(INFO) << "Parallel optimizer is no need for " << parameter->ToString(); | |||
| } else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) { | |||
| // get a totally shard tensor slice shape if the weight is repeated on devices | |||
| // and the shape of the first dimension could be divided | |||
| // apply parallel optimizer on parameters | |||
| ApplyParallelOptOnParam(&tensor_layout, distribute_operator, cnode, parameter, IntToSize(res.second)); | |||
| slice_shape = tensor_layout.opt_shard_slice_shape(); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape " | |||
| << MakeValue(slice_shape)->ToString() << ", op name is " << distribute_operator->name(); | |||
| std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape); | |||
| @@ -1290,7 +1337,6 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i | |||
| MS_EXCEPTION_IF_NULL(cloned_abstract); | |||
| cloned_abstract->set_shape(parallel_shape); | |||
| parameter->set_abstract(cloned_abstract); | |||
| TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); | |||
| ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(parameter_ptr); | |||
| parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout)); | |||
| @@ -160,6 +160,9 @@ RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode); | |||
| std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node); | |||
| ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)); | |||
| void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator, | |||
| const CNodePtr &cnode, const AnfNodePtr ¶meter, size_t index); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -389,5 +389,39 @@ TensorLayout TensorLayout::SqueezeShape() const { | |||
| (void)out.Init(device_arrangement_, out_map, out_shape); | |||
| return out; | |||
| } | |||
| // Generate a totally shard tensor slice shape for parallel optimizer | |||
| Status TensorLayout::GenerateOptShardSliceShape() { | |||
| MS_LOG(INFO) << "layout for GetOptShardSliceShape is " << StandardToString(); | |||
| Shape dev_max = device_arrangement_.array(); | |||
| Shape tensor_map = tensor_map_.array(); | |||
| Shape repeated_dev; | |||
| for (size_t i = 0; i < dev_max.size(); i++) { | |||
| if (tensor_map_.GetIndexByValue(i) == MAP_NONE) { | |||
| repeated_dev.push_back(dev_max[dev_max.size() - 1 - i]); | |||
| dev_max[dev_max.size() - 1 - i] = 1; | |||
| } | |||
| } | |||
| if (repeated_dev.empty()) { | |||
| MS_LOG(INFO) << "Tensor is totally shard already."; | |||
| return Status::FAILED; | |||
| } | |||
| int64_t repeated_num = | |||
| std::accumulate(repeated_dev.begin(), repeated_dev.end(), static_cast<int64_t>(1), std::multiplies<int64_t>()); | |||
| int64_t split_num; | |||
| if (tensor_map[0] == MAP_NONE) { | |||
| split_num = repeated_num; | |||
| } else { | |||
| split_num = dev_max[dev_max.size() - 1 - tensor_map[0]] * repeated_num; | |||
| } | |||
| if (tensor_shape_.array()[0] % split_num != 0) { | |||
| MS_LOG(INFO) << "Tensor could not be shard on the first dimension."; | |||
| return Status::FAILED; | |||
| } | |||
| Shape origin_slice_shape = slice_shape().array(); | |||
| origin_slice_shape[0] = tensor_shape_.array()[0] / split_num; | |||
| opt_shard_slice_shape_ = origin_slice_shape; | |||
| return Status::SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -21,7 +21,9 @@ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <functional> | |||
| #include "frontend/parallel/device_manager.h" | |||
| #include "frontend/parallel/status.h" | |||
| #include "frontend/parallel/tensor_layout/arrangement.h" | |||
| @@ -86,6 +88,14 @@ class TensorLayout { | |||
| TensorLayout SqueezeShape() const; | |||
| Status GenerateOptShardSliceShape(); | |||
| Shape opt_shard_slice_shape() { return opt_shard_slice_shape_; } | |||
| void set_opt_shard_group(std::string name) { opt_shard_group_ = std::move(name); } | |||
| std::string opt_shard_group() { return opt_shard_group_; } | |||
| // Key for user data. | |||
| constexpr static char key[] = "TLayout"; | |||
| @@ -109,6 +119,8 @@ class TensorLayout { | |||
| bool skip_redistribution_ = false; | |||
| int32_t field_size_ = 0; | |||
| bool uniform_split_ = true; | |||
| Shape opt_shard_slice_shape_; | |||
| std::string opt_shard_group_ = ""; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -396,8 +396,8 @@ class Parameter(MetaTensor): | |||
| if self.inited_param is not None: | |||
| return self.inited_param | |||
| if layout is not None: | |||
| if not isinstance(layout, list): | |||
| raise TypeError("The layout should be list! layout is {}.".format(layout)) | |||
| if not isinstance(layout, tuple): | |||
| raise TypeError("The layout should be tuple! layout is {}.".format(layout)) | |||
| if len(layout) < 3: | |||
| raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout)) | |||
| slice_index = int(_get_slice_index(layout[0], layout[1])) | |||
| @@ -334,7 +334,7 @@ def _context(): | |||
| all_reduce_fusion_config=list, pipeline_stages=int) | |||
| def set_auto_parallel_context(**kwargs): | |||
| r""" | |||
| Set auto parallel context. | |||
| Set auto parallel context, which is valid only for Ascend and GPU target. | |||
| Auto parallel context should be configured before the initialization of your network. | |||
| @@ -348,17 +348,17 @@ def set_auto_parallel_context(**kwargs): | |||
| Some configurations are parallel mode specific, see the below table for details: | |||
| =========================== =========================== ================= | |||
| Common AUTO_PARALLEL DATA_PARALLEL | |||
| =========================== =========================== ================= | |||
| device_num gradient_fp32_sync enable_parallel_optimizer | |||
| =========================== =========================== | |||
| Common AUTO_PARALLEL | |||
| =========================== =========================== | |||
| device_num gradient_fp32_sync | |||
| global_rank loss_repeated_mean | |||
| gradients_mean auto_parallel_search_mode | |||
| parallel_mode strategy_ckpt_load_file | |||
| all_reduce_fusion_config strategy_ckpt_save_file | |||
| \ full_batch | |||
| \ pipeline_stages | |||
| =========================== =========================== ================= | |||
| enable_parallel_optimizer full_batch | |||
| \ pipeline_stages | |||
| =========================== =========================== | |||
| Args: | |||
| device_num (int): Available device number, the value must be in [1, 4096]. Default: 1. | |||
| @@ -387,7 +387,7 @@ def set_auto_parallel_context(**kwargs): | |||
| - recursive_programming: Recursive programming search mode. | |||
| - dynamic_programming: Dynamic programming search mode. | |||
| parameter_broadcast (bool): Whether to broadcast parameters before training. | |||
| parameter_broadcast (bool): A developing feature. Whether to broadcast parameters before training. | |||
| "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter | |||
| broadcast. Default: False. | |||
| strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' | |||
| @@ -395,9 +395,9 @@ def set_auto_parallel_context(**kwargs): | |||
| full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter | |||
| should be set with True. Default: False. | |||
| enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for | |||
| data parallel training in the benefit of time and memory saving. For now, | |||
| `Lamb` and `AdamWeightDecay` are supported in data parallel mode. No Default, if it is not set, | |||
| the fusion is closed. | |||
| data parallel training in the benefit of time and memory saving. For now, auto parallel mode | |||
| supports all optimizers. Data parallel mode only supports `Lamb` and `AdamWeightDecay`. | |||
| Default: False. | |||
| all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM | |||
| and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed. | |||
| pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how | |||
| @@ -148,15 +148,18 @@ class Optimizer(Cell): | |||
| self.reciprocal_scale = 1.0 / loss_scale | |||
| self.param_length = len(self.parameters) | |||
| self.map_ = C.Map() | |||
| use_parallel = context.get_auto_parallel_context("enable_parallel_optimizer") | |||
| self.use_parallel = use_parallel | |||
| if use_parallel: | |||
| if context.get_auto_parallel_context("enable_parallel_optimizer"): | |||
| if _get_parallel_mode() == ParallelMode.DATA_PARALLEL: | |||
| self.use_parallel = True | |||
| elif _get_parallel_mode() == ParallelMode.STAND_ALONE: | |||
| raise RuntimeError("Parallel optimizer is not supported in stand alone mode.") | |||
| else: | |||
| self.use_parallel = False | |||
| else: | |||
| self.use_parallel = False | |||
| if self.use_parallel: | |||
| if self.cls_name not in ["Lamb", "AdamWeightDecay"]: | |||
| raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name)) | |||
| if _get_parallel_mode() != ParallelMode.DATA_PARALLEL: | |||
| raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format | |||
| (_get_parallel_mode())) | |||
| self.dev_num = _get_device_num() | |||
| if self.dev_num > self.param_length: | |||
| raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is" | |||
| @@ -83,8 +83,10 @@ def get_bprop_broad_cast(self): | |||
| def get_bprop_all_gather(self): | |||
| """Generate bprop for AllGather""" | |||
| all_gather_grad = ReduceScatter(ReduceOp.SUM, self.group) | |||
| fusion = self.get_attr_dict()["fusion"] | |||
| all_gather_grad.add_prim_attr("fusion", fusion) | |||
| if self.instance_name: | |||
| instance_name = "grad" + self.instance_name | |||
| instance_name = "grad_" + self.instance_name | |||
| all_gather_grad.set_prim_instance_name(instance_name) | |||
| def bprop(x, out, dout): | |||
| @@ -15,6 +15,7 @@ | |||
| """comm_ops""" | |||
| from mindspore.common import Tensor | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group | |||
| @@ -158,6 +159,7 @@ class AllGather(PrimitiveWithInfer): | |||
| validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name) | |||
| self.add_prim_attr('rank_size', self.rank_size) | |||
| self.add_prim_attr('group', _get_group(group)) | |||
| self.add_prim_attr('fusion', 0) | |||
| def infer_shape(self, x_shape): | |||
| validator.check_integer("x shape", len(x_shape), 0, Rel.GT, self.name) | |||
| @@ -268,6 +270,7 @@ class ReduceScatter(PrimitiveWithInfer): | |||
| self.rank_size = get_group_size(_get_group(group)) | |||
| self.add_prim_attr('rank_size', self.rank_size) | |||
| self.add_prim_attr('group', _get_group(group)) | |||
| self.add_prim_attr('fusion', 0) | |||
| def infer_shape(self, x_shape): | |||
| if x_shape[0] % self.rank_size != 0: | |||
| @@ -526,4 +529,4 @@ class _GetTensorSlice(PrimitiveWithInfer): | |||
| from mindspore.parallel._tensor import _load_tensor | |||
| validator.check_value_type("dev_mat", dev_mat, [tuple], self.name) | |||
| validator.check_value_type("tensor_map", tensor_map, [tuple], self.name) | |||
| return _load_tensor(x, dev_mat, tensor_map) | |||
| return Tensor(_load_tensor(x, dev_mat, tensor_map)) | |||
| @@ -37,12 +37,34 @@ class AllGatherCell(Cell): | |||
| return x | |||
| def get_allgather_cell(): | |||
| class SaveOptShardCkptCell(Cell): | |||
| """ | |||
| Allgather cell, used in optimizer parallel scenario. | |||
| Firstly gather the tensor to original layout in the specified device group. | |||
| Then gather the whole parameter slices from all devices. | |||
| Note: | |||
| This could be optimized later with less communication consumption. | |||
| """ | |||
| def __init__(self, group): | |||
| super(SaveOptShardCkptCell, self).__init__(auto_prefix=False) | |||
| self.allgather1 = AllGather(group) | |||
| self.allgather2 = AllGather() | |||
| def construct(self, x): | |||
| x = self.allgather1(x) | |||
| x = self.allgather2(x) | |||
| return x | |||
| def get_allgather_cell(group): | |||
| """Get AllGatherCell object.""" | |||
| global _allgather_cell | |||
| if not _allgather_cell: | |||
| if group: | |||
| _allgather_cell = SaveOptShardCkptCell(group) | |||
| else: | |||
| _allgather_cell = AllGatherCell() | |||
| return _allgather_cell | |||
| @@ -16,7 +16,7 @@ | |||
| import numpy as np | |||
| from mindspore.common.tensor import Tensor | |||
| from ..communication.management import get_rank | |||
| from ..communication.management import get_rank, get_group_size | |||
| def _get_tensor_strategy(dev_mat, tensor_map): | |||
| @@ -168,6 +168,7 @@ def _chunk_tensor_by_strategy(np_tensor, strategy): | |||
| raise ValueError("The length of np_tensor does not match the length of strategy!") | |||
| return _chunk_tensor(np_tensor, strategy, len(strategy)) | |||
| def _get_slice_index(dev_mat, tensor_map): | |||
| """ | |||
| Get the slice index for current slice. | |||
| @@ -184,6 +185,7 @@ def _get_slice_index(dev_mat, tensor_map): | |||
| tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank) | |||
| return tensor_slice_index | |||
| def _load_tensor(tensor, dev_mat, tensor_map): | |||
| """ | |||
| Get the tensor slice of the local device by the device matrix and the tensor map | |||
| @@ -194,7 +196,7 @@ def _load_tensor(tensor, dev_mat, tensor_map): | |||
| tensor_map (list): The split strategy of tensor. | |||
| Returns: | |||
| Tensor, the sliced tensor. | |||
| numpy.array, the sliced array. | |||
| Examples: | |||
| >>> tensor = Tensor(np.ones([32, 32])) | |||
| @@ -208,8 +210,7 @@ def _load_tensor(tensor, dev_mat, tensor_map): | |||
| np_tensor = tensor.asnumpy() | |||
| np_tensor_list = _chunk_tensor_by_strategy(np_tensor, tensor_strategy) | |||
| np_tensor_slice = np_tensor_list[int(tensor_slice_index)] | |||
| tensor_slice = Tensor(np_tensor_slice) | |||
| return tensor_slice | |||
| return np_tensor_slice | |||
| def _load_tensor_by_layout(tensor, layout): | |||
| @@ -227,18 +228,25 @@ def _load_tensor_by_layout(tensor, layout): | |||
| TypeError: If layout is not list. | |||
| ValueError: If the length of layout is not 3. | |||
| """ | |||
| if not isinstance(layout, list): | |||
| raise TypeError("The layout should be list! layout is {}".format(layout)) | |||
| if len(layout) < 5: | |||
| if not isinstance(layout, tuple): | |||
| raise TypeError("The layout should be tuple! layout is {}".format(layout)) | |||
| if len(layout) < 6: | |||
| raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout)) | |||
| dev_mat = layout[0] | |||
| tensor_map = layout[1] | |||
| uniform_split = layout[4] | |||
| if uniform_split[0] == 0: | |||
| group = layout[5] | |||
| if uniform_split == 0: | |||
| raise RuntimeError("The load tensor only support uniform split now") | |||
| if tensor.size() == 1: | |||
| return tensor | |||
| return _load_tensor(tensor, dev_mat, tensor_map) | |||
| tensor_slice = _load_tensor(tensor, dev_mat, tensor_map) | |||
| if group: | |||
| # get a totally shard tensor slice for parallel optimizer | |||
| rank = get_rank(group) | |||
| size = get_group_size(group) | |||
| tensor_slice = np.split(tensor_slice, size)[rank] | |||
| return Tensor(tensor_slice) | |||
| def _reshape_param_data(param_data, dev_mat, tensor_map): | |||
| @@ -294,6 +302,7 @@ def _reshape_param_data(param_data, dev_mat, tensor_map): | |||
| return Tensor(tensor_slices_new[0]) | |||
| def _reshape_param_data_with_weight(param_data, dev_mat, field_size): | |||
| """ | |||
| Combine param slice by the device matrix, used in model parallel scenario. | |||
| @@ -318,10 +327,10 @@ def _reshape_param_data_with_weight(param_data, dev_mat, field_size): | |||
| tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0) | |||
| tensor_slices_col = [] | |||
| for i in range(len(tensor_slices[0][0])): | |||
| tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size[0], -1) | |||
| tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size, -1) | |||
| for j in range(1, device_count): | |||
| tensor_slices_new = np.concatenate((tensor_slices_new,\ | |||
| np.array(tensor_slices[j][:, i]).reshape(field_size[0], -1)), axis=1) | |||
| np.array(tensor_slices[j][:, i]).reshape(field_size, -1)), axis=1) | |||
| tensor_slices_col.append(tensor_slices_new) | |||
| new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1) | |||
| for i in range(1, len(tensor_slices_col)): | |||
| @@ -398,7 +398,7 @@ def _get_merged_param_data(net, param_name, param_data): | |||
| Tensor, the combined tensor which with the whole data value. | |||
| """ | |||
| layout = net.parameter_layout_dict[param_name] | |||
| if len(layout) < 5: | |||
| if len(layout) < 6: | |||
| logger.info("layout dict does not contain the key %s", param_name) | |||
| return param_data | |||
| @@ -406,17 +406,19 @@ def _get_merged_param_data(net, param_name, param_data): | |||
| tensor_map = layout[1] | |||
| field_size = layout[3] | |||
| uniform_split = layout[4] | |||
| if uniform_split[0] == 0: | |||
| opt_shard_group = layout[5] | |||
| if uniform_split == 0: | |||
| raise RuntimeError("Save checkpoint only support uniform split tensor now.") | |||
| from mindspore.parallel._cell_wrapper import get_allgather_cell | |||
| from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight | |||
| # while any dim is not equal to -1, means param is splited and needs to be merged | |||
| # while any dim is not equal to -1, means param is split and needs to be merged | |||
| # pipeline parallel need to be supported here later | |||
| for dim in tensor_map: | |||
| if dim != -1: | |||
| allgather_net = get_allgather_cell() | |||
| if dim != -1 or opt_shard_group: | |||
| allgather_net = get_allgather_cell(opt_shard_group) | |||
| param_data = allgather_net(param_data) | |||
| if field_size[0]: | |||
| if field_size: | |||
| return _reshape_param_data_with_weight(param_data, dev_mat, field_size) | |||
| return _reshape_param_data(param_data, dev_mat, tensor_map) | |||
| @@ -35,11 +35,11 @@ class TestRedistributionLayoutTransfer : public UT::Common { | |||
| }; | |||
| void RedistributionLayoutTransferTestFunction( | |||
| const DeviceArrangement& in_device_arrangement_shape, const TensorMap& in_tensor_map_shape, | |||
| const TensorShape& tensor_shape_shape, const DeviceArrangement& out_device_arrangement_shape, | |||
| const TensorMap& out_tensor_map_shape, DeviceArrangement* unified_device_arrangement_shape, | |||
| TensorMap* unified_in_tensor_map_shape, TensorMap* unified_out_tensor_map_shape, | |||
| TensorMap* unified_tensor_shape_shape) { | |||
| const DeviceArrangement &in_device_arrangement_shape, const TensorMap &in_tensor_map_shape, | |||
| const TensorShape &tensor_shape_shape, const DeviceArrangement &out_device_arrangement_shape, | |||
| const TensorMap &out_tensor_map_shape, DeviceArrangement *unified_device_arrangement_shape, | |||
| TensorMap *unified_in_tensor_map_shape, TensorMap *unified_out_tensor_map_shape, | |||
| TensorMap *unified_tensor_shape_shape) { | |||
| Arrangement in_device_arrangement; | |||
| Status status = in_device_arrangement.Init(in_device_arrangement_shape); | |||
| ASSERT_EQ(Status::SUCCESS, status); | |||
| @@ -86,13 +86,13 @@ void RedistributionLayoutTransferTestFunction( | |||
| *unified_tensor_shape_shape = unified_in_tensor_shape.array(); | |||
| } | |||
| void RedistributionLayoutCheck(const DeviceArrangement& in_device_arrangement, const TensorMap& in_tensor_map, | |||
| const TensorShape& tensor_shape, const DeviceArrangement& out_device_arrangement, | |||
| const TensorMap& out_tensor_map, | |||
| const DeviceArrangement& unified_device_arrangement_expect, | |||
| const TensorMap& unified_in_tensor_map_expect, | |||
| const TensorMap& unified_out_tensor_map_expect, | |||
| const TensorMap& unified_tensor_shape_expect) { | |||
| void RedistributionLayoutCheck(const DeviceArrangement &in_device_arrangement, const TensorMap &in_tensor_map, | |||
| const TensorShape &tensor_shape, const DeviceArrangement &out_device_arrangement, | |||
| const TensorMap &out_tensor_map, | |||
| const DeviceArrangement &unified_device_arrangement_expect, | |||
| const TensorMap &unified_in_tensor_map_expect, | |||
| const TensorMap &unified_out_tensor_map_expect, | |||
| const TensorMap &unified_tensor_shape_expect) { | |||
| DeviceArrangement unified_device_arrangement; | |||
| TensorMap unified_in_tensor_map; | |||
| TensorMap unified_out_tensor_map; | |||
| @@ -224,9 +224,9 @@ TEST_F(TestRedistributionLayoutTransfer, RedistributionLayoutTransfer5) { | |||
| unified_out_tensor_map_expect, unified_tensor_shape_expect); | |||
| } | |||
| void ValidRedistributionLayoutCheck(const DeviceArrangement& in_device_arrangement, const TensorMap& in_tensor_map, | |||
| const TensorShape& tensor_shape, const DeviceArrangement& out_device_arrangement, | |||
| const TensorMap& out_tensor_map) { | |||
| void ValidRedistributionLayoutCheck(const DeviceArrangement &in_device_arrangement, const TensorMap &in_tensor_map, | |||
| const TensorShape &tensor_shape, const DeviceArrangement &out_device_arrangement, | |||
| const TensorMap &out_tensor_map) { | |||
| DeviceArrangement unified_device_arrangement; | |||
| TensorMap unified_in_tensor_map; | |||
| TensorMap unified_out_tensor_map; | |||
| @@ -242,8 +242,8 @@ void ValidRedistributionLayoutCheck(const DeviceArrangement& in_device_arrangeme | |||
| unified_out_tensor_map, unified_tensor_shape); | |||
| } | |||
| void ValidRedistributionLayoutCheckAll(int64_t device_pow_size, int64_t tensor_pow_size, | |||
| int64_t max_device_dim, int64_t max_shape_dim) { | |||
| void ValidRedistributionLayoutCheckAll(int64_t device_pow_size, int64_t tensor_pow_size, int64_t max_device_dim, | |||
| int64_t max_shape_dim) { | |||
| std::vector<std::tuple<DeviceArrangement, TensorMap, TensorShape>> layout_list; | |||
| GenerateValidLayoutByDeviceSizeAndTensorSize(device_pow_size, tensor_pow_size, max_device_dim, max_shape_dim, | |||
| &layout_list); | |||
| @@ -49,7 +49,7 @@ class TestRedistributionOperatorInfer : public UT::Common { | |||
| }; | |||
| // check if in_tensor_map could be changed to out_tensor_map with operator_list | |||
| void InferOperatorCheck(Shape in_tensor_map, const Shape& out_tensor_map, const OperatorList& operator_list) { | |||
| void InferOperatorCheck(Shape in_tensor_map, const Shape &out_tensor_map, const OperatorList &operator_list) { | |||
| for (auto op_cost : operator_list) { | |||
| OperatorR op = op_cost.first; | |||
| Args args = op.second; | |||
| @@ -35,11 +35,11 @@ class TestReshapeLayoutTransfer : public UT::Common { | |||
| virtual void TearDown() {} | |||
| }; | |||
| void InferUnifiedLayout(const DeviceArrangement& device_arrangement_shape, const TensorMap& in_tensor_map_shape, | |||
| const TensorShape& in_tensor_shape_shape, const TensorMap& out_tensor_map_shape, | |||
| const TensorShape& out_tensor_shape_shape, DeviceArrangement* unified_device_arrangement_shape, | |||
| TensorMap* unified_in_tensor_map_shape, TensorMap* unified_out_tensor_map_shape, | |||
| TensorMap* unified_tensor_shape_shape) { | |||
| void InferUnifiedLayout(const DeviceArrangement &device_arrangement_shape, const TensorMap &in_tensor_map_shape, | |||
| const TensorShape &in_tensor_shape_shape, const TensorMap &out_tensor_map_shape, | |||
| const TensorShape &out_tensor_shape_shape, DeviceArrangement *unified_device_arrangement_shape, | |||
| TensorMap *unified_in_tensor_map_shape, TensorMap *unified_out_tensor_map_shape, | |||
| TensorMap *unified_tensor_shape_shape) { | |||
| Arrangement device_arrangement; | |||
| Status status = device_arrangement.Init(device_arrangement_shape); | |||
| ASSERT_EQ(Status::SUCCESS, status); | |||
| @@ -85,13 +85,13 @@ void InferUnifiedLayout(const DeviceArrangement& device_arrangement_shape, const | |||
| *unified_out_tensor_map_shape = unified_out_tensor_map.array(); | |||
| } | |||
| void InferUnifiedLayoutCheck(const DeviceArrangement& device_arrangement, const TensorMap& in_tensor_map, | |||
| const TensorShape& in_tensor_shape, const TensorMap& out_tensor_map, | |||
| const TensorShape& out_tensor_shape, | |||
| const DeviceArrangement& unified_device_arrangement_expect, | |||
| const TensorMap& unified_in_tensor_map_expect, | |||
| const TensorMap& unified_out_tensor_map_expect, | |||
| const TensorMap& unified_tensor_shape_expect) { | |||
| void InferUnifiedLayoutCheck(const DeviceArrangement &device_arrangement, const TensorMap &in_tensor_map, | |||
| const TensorShape &in_tensor_shape, const TensorMap &out_tensor_map, | |||
| const TensorShape &out_tensor_shape, | |||
| const DeviceArrangement &unified_device_arrangement_expect, | |||
| const TensorMap &unified_in_tensor_map_expect, | |||
| const TensorMap &unified_out_tensor_map_expect, | |||
| const TensorMap &unified_tensor_shape_expect) { | |||
| DeviceArrangement unified_device_arrangement; | |||
| TensorMap unified_in_tensor_map; | |||
| TensorMap unified_out_tensor_map; | |||
| @@ -109,9 +109,9 @@ void InferUnifiedLayoutCheck(const DeviceArrangement& device_arrangement, const | |||
| ASSERT_EQ(unified_tensor_shape_expect, unified_tensor_shape); | |||
| } | |||
| void ValidUnifiedLayoutCheck(const DeviceArrangement& device_arrangement, const TensorMap& in_tensor_map, | |||
| const TensorShape& in_tensor_shape, const TensorMap& out_tensor_map, | |||
| const TensorShape& out_tensor_shape) { | |||
| void ValidUnifiedLayoutCheck(const DeviceArrangement &device_arrangement, const TensorMap &in_tensor_map, | |||
| const TensorShape &in_tensor_shape, const TensorMap &out_tensor_map, | |||
| const TensorShape &out_tensor_shape) { | |||
| DeviceArrangement unified_device_arrangement; | |||
| TensorMap unified_in_tensor_map; | |||
| TensorMap unified_out_tensor_map; | |||
| @@ -257,8 +257,8 @@ TEST_F(TestReshapeLayoutTransfer, ValidInferUnifiedLayoutCheck11) { | |||
| ValidUnifiedLayoutCheck(device_arrangement, in_tensor_map, in_tensor_shape, out_tensor_map, out_tensor_shape); | |||
| } | |||
| void ValidInferUnifiedLayoutCheckAll(int64_t device_pow_size, int64_t tensor_pow_size, | |||
| int64_t max_device_dim, int64_t max_shape_dim) { | |||
| void ValidInferUnifiedLayoutCheckAll(int64_t device_pow_size, int64_t tensor_pow_size, int64_t max_device_dim, | |||
| int64_t max_shape_dim) { | |||
| std::vector<std::tuple<DeviceArrangement, TensorMap, TensorShape>> layout_list; | |||
| GenerateValidLayoutByDeviceSizeAndTensorSize(device_pow_size, tensor_pow_size, max_device_dim, max_shape_dim, | |||
| &layout_list); | |||
| @@ -297,7 +297,7 @@ TEST_F(TestReshapeLayoutTransfer, ValidInferUnifiedLayoutCheckAll) { | |||
| ValidInferUnifiedLayoutCheckAll(device_pow_size, tensor_pow_size, max_device_dim, max_shape_dim); | |||
| tensor_pow_size++; | |||
| } | |||
| device_pow_size++; | |||
| device_pow_size++; | |||
| } | |||
| } | |||
| @@ -32,12 +32,12 @@ class TestTensorLayout : public UT::Common { | |||
| virtual void TearDown() {} | |||
| }; | |||
| void ReshapeExpandDeviceArrangementTestFunction(const DeviceArrangement& in_device_arrangement_shape, | |||
| const TensorMap& in_tensor_map_shape, | |||
| const TensorShape& in_tensor_shape_shape, | |||
| const DeviceArrangement& out_device_arrangement_shape, | |||
| const TensorMap& out_tensor_map_shape, | |||
| const TensorShape& out_tensor_shape_shape) { | |||
| void ReshapeExpandDeviceArrangementTestFunction(const DeviceArrangement &in_device_arrangement_shape, | |||
| const TensorMap &in_tensor_map_shape, | |||
| const TensorShape &in_tensor_shape_shape, | |||
| const DeviceArrangement &out_device_arrangement_shape, | |||
| const TensorMap &out_tensor_map_shape, | |||
| const TensorShape &out_tensor_shape_shape) { | |||
| Arrangement device_arrangement; | |||
| Status status = device_arrangement.Init(in_device_arrangement_shape); | |||
| ASSERT_EQ(Status::SUCCESS, status); | |||
| @@ -154,12 +154,10 @@ TEST_F(TestTensorLayout, ReshapeExpandDeviceArrangement5) { | |||
| tensor_map_expect, tensor_shape_expect); | |||
| } | |||
| void ExpandTensorShapeTestFunction(const DeviceArrangement& in_device_arrangement_shape, | |||
| const TensorMap& in_tensor_map_shape, | |||
| const TensorShape& in_tensor_shape_shape, | |||
| const DeviceArrangement& out_device_arrangement_shape, | |||
| const TensorMap& out_tensor_map_shape, | |||
| const TensorShape& out_tensor_shape_shape) { | |||
| void ExpandTensorShapeTestFunction(const DeviceArrangement &in_device_arrangement_shape, | |||
| const TensorMap &in_tensor_map_shape, const TensorShape &in_tensor_shape_shape, | |||
| const DeviceArrangement &out_device_arrangement_shape, | |||
| const TensorMap &out_tensor_map_shape, const TensorShape &out_tensor_shape_shape) { | |||
| Arrangement device_arrangement; | |||
| Status status = device_arrangement.Init(in_device_arrangement_shape); | |||
| ASSERT_EQ(Status::SUCCESS, status); | |||
| @@ -251,12 +249,12 @@ TEST_F(TestTensorLayout, UpdateTensorMap) { | |||
| ASSERT_EQ(in_tensor_map, new_tensor_map); | |||
| } | |||
| void RemoveElementEqualToOneInDeviceArrangementTestFunction(const DeviceArrangement& in_device_arrangement_shape, | |||
| const TensorMap& in_tensor_map_shape, | |||
| const TensorShape& in_tensor_shape_shape, | |||
| const DeviceArrangement& out_device_arrangement_shape, | |||
| const TensorMap& out_tensor_map_shape, | |||
| const TensorShape& out_tensor_shape_shape) { | |||
| void RemoveElementEqualToOneInDeviceArrangementTestFunction(const DeviceArrangement &in_device_arrangement_shape, | |||
| const TensorMap &in_tensor_map_shape, | |||
| const TensorShape &in_tensor_shape_shape, | |||
| const DeviceArrangement &out_device_arrangement_shape, | |||
| const TensorMap &out_tensor_map_shape, | |||
| const TensorShape &out_tensor_shape_shape) { | |||
| Arrangement device_arrangement; | |||
| Status status = device_arrangement.Init(in_device_arrangement_shape); | |||
| ASSERT_EQ(Status::SUCCESS, status); | |||
| @@ -310,15 +308,82 @@ TEST_F(TestTensorLayout, RemoveElementEqualToOneInDeviceArrangement3) { | |||
| device_arrangement, tensor_map, tensor_shape, device_arrangement_expect, tensor_map_expect, tensor_shape_new); | |||
| } | |||
| TEST_F(TestTensorLayout, RemoveElementEqualToOneInDeviceArrangement4) { | |||
| DeviceArrangement device_arrangement = {1, 1, 1}; | |||
| TensorMap tensor_map = {2, 1}; | |||
| TensorShape tensor_shape = {128, 4096}; | |||
| DeviceArrangement device_arrangement_expect = {}; | |||
| TensorMap tensor_map_expect = {-1, -1}; | |||
| TensorShape tensor_shape_new = {128, 4096}; | |||
| RemoveElementEqualToOneInDeviceArrangementTestFunction( | |||
| device_arrangement, tensor_map, tensor_shape, device_arrangement_expect, tensor_map_expect, tensor_shape_new); | |||
| /* | |||
| * example: | |||
| * device_arrangement = [8, 4], | |||
| * tensor_map = [1, 0], | |||
| * tensor_shape = [512, 1024], | |||
| */ | |||
| TEST_F(TestTensorLayout, GenerateOptShardSliceShape1) { | |||
| Arrangement device_arrangement; | |||
| device_arrangement.Init({8, 4}); | |||
| Map tensor_map; | |||
| tensor_map.Init({1, 0}); | |||
| Arrangement tensor_shape; | |||
| tensor_shape.Init({512, 1024}); | |||
| TensorLayout tensor_layout; | |||
| tensor_layout.Init(device_arrangement, tensor_map, tensor_shape); | |||
| ASSERT_EQ(Status::FAILED, tensor_layout.GenerateOptShardSliceShape()); | |||
| } | |||
| /* | |||
| * example: | |||
| * device_arrangement = [8, 4], | |||
| * tensor_map = [-1, 0], | |||
| * tensor_shape = [512, 1024], | |||
| */ | |||
| TEST_F(TestTensorLayout, GenerateOptShardSliceShape2) { | |||
| Arrangement device_arrangement; | |||
| device_arrangement.Init({8, 4}); | |||
| Map tensor_map; | |||
| tensor_map.Init({-1, 0}); | |||
| Arrangement tensor_shape; | |||
| tensor_shape.Init({512, 1024}); | |||
| TensorLayout tensor_layout; | |||
| tensor_layout.Init(device_arrangement, tensor_map, tensor_shape); | |||
| ASSERT_EQ(Status::SUCCESS, tensor_layout.GenerateOptShardSliceShape()); | |||
| Shape slice_shape_expect = {64, 256}; | |||
| ASSERT_EQ(tensor_layout.opt_shard_slice_shape(), slice_shape_expect); | |||
| } | |||
| /* | |||
| * example: | |||
| * device_arrangement = [4, 4, 2], | |||
| * tensor_map = [1, 0], | |||
| * tensor_shape = [512, 1024], | |||
| */ | |||
| TEST_F(TestTensorLayout, GenerateOptShardSliceShape3) { | |||
| Arrangement device_arrangement; | |||
| device_arrangement.Init({4, 4, 2}); | |||
| Map tensor_map; | |||
| tensor_map.Init({1, 0}); | |||
| Arrangement tensor_shape; | |||
| tensor_shape.Init({512, 1024}); | |||
| TensorLayout tensor_layout; | |||
| tensor_layout.Init(device_arrangement, tensor_map, tensor_shape); | |||
| ASSERT_EQ(Status::SUCCESS, tensor_layout.GenerateOptShardSliceShape()); | |||
| Shape slice_shape_expect = {32, 512}; | |||
| ASSERT_EQ(tensor_layout.opt_shard_slice_shape(), slice_shape_expect); | |||
| } | |||
| /* | |||
| * example: | |||
| * device_arrangement = [4, 4, 2], | |||
| * tensor_map = [1, 0], | |||
| * tensor_shape = [20, 1024], | |||
| */ | |||
| TEST_F(TestTensorLayout, GenerateOptShardSliceShape4) { | |||
| Arrangement device_arrangement; | |||
| device_arrangement.Init({4, 4, 2}); | |||
| Map tensor_map; | |||
| tensor_map.Init({1, 0}); | |||
| Arrangement tensor_shape; | |||
| tensor_shape.Init({20, 1024}); | |||
| TensorLayout tensor_layout; | |||
| tensor_layout.Init(device_arrangement, tensor_map, tensor_shape); | |||
| ASSERT_EQ(Status::FAILED, tensor_layout.GenerateOptShardSliceShape()); | |||
| } | |||
| } // namespace parallel | |||
| @@ -28,7 +28,7 @@ using std::pow; | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| std::vector<Shape> combine(const Shape& in, int64_t target) { | |||
| std::vector<Shape> combine(const Shape &in, int64_t target) { | |||
| std::vector<Shape> output; | |||
| for (int64_t i = 0; i < pow(2, in.size()); i++) { | |||
| size_t temp = 0; | |||
| @@ -54,7 +54,7 @@ std::vector<Shape> combine(const Shape& in, int64_t target) { | |||
| return output; | |||
| } | |||
| void GenerateValidShapeBySizeAndDim(int64_t pow_size, int64_t dim, std::vector<Shape>* out) { | |||
| void GenerateValidShapeBySizeAndDim(int64_t pow_size, int64_t dim, std::vector<Shape> *out) { | |||
| out->clear(); | |||
| Shape in; | |||
| for (int64_t i = 1; i < pow_size; i++) { | |||
| @@ -80,7 +80,7 @@ void GenerateValidShapeBySizeAndDim(int64_t pow_size, int64_t dim, std::vector<S | |||
| return; | |||
| } | |||
| void GenerateValidShapeBySize(int64_t pow_size, std::vector<Shape>* out) { | |||
| void GenerateValidShapeBySize(int64_t pow_size, std::vector<Shape> *out) { | |||
| out->clear(); | |||
| for (int64_t dim = 1; dim <= pow_size; dim++) { | |||
| std::vector<Shape> combine_result; | |||
| @@ -92,7 +92,7 @@ void GenerateValidShapeBySize(int64_t pow_size, std::vector<Shape>* out) { | |||
| return; | |||
| } | |||
| TensorMap GenerateTensorMap(const int64_t& map_size, const Shape& pos_index, const Shape& pos_value) { | |||
| TensorMap GenerateTensorMap(const int64_t &map_size, const Shape &pos_index, const Shape &pos_value) { | |||
| TensorMap tensor_map(map_size, -1); | |||
| for (size_t i = 0; i < pos_index.size() && i < pos_value.size(); i++) { | |||
| if (pos_index[i] >= map_size) { | |||
| @@ -103,8 +103,8 @@ TensorMap GenerateTensorMap(const int64_t& map_size, const Shape& pos_index, con | |||
| return tensor_map; | |||
| } | |||
| void GenerateValidTensorMap(const DeviceArrangement& device_arrangement, const TensorShape& tensor_shape, | |||
| std::vector<TensorMap>* tensor_map_list) { | |||
| void GenerateValidTensorMap(const DeviceArrangement &device_arrangement, const TensorShape &tensor_shape, | |||
| std::vector<TensorMap> *tensor_map_list) { | |||
| tensor_map_list->clear(); | |||
| int64_t device_size = device_arrangement.size(); | |||
| int64_t shape_size = tensor_shape.size(); | |||
| @@ -149,9 +149,8 @@ void GenerateValidTensorMap(const DeviceArrangement& device_arrangement, const T | |||
| } | |||
| void GenerateValidLayoutByDeviceSizeAndTensorSize( | |||
| int64_t device_pow_size, int64_t tensor_pow_size, int64_t max_device_dim, | |||
| int64_t max_shape_dim, | |||
| std::vector<std::tuple<DeviceArrangement, TensorMap, TensorShape>>* layout_list) { | |||
| int64_t device_pow_size, int64_t tensor_pow_size, int64_t max_device_dim, int64_t max_shape_dim, | |||
| std::vector<std::tuple<DeviceArrangement, TensorMap, TensorShape>> *layout_list) { | |||
| layout_list->clear(); | |||
| std::vector<DeviceArrangement> device_arrangement_list; | |||
| GenerateValidShapeBySize(device_pow_size, &device_arrangement_list); | |||
| @@ -174,8 +173,8 @@ void GenerateValidLayoutByDeviceSizeAndTensorSize( | |||
| return; | |||
| } | |||
| bool CheckLayoutValid(const DeviceArrangement& device_arrangement, const TensorMap& tensor_map, | |||
| const TensorShape& tensor_shape) { | |||
| bool CheckLayoutValid(const DeviceArrangement &device_arrangement, const TensorMap &tensor_map, | |||
| const TensorShape &tensor_shape) { | |||
| bool flag = false; | |||
| if ((tensor_map.size() - ComputeNoneNumber(tensor_map)) > device_arrangement.size()) { | |||
| return flag; | |||
| @@ -186,7 +185,7 @@ bool CheckLayoutValid(const DeviceArrangement& device_arrangement, const TensorM | |||
| return true; | |||
| } | |||
| size_t ComputeNoneNumber(const TensorMap& tensor_map) { | |||
| size_t ComputeNoneNumber(const TensorMap &tensor_map) { | |||
| size_t num = 0; | |||
| for (size_t i = 0; i < tensor_map.size(); i++) { | |||
| if (tensor_map[i] == -1) { | |||
| @@ -196,8 +195,8 @@ size_t ComputeNoneNumber(const TensorMap& tensor_map) { | |||
| return num; | |||
| } | |||
| bool ShapeIsDividedByDevice(const DeviceArrangement& device_arrangement, const TensorMap& tensor_map, | |||
| const TensorShape& tensor_shape) { | |||
| bool ShapeIsDividedByDevice(const DeviceArrangement &device_arrangement, const TensorMap &tensor_map, | |||
| const TensorShape &tensor_shape) { | |||
| bool flag = false; | |||
| for (uint32_t i = 0; i < tensor_map.size() && i < tensor_shape.size(); i++) { | |||
| if (tensor_map[i] == -1) { | |||
| @@ -211,7 +210,7 @@ bool ShapeIsDividedByDevice(const DeviceArrangement& device_arrangement, const T | |||
| return true; | |||
| } | |||
| bool IsExpended(const Shape& in1, const Shape& in2) { | |||
| bool IsExpended(const Shape &in1, const Shape &in2) { | |||
| int64_t size = 1; | |||
| uint32_t ind = 0; | |||
| for (uint32_t i = 0; i < in1.size(); i++) { | |||
| @@ -234,9 +233,9 @@ bool IsExpended(const Shape& in1, const Shape& in2) { | |||
| return true; | |||
| } | |||
| void ComputeAccumDeviceTOAccumShapeMap(const DeviceArrangement& device_arrangement, | |||
| const TensorMap& tensor_map, const TensorShape& tensor_shape, | |||
| std::map<int64_t, int64_t>* accum_device_to_accum_shape_map) { | |||
| void ComputeAccumDeviceTOAccumShapeMap(const DeviceArrangement &device_arrangement, const TensorMap &tensor_map, | |||
| const TensorShape &tensor_shape, | |||
| std::map<int64_t, int64_t> *accum_device_to_accum_shape_map) { | |||
| accum_device_to_accum_shape_map->clear(); | |||
| std::vector<int64_t> shape_accum_reverse; | |||
| Status status = ShapeToAccumulateProductReverse(tensor_shape, &shape_accum_reverse); | |||
| @@ -263,12 +262,10 @@ void IsLinearValue(int64_t small, int64_t big, int64_t small_value, int64_t big_ | |||
| ASSERT_EQ(middle_value, value); | |||
| } | |||
| void LayoutTransferValidLayoutChangeCheck(const DeviceArrangement& in_device_arrangement, | |||
| const TensorMap& in_tensor_map, | |||
| const TensorShape& in_tensor_shape, | |||
| const DeviceArrangement& out_device_arrangement, | |||
| const TensorMap& out_tensor_map, | |||
| const TensorShape& out_tensor_shape) { | |||
| void LayoutTransferValidLayoutChangeCheck(const DeviceArrangement &in_device_arrangement, | |||
| const TensorMap &in_tensor_map, const TensorShape &in_tensor_shape, | |||
| const DeviceArrangement &out_device_arrangement, | |||
| const TensorMap &out_tensor_map, const TensorShape &out_tensor_shape) { | |||
| bool is_expended = IsExpended(out_device_arrangement, in_device_arrangement); | |||
| ASSERT_EQ(true, is_expended); | |||
| is_expended = IsExpended(out_tensor_shape, in_tensor_shape); | |||
| @@ -317,10 +314,9 @@ void LayoutTransferValidLayoutChangeCheck(const DeviceArrangement& in_device_arr | |||
| } | |||
| } | |||
| void ValidLayoutChangeCheck(const DeviceArrangement& in_device_arrangement, | |||
| const TensorMap& in_tensor_map, const TensorShape& in_tensor_shape, | |||
| const DeviceArrangement& out_device_arrangement, | |||
| const TensorMap& out_tensor_map, const TensorShape& out_tensor_shape) { | |||
| void ValidLayoutChangeCheck(const DeviceArrangement &in_device_arrangement, const TensorMap &in_tensor_map, | |||
| const TensorShape &in_tensor_shape, const DeviceArrangement &out_device_arrangement, | |||
| const TensorMap &out_tensor_map, const TensorShape &out_tensor_shape) { | |||
| LayoutTransferValidLayoutChangeCheck(in_device_arrangement, in_tensor_map, in_tensor_shape, out_device_arrangement, | |||
| out_tensor_map, out_tensor_shape); | |||
| } | |||
| @@ -26,45 +26,41 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| std::vector<Shape> combine(const Shape& in, int64_t target); | |||
| std::vector<Shape> combine(const Shape &in, int64_t target); | |||
| void GenerateValidShapeBySizeAndDim(int64_t pow_size, int64_t dim, std::vector<Shape>* out); | |||
| void GenerateValidShapeBySizeAndDim(int64_t pow_size, int64_t dim, std::vector<Shape> *out); | |||
| void GenerateValidShapeBySize(int64_t pow_size, std::vector<Shape>* out); | |||
| void GenerateValidShapeBySize(int64_t pow_size, std::vector<Shape> *out); | |||
| TensorMap GenerateTensorMap(const int64_t& map_size, const Shape& pos_index, const Shape& pos_value); | |||
| TensorMap GenerateTensorMap(const int64_t &map_size, const Shape &pos_index, const Shape &pos_value); | |||
| void GenerateValidTensorMap(const DeviceArrangement& device_arrangement, const TensorMap& tensor_shape, | |||
| std::vector<TensorMap>* tensor_map_list); | |||
| void GenerateValidTensorMap(const DeviceArrangement &device_arrangement, const TensorMap &tensor_shape, | |||
| std::vector<TensorMap> *tensor_map_list); | |||
| void GenerateValidLayoutByDeviceSizeAndTensorSize( | |||
| int64_t device_pow_size, int64_t tensor_pow_size, int64_t max_device_dim, | |||
| int64_t max_shape_dim, | |||
| std::vector<std::tuple<DeviceArrangement, TensorMap, TensorShape>>* layout_list); | |||
| int64_t device_pow_size, int64_t tensor_pow_size, int64_t max_device_dim, int64_t max_shape_dim, | |||
| std::vector<std::tuple<DeviceArrangement, TensorMap, TensorShape>> *layout_list); | |||
| size_t ComputeNoneNumber(const TensorMap& tensor_map); | |||
| size_t ComputeNoneNumber(const TensorMap &tensor_map); | |||
| bool ShapeIsDividedByDevice(const DeviceArrangement& device_arrangement, const TensorMap& tensor_map, | |||
| const TensorShape& tensor_shape); | |||
| bool ShapeIsDividedByDevice(const DeviceArrangement &device_arrangement, const TensorMap &tensor_map, | |||
| const TensorShape &tensor_shape); | |||
| bool CheckLayoutValid(const DeviceArrangement& device_arrangement, const TensorMap& tensor_map, | |||
| const TensorShape& tensor_shape); | |||
| bool CheckLayoutValid(const DeviceArrangement &device_arrangement, const TensorMap &tensor_map, | |||
| const TensorShape &tensor_shape); | |||
| void ComputeAccumDeviceTOAccumShapeMap(const DeviceArrangement& device_arrangement, | |||
| const TensorMap& tensor_map, const TensorShape& tensor_shape, | |||
| std::map<int64_t, int64_t>* accum_device_to_accum_shape_map); | |||
| void ComputeAccumDeviceTOAccumShapeMap(const DeviceArrangement &device_arrangement, const TensorMap &tensor_map, | |||
| const TensorShape &tensor_shape, | |||
| std::map<int64_t, int64_t> *accum_device_to_accum_shape_map); | |||
| void LayoutTransferValidLayoutChangeCheck(const DeviceArrangement& in_device_arrangement, | |||
| const TensorMap& in_tensor_map, | |||
| const TensorShape& in_tensor_shape, | |||
| const DeviceArrangement& out_device_arrangement, | |||
| const TensorMap& out_tensor_map, | |||
| const TensorShape& out_tensor_shape); | |||
| void LayoutTransferValidLayoutChangeCheck(const DeviceArrangement &in_device_arrangement, | |||
| const TensorMap &in_tensor_map, const TensorShape &in_tensor_shape, | |||
| const DeviceArrangement &out_device_arrangement, | |||
| const TensorMap &out_tensor_map, const TensorShape &out_tensor_shape); | |||
| void ValidLayoutChangeCheck(const DeviceArrangement& in_device_arrangement, | |||
| const TensorMap& in_tensor_map, const TensorShape& in_tensor_shape, | |||
| const DeviceArrangement& out_device_arrangement, | |||
| const TensorMap& out_tensor_map, const TensorShape& out_tensor_shape); | |||
| void ValidLayoutChangeCheck(const DeviceArrangement &in_device_arrangement, const TensorMap &in_tensor_map, | |||
| const TensorShape &in_tensor_shape, const DeviceArrangement &out_device_arrangement, | |||
| const TensorMap &out_tensor_map, const TensorShape &out_tensor_shape); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """api definition""" | |||
| import threading | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| class Hccl(): | |||
| @@ -62,7 +63,9 @@ def get_rank_id(group=None): | |||
| def get_rank_size(group=None): | |||
| hccl = Hccl() | |||
| if group is None or "nccl_world_group" in group: | |||
| return hccl.rank_size | |||
| if auto_parallel_context().get_device_num_is_set() is False: | |||
| return 1 | |||
| return auto_parallel_context().get_device_num() | |||
| if isinstance(group, str): | |||
| return int(group.split("-")[0]) | |||
| raise ValueError | |||
| @@ -49,8 +49,8 @@ def test_get_parameter_layout(): | |||
| net.set_auto_parallel() | |||
| exe = me._executor | |||
| exe.compile(net, x, phase='train', auto_parallel_mode=True) | |||
| x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1] | |||
| weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1] | |||
| x_layout = ([2, 4], [1, -1], [16, 32], 0, True, '') # device_arrangement = [2, 4], tensor_map = [1, -1] | |||
| weight_layout = ([2, 4], [0, -1], [16, 32], 0, True, '') # device_arrangement = [2, 4], tensor_map = [0, -1] | |||
| expect_dict = {'x': x_layout, 'w1': weight_layout} | |||
| # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut | |||
| assert net.parameter_layout_dict == expect_dict | |||
| @@ -17,16 +17,17 @@ import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore.common.api import _executor | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb | |||
| from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell | |||
| from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb, Momentum | |||
| from mindspore.ops import operations as P | |||
| from mindspore import context | |||
| class Net(nn.Cell): | |||
| """Net definition""" | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.fc1 = nn.Dense(128, 768, activation='relu') | |||
| @@ -50,6 +51,56 @@ class Net(nn.Cell): | |||
| return s | |||
| class Net2(nn.Cell): | |||
| """Net definition""" | |||
| def __init__(self, strategy1, strategy2): | |||
| super(Net2, self).__init__() | |||
| self.fc1 = P.MatMul().shard(strategy=strategy1) | |||
| self.fc2 = P.MatMul().shard(strategy=strategy2) | |||
| self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1") | |||
| self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2") | |||
| def construct(self, x, y): | |||
| x = self.fc1(x, self.p1) | |||
| x = self.fc2(x, self.p2) | |||
| return x - y | |||
| def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True) | |||
| inputs = Tensor(np.ones([32, 48]).astype(np.float32)) | |||
| label = Tensor(np.zeros([32, 16]).astype(np.float32)) | |||
| net = Net2(strategy1, strategy2) | |||
| net = _VirtualDatasetCell(net) | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| train_network = TrainOneStepCell(net, optimizer) | |||
| train_network.set_auto_parallel() | |||
| _executor.compile(train_network, inputs, label) | |||
| context.reset_auto_parallel_context() | |||
| def test_auto_parallel_momentum_1(): | |||
| auto_parallel_compile_net("auto_parallel", 8) | |||
| def test_auto_parallel_momentum_2(): | |||
| # data parallel case | |||
| auto_parallel_compile_net("auto_parallel", 8, ((8, 1), (1, 1)), ((8, 1), (1, 1))) | |||
| def test_auto_parallel_momentum_3(): | |||
| # hybrid parallel case | |||
| # weight1 could not be shard and weight2 is repeated | |||
| auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 8), (8, 1)), ((4, 4), (4, 2))) | |||
| def test_auto_parallel_momentum_4(): | |||
| # hybrid parallel cases | |||
| # devices are repeatedly used | |||
| auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 4), (4, 1)), ((4, 4), (4, 2))) | |||
| def test_AdamWeightDecay(): | |||
| """ test_AdamWeightDecay """ | |||
| context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) | |||
| @@ -98,6 +149,7 @@ def test_lamb_split_fusion(): | |||
| _executor.compile(train_network, inputs, label) | |||
| context.reset_auto_parallel_context() | |||
| def test_edge_case(): | |||
| """ test_edge_case """ | |||
| context.set_auto_parallel_context(enable_parallel_optimizer=True) | |||
| @@ -121,10 +121,10 @@ def test_grad_sens_parameter_type(): | |||
| sens = Tensor(np.ones([128, 64]), dtype=ms.float32) | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x, y, b, sens, phase='train', auto_parallel_mode=True) | |||
| x_layout = [[8, 8], [1, -1], [16, 32], [0], [1]] | |||
| y_layout = [[8, 8], [-1, 0], [32, 8], [0], [1]] | |||
| b_layout = [[8, 8], [0, -1], [8, 64], [0], [1]] | |||
| sens_layout = [[8, 8], [1, -1], [16, 64], [0], [1]] | |||
| x_layout = ([8, 8], [1, -1], [16, 32], 0, True, '') | |||
| y_layout = ([8, 8], [-1, 0], [32, 8], 0, True, '') | |||
| b_layout = ([8, 8], [0, -1], [8, 64], 0, True, '') | |||
| sens_layout = ([8, 8], [1, -1], [16, 64], 0, True, '') | |||
| expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout} | |||
| assert net.parameter_layout_dict == expect_dict | |||