| @@ -942,6 +942,29 @@ bool IsCastBeforMirror(const CNodePtr &node, size_t index) { | |||
| return (type_id != kNumberTypeFloat32); | |||
| } | |||
| static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) { | |||
| MS_EXCEPTION_IF_NULL(comm_node); | |||
| MS_EXCEPTION_IF_NULL(param_node); | |||
| if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) { | |||
| MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now."; | |||
| return; | |||
| } | |||
| auto param = param_node->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(param); | |||
| auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto attrs = prim->attrs(); | |||
| auto param_info = param->param_info(); | |||
| if (!param_info) { | |||
| MS_LOG(WARNING) << param->ToString() << "does not have parameter info."; | |||
| return; | |||
| } | |||
| int32_t fusion_type = param_info->comm_fusion(); | |||
| attrs[FUSION] = MakeValue<int64_t>(fusion_type); | |||
| prim->SetAttrs(attrs); | |||
| MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; | |||
| } | |||
| void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| size_t node_size = node->inputs().size(); | |||
| @@ -1006,11 +1029,19 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| AnfNodePtr pre_node = cnode->input(1); | |||
| InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name); | |||
| auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>(); | |||
| // add fusion flag | |||
| // pipeline mirror would not be set, which should be supported later | |||
| AddCommOpFusionType(comm_op, param_node_pair.first); | |||
| } | |||
| } else { | |||
| for (auto &op : backward_op) { | |||
| AnfNodePtr pre_node = node->input(index); | |||
| InsertNode(op, node, index, pre_node, func_graph, instance_name); | |||
| auto comm_op = node->input(index)->cast<CNodePtr>(); | |||
| // add fusion flag | |||
| // pipeline mirror would not be set, which should be supported later | |||
| AddCommOpFusionType(comm_op, param_node_pair.first); | |||
| } | |||
| } | |||
| } | |||
| @@ -1342,7 +1373,8 @@ std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const Anf | |||
| return std::make_pair(nullptr, 0); | |||
| } | |||
| void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res, const AnfNodePtr ¶meter) { | |||
| static void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res, | |||
| const AnfNodePtr ¶meter) { | |||
| Operator op = CreateAllGatherOp(group); | |||
| MS_EXCEPTION_IF_NULL(res.first); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| @@ -1360,11 +1392,7 @@ void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int | |||
| } | |||
| // add fusion flag | |||
| MS_EXCEPTION_IF_NULL(allgather); | |||
| auto prim = GetValueNode<PrimitivePtr>(allgather->input(0)); | |||
| auto attrs = prim->attrs(); | |||
| // enable fusion flag later when it's supported in backend | |||
| attrs["fusion"] = MakeValue<int64_t>(1); | |||
| prim->SetAttrs(attrs); | |||
| AddCommOpFusionType(allgather, parameter); | |||
| } | |||
| static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, | |||
| @@ -1419,6 +1447,9 @@ std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNod | |||
| if (!ParameterRequireGrad(parameter)) { | |||
| // only trainable parameters need parallel optimizer | |||
| MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter."; | |||
| } else if (parameter->cast<ParameterPtr>()->param_info() && | |||
| !parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) { | |||
| MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard."; | |||
| } else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) { | |||
| // get a totally shard tensor slice shape if the weight is repeated on devices | |||
| // and the shape of the first dimension could be divided | |||
| @@ -29,6 +29,9 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) { | |||
| .def_property("init_in_server", &ParamInfo::init_in_server, &ParamInfo::set_init_in_server) | |||
| .def_property("layerwise_parallel", &ParamInfo::layerwise_parallel, | |||
| &ParamInfo::set_layerwise_parallel) | |||
| .def_property("parallel_optimizer", &ParamInfo::parallel_optimizer, | |||
| &ParamInfo::set_parallel_optimizer) | |||
| .def_property("comm_fusion", &ParamInfo::comm_fusion, &ParamInfo::set_comm_fusion) | |||
| .def(py::pickle( | |||
| [](const ParamInfo &p) { // __getstate__ | |||
| return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel()); | |||
| @@ -75,8 +75,11 @@ class Parameter(MetaTensor_): | |||
| default_input (Union[Tensor, MetaTensor, Number]): Parameter data, to be set initialized. | |||
| name (str): Name of the child parameter. Default: None. | |||
| requires_grad (bool): True if the parameter requires gradient. Default: True. | |||
| layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in parallel mode, | |||
| layerwise_parallel (bool): When layerwise_parallel is true in data parallel mode, | |||
| broadcast and gradients communication would not be applied to parameters. Default: False. | |||
| parallel_optimizer (bool): It is used to filter the weight shard operation in semi auto or auto parallel | |||
| mode. It works only when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`. | |||
| Default: True. | |||
| Example: | |||
| >>> from mindspore import Parameter, Tensor | |||
| @@ -132,19 +135,21 @@ class Parameter(MetaTensor_): | |||
| return ( | |||
| Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) | |||
| def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False): | |||
| def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True): | |||
| self._param_info = ParamInfo() | |||
| self.init_in_server = False | |||
| self.cache_enable = False | |||
| self.name = name | |||
| self.requires_grad = requires_grad | |||
| self.layerwise_parallel = layerwise_parallel | |||
| self.parallel_optimizer = parallel_optimizer | |||
| # this flag for tensor copy data. | |||
| self.init_flag = False | |||
| # this flag is for ge variable copy data. | |||
| self._is_init = False | |||
| self._inited_param = None | |||
| self._sliced = False | |||
| self.comm_fusion = 1 | |||
| self.is_param_ps = False | |||
| self._cast_type = None | |||
| self._unique = False | |||
| @@ -210,7 +215,6 @@ class Parameter(MetaTensor_): | |||
| raise RuntimeError("Must complete following two steps before calling set_param_ps: \ | |||
| 1. set_ps_context(enable_ps=True) \ | |||
| 2. export MS_ROLE environment variable.") | |||
| if init_in_server and (not self.name.endswith("embedding_table")): | |||
| raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of " | |||
| "sparse operator support initialization in server.".format(self.name)) | |||
| @@ -218,7 +222,6 @@ class Parameter(MetaTensor_): | |||
| self.init_in_server = init_in_server | |||
| self._param_info.init_in_server = init_in_server | |||
| @property | |||
| def inited_param(self): | |||
| """ | |||
| @@ -273,6 +276,16 @@ class Parameter(MetaTensor_): | |||
| def sliced(self, sliced_): | |||
| self._sliced = sliced_ | |||
| @property | |||
| def comm_fusion(self): | |||
| """Get the fusion type for communication operators corresponding to this parameter.""" | |||
| return self._param_info.comm_fusion | |||
| @comm_fusion.setter | |||
| def comm_fusion(self, comm_fusion_): | |||
| """Set the fusion type for communication operators corresponding to this parameter.""" | |||
| self._param_info.comm_fusion = comm_fusion_ | |||
| @property | |||
| def unique(self): | |||
| """whether the parameter is already unique or not.""" | |||
| @@ -338,6 +351,17 @@ class Parameter(MetaTensor_): | |||
| raise TypeError("`layerwise_parallel` parameter must be bool type") | |||
| self._param_info.layerwise_parallel = value | |||
| @property | |||
| def parallel_optimizer(self): | |||
| """Return whether the parameter requires weight shard for parallel optimizer.""" | |||
| return self._param_info.parallel_optimizer | |||
| @parallel_optimizer.setter | |||
| def parallel_optimizer(self, value=True): | |||
| if not isinstance(value, bool): | |||
| raise TypeError("`parallel_optimizer` parameter must be bool type") | |||
| self._param_info.parallel_optimizer = value | |||
| @property | |||
| def requires_grad(self): | |||
| """Return whether the parameter requires gradient.""" | |||
| @@ -75,6 +75,12 @@ class ParamInfo { | |||
| return clone; | |||
| } | |||
| int32_t comm_fusion() const { return fusion_type_; } | |||
| void set_comm_fusion(int32_t fusion_type) { fusion_type_ = fusion_type; } | |||
| bool parallel_optimizer() const { return parallel_optimizer_; } | |||
| void set_parallel_optimizer(bool parallel_optimizer) { parallel_optimizer_ = parallel_optimizer; } | |||
| private: | |||
| std::string name_{"Parameter"}; | |||
| bool requires_grad_{true}; | |||
| @@ -84,6 +90,8 @@ class ParamInfo { | |||
| bool cloned_{false}; | |||
| std::vector<int32_t> be_cloned_index_; | |||
| int32_t cloned_index_{0}; | |||
| int32_t fusion_type_{1}; | |||
| bool parallel_optimizer_{true}; | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_IR_PARAM_INFO_H_ | |||
| @@ -1075,6 +1075,12 @@ class Cell(Cell_): | |||
| for param in params: | |||
| param.set_param_ps(init_in_server) | |||
| def set_comm_fusion(self, fusion_type, recurse=True): | |||
| Validator.check_is_int(fusion_type) | |||
| for param in self.trainable_params(recurse): | |||
| param.comm_fusion = fusion_type | |||
| return self | |||
| class GraphKernel(Cell): | |||
| """ | |||
| @@ -125,7 +125,7 @@ def get_bprop_all_gather(self): | |||
| instance_name = "grad_" + self.instance_name | |||
| reduce_scatter.set_prim_instance_name(instance_name) | |||
| else: | |||
| all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", 1) | |||
| all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion) | |||
| if self.instance_name: | |||
| instance_name = "grad_" + self.instance_name | |||
| all_reduce.set_prim_instance_name(instance_name) | |||
| @@ -240,9 +240,7 @@ def get_bprop_mirror_operator(self): | |||
| mul = P.Mul() | |||
| cast = P.Cast() | |||
| fusion = 1 | |||
| if hasattr(self, 'fusion'): | |||
| fusion = self.fusion | |||
| fusion = self.get_attr_dict()["fusion"] | |||
| all_reduce.add_prim_attr("fusion", fusion) | |||
| if hasattr(self, 'parameter'): | |||
| parameter = self.parameter | |||
| @@ -534,6 +534,7 @@ class _MirrorOperator(PrimitiveWithInfer): | |||
| self.group = group | |||
| self.dev_num = dev_num | |||
| self.mean_flag = mean_flag | |||
| self.add_prim_attr("fusion", 1) | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| @@ -25,6 +25,7 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from mindspore.train import Model | |||
| from mindspore.context import ParallelMode | |||
| from tests.dataset_mock import MindData | |||
| import pytest | |||
| class Dataset(MindData): | |||
| @@ -125,6 +126,7 @@ def train_common(net): | |||
| return allreduce_fusion_dict | |||
| @pytest.mark.skip(reason="depreciated feature") | |||
| def test_allreduce_fusion_parameters(): | |||
| cost_model_context.reset_cost_model_context() | |||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2) | |||
| @@ -181,6 +183,7 @@ def test_allreduce_fusion_parameters(): | |||
| assert computation_time_parameter == 0.1 | |||
| @pytest.mark.skip(reason="depreciated feature") | |||
| def test_allreduce_fusion1(): | |||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) | |||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) | |||
| @@ -205,6 +208,7 @@ def test_allreduce_fusion1(): | |||
| cost_model_context.reset_cost_model_context() | |||
| @pytest.mark.skip(reason="depreciated feature") | |||
| # reset_cost_model_context is called, the default value of costmodel_allreduce_fusion_times is 0, step_allreduce_fusion | |||
| # is bypassed. | |||
| def test_allreduce_fusion2(): | |||
| @@ -220,6 +224,7 @@ def test_allreduce_fusion2(): | |||
| cost_model_context.reset_cost_model_context() | |||
| @pytest.mark.skip(reason="depreciated feature") | |||
| def test_allreduce_fusion3(): | |||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) | |||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=3) | |||
| @@ -248,6 +253,7 @@ def test_allreduce_fusion3(): | |||
| cost_model_context.reset_cost_model_context() | |||
| @pytest.mark.skip(reason="depreciated feature") | |||
| def test_allreduce_fusion4(): | |||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) | |||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) | |||
| @@ -277,6 +283,7 @@ def test_allreduce_fusion4(): | |||
| cost_model_context.reset_cost_model_context() | |||
| @pytest.mark.skip(reason="depreciated feature") | |||
| def test_allreduce_fusion5(): | |||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2) | |||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1) | |||
| @@ -66,15 +66,30 @@ class Net2(nn.Cell): | |||
| return x - y | |||
| def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None): | |||
| class Net3(nn.Cell): | |||
| """Net definition""" | |||
| def __init__(self, strategy1, strategy2): | |||
| super(Net3, self).__init__() | |||
| self.fc1 = P.MatMul().shard(strategy=strategy1) | |||
| self.fc2 = P.MatMul().shard(strategy=strategy2) | |||
| self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1") | |||
| self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2", parallel_optimizer=False) | |||
| def construct(self, x, y): | |||
| x = self.fc1(x, self.p1) | |||
| x = self.fc2(x, self.p2) | |||
| return x - y | |||
| def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True) | |||
| inputs = Tensor(np.ones([32, 48]).astype(np.float32)) | |||
| label = Tensor(np.zeros([32, 16]).astype(np.float32)) | |||
| net = Net2(strategy1, strategy2) | |||
| net = net(strategy1, strategy2) | |||
| net = _VirtualDatasetCell(net) | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| train_network = TrainOneStepCell(net, optimizer) | |||
| train_network = TrainOneStepCell(net, optimizer).set_comm_fusion(4) | |||
| train_network.set_auto_parallel() | |||
| train_network.set_train() | |||
| _executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=True) | |||
| @@ -83,18 +98,18 @@ def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None): | |||
| def test_auto_parallel_momentum_1(): | |||
| auto_parallel_compile_net("auto_parallel", 8) | |||
| auto_parallel_compile_net("auto_parallel", 8, Net2) | |||
| def test_auto_parallel_momentum_2(): | |||
| # data parallel case | |||
| auto_parallel_compile_net("auto_parallel", 8, ((8, 1), (1, 1)), ((8, 1), (1, 1))) | |||
| auto_parallel_compile_net("auto_parallel", 8, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1))) | |||
| def test_auto_parallel_momentum_3(): | |||
| # hybrid parallel case | |||
| # weight1 could not be shard and weight2 is repeated | |||
| train_network = auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 8), (8, 1)), ((4, 4), (4, 2))) | |||
| train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2))) | |||
| param_dict = train_network.parameter_layout_dict | |||
| # validate opt_shard_group | |||
| assert not param_dict["weight1"][5] | |||
| @@ -104,7 +119,16 @@ def test_auto_parallel_momentum_3(): | |||
| def test_auto_parallel_momentum_4(): | |||
| # hybrid parallel cases | |||
| # devices are repeatedly used | |||
| auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 4), (4, 1)), ((4, 4), (4, 2))) | |||
| auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 4), (4, 1)), ((4, 4), (4, 2))) | |||
| def test_auto_parallel_momentum_5(): | |||
| # test parallel optimizer filter | |||
| train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net3, ((4, 8), (8, 1)), ((4, 4), (4, 2))) | |||
| param_dict = train_network.parameter_layout_dict | |||
| # validate opt_shard_group | |||
| assert not param_dict["weight1"][5] | |||
| assert not param_dict["weight2"][5] | |||
| def test_AdamWeightDecay(): | |||