| @@ -942,6 +942,29 @@ bool IsCastBeforMirror(const CNodePtr &node, size_t index) { | |||||
| return (type_id != kNumberTypeFloat32); | return (type_id != kNumberTypeFloat32); | ||||
| } | } | ||||
| static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) { | |||||
| MS_EXCEPTION_IF_NULL(comm_node); | |||||
| MS_EXCEPTION_IF_NULL(param_node); | |||||
| if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) { | |||||
| MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now."; | |||||
| return; | |||||
| } | |||||
| auto param = param_node->cast<ParameterPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(param); | |||||
| auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0)); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto attrs = prim->attrs(); | |||||
| auto param_info = param->param_info(); | |||||
| if (!param_info) { | |||||
| MS_LOG(WARNING) << param->ToString() << "does not have parameter info."; | |||||
| return; | |||||
| } | |||||
| int32_t fusion_type = param_info->comm_fusion(); | |||||
| attrs[FUSION] = MakeValue<int64_t>(fusion_type); | |||||
| prim->SetAttrs(attrs); | |||||
| MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; | |||||
| } | |||||
| void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { | void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| size_t node_size = node->inputs().size(); | size_t node_size = node->inputs().size(); | ||||
| @@ -1006,11 +1029,19 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| AnfNodePtr pre_node = cnode->input(1); | AnfNodePtr pre_node = cnode->input(1); | ||||
| InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name); | InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name); | ||||
| auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>(); | |||||
| // add fusion flag | |||||
| // pipeline mirror would not be set, which should be supported later | |||||
| AddCommOpFusionType(comm_op, param_node_pair.first); | |||||
| } | } | ||||
| } else { | } else { | ||||
| for (auto &op : backward_op) { | for (auto &op : backward_op) { | ||||
| AnfNodePtr pre_node = node->input(index); | AnfNodePtr pre_node = node->input(index); | ||||
| InsertNode(op, node, index, pre_node, func_graph, instance_name); | InsertNode(op, node, index, pre_node, func_graph, instance_name); | ||||
| auto comm_op = node->input(index)->cast<CNodePtr>(); | |||||
| // add fusion flag | |||||
| // pipeline mirror would not be set, which should be supported later | |||||
| AddCommOpFusionType(comm_op, param_node_pair.first); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1342,7 +1373,8 @@ std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const Anf | |||||
| return std::make_pair(nullptr, 0); | return std::make_pair(nullptr, 0); | ||||
| } | } | ||||
| void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res, const AnfNodePtr ¶meter) { | |||||
| static void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res, | |||||
| const AnfNodePtr ¶meter) { | |||||
| Operator op = CreateAllGatherOp(group); | Operator op = CreateAllGatherOp(group); | ||||
| MS_EXCEPTION_IF_NULL(res.first); | MS_EXCEPTION_IF_NULL(res.first); | ||||
| MS_EXCEPTION_IF_NULL(parameter); | MS_EXCEPTION_IF_NULL(parameter); | ||||
| @@ -1360,11 +1392,7 @@ void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int | |||||
| } | } | ||||
| // add fusion flag | // add fusion flag | ||||
| MS_EXCEPTION_IF_NULL(allgather); | MS_EXCEPTION_IF_NULL(allgather); | ||||
| auto prim = GetValueNode<PrimitivePtr>(allgather->input(0)); | |||||
| auto attrs = prim->attrs(); | |||||
| // enable fusion flag later when it's supported in backend | |||||
| attrs["fusion"] = MakeValue<int64_t>(1); | |||||
| prim->SetAttrs(attrs); | |||||
| AddCommOpFusionType(allgather, parameter); | |||||
| } | } | ||||
| static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, | static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, | ||||
| @@ -1419,6 +1447,9 @@ std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNod | |||||
| if (!ParameterRequireGrad(parameter)) { | if (!ParameterRequireGrad(parameter)) { | ||||
| // only trainable parameters need parallel optimizer | // only trainable parameters need parallel optimizer | ||||
| MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter."; | MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter."; | ||||
| } else if (parameter->cast<ParameterPtr>()->param_info() && | |||||
| !parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) { | |||||
| MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard."; | |||||
| } else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) { | } else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) { | ||||
| // get a totally shard tensor slice shape if the weight is repeated on devices | // get a totally shard tensor slice shape if the weight is repeated on devices | ||||
| // and the shape of the first dimension could be divided | // and the shape of the first dimension could be divided | ||||
| @@ -29,6 +29,9 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) { | |||||
| .def_property("init_in_server", &ParamInfo::init_in_server, &ParamInfo::set_init_in_server) | .def_property("init_in_server", &ParamInfo::init_in_server, &ParamInfo::set_init_in_server) | ||||
| .def_property("layerwise_parallel", &ParamInfo::layerwise_parallel, | .def_property("layerwise_parallel", &ParamInfo::layerwise_parallel, | ||||
| &ParamInfo::set_layerwise_parallel) | &ParamInfo::set_layerwise_parallel) | ||||
| .def_property("parallel_optimizer", &ParamInfo::parallel_optimizer, | |||||
| &ParamInfo::set_parallel_optimizer) | |||||
| .def_property("comm_fusion", &ParamInfo::comm_fusion, &ParamInfo::set_comm_fusion) | |||||
| .def(py::pickle( | .def(py::pickle( | ||||
| [](const ParamInfo &p) { // __getstate__ | [](const ParamInfo &p) { // __getstate__ | ||||
| return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel()); | return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel()); | ||||
| @@ -75,8 +75,11 @@ class Parameter(MetaTensor_): | |||||
| default_input (Union[Tensor, MetaTensor, Number]): Parameter data, to be set initialized. | default_input (Union[Tensor, MetaTensor, Number]): Parameter data, to be set initialized. | ||||
| name (str): Name of the child parameter. Default: None. | name (str): Name of the child parameter. Default: None. | ||||
| requires_grad (bool): True if the parameter requires gradient. Default: True. | requires_grad (bool): True if the parameter requires gradient. Default: True. | ||||
| layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in parallel mode, | |||||
| layerwise_parallel (bool): When layerwise_parallel is true in data parallel mode, | |||||
| broadcast and gradients communication would not be applied to parameters. Default: False. | broadcast and gradients communication would not be applied to parameters. Default: False. | ||||
| parallel_optimizer (bool): It is used to filter the weight shard operation in semi auto or auto parallel | |||||
| mode. It works only when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`. | |||||
| Default: True. | |||||
| Example: | Example: | ||||
| >>> from mindspore import Parameter, Tensor | >>> from mindspore import Parameter, Tensor | ||||
| @@ -132,19 +135,21 @@ class Parameter(MetaTensor_): | |||||
| return ( | return ( | ||||
| Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) | Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) | ||||
| def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False): | |||||
| def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True): | |||||
| self._param_info = ParamInfo() | self._param_info = ParamInfo() | ||||
| self.init_in_server = False | self.init_in_server = False | ||||
| self.cache_enable = False | self.cache_enable = False | ||||
| self.name = name | self.name = name | ||||
| self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
| self.layerwise_parallel = layerwise_parallel | self.layerwise_parallel = layerwise_parallel | ||||
| self.parallel_optimizer = parallel_optimizer | |||||
| # this flag for tensor copy data. | # this flag for tensor copy data. | ||||
| self.init_flag = False | self.init_flag = False | ||||
| # this flag is for ge variable copy data. | # this flag is for ge variable copy data. | ||||
| self._is_init = False | self._is_init = False | ||||
| self._inited_param = None | self._inited_param = None | ||||
| self._sliced = False | self._sliced = False | ||||
| self.comm_fusion = 1 | |||||
| self.is_param_ps = False | self.is_param_ps = False | ||||
| self._cast_type = None | self._cast_type = None | ||||
| self._unique = False | self._unique = False | ||||
| @@ -210,7 +215,6 @@ class Parameter(MetaTensor_): | |||||
| raise RuntimeError("Must complete following two steps before calling set_param_ps: \ | raise RuntimeError("Must complete following two steps before calling set_param_ps: \ | ||||
| 1. set_ps_context(enable_ps=True) \ | 1. set_ps_context(enable_ps=True) \ | ||||
| 2. export MS_ROLE environment variable.") | 2. export MS_ROLE environment variable.") | ||||
| if init_in_server and (not self.name.endswith("embedding_table")): | if init_in_server and (not self.name.endswith("embedding_table")): | ||||
| raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of " | raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of " | ||||
| "sparse operator support initialization in server.".format(self.name)) | "sparse operator support initialization in server.".format(self.name)) | ||||
| @@ -218,7 +222,6 @@ class Parameter(MetaTensor_): | |||||
| self.init_in_server = init_in_server | self.init_in_server = init_in_server | ||||
| self._param_info.init_in_server = init_in_server | self._param_info.init_in_server = init_in_server | ||||
| @property | @property | ||||
| def inited_param(self): | def inited_param(self): | ||||
| """ | """ | ||||
| @@ -273,6 +276,16 @@ class Parameter(MetaTensor_): | |||||
| def sliced(self, sliced_): | def sliced(self, sliced_): | ||||
| self._sliced = sliced_ | self._sliced = sliced_ | ||||
| @property | |||||
| def comm_fusion(self): | |||||
| """Get the fusion type for communication operators corresponding to this parameter.""" | |||||
| return self._param_info.comm_fusion | |||||
| @comm_fusion.setter | |||||
| def comm_fusion(self, comm_fusion_): | |||||
| """Set the fusion type for communication operators corresponding to this parameter.""" | |||||
| self._param_info.comm_fusion = comm_fusion_ | |||||
| @property | @property | ||||
| def unique(self): | def unique(self): | ||||
| """whether the parameter is already unique or not.""" | """whether the parameter is already unique or not.""" | ||||
| @@ -338,6 +351,17 @@ class Parameter(MetaTensor_): | |||||
| raise TypeError("`layerwise_parallel` parameter must be bool type") | raise TypeError("`layerwise_parallel` parameter must be bool type") | ||||
| self._param_info.layerwise_parallel = value | self._param_info.layerwise_parallel = value | ||||
| @property | |||||
| def parallel_optimizer(self): | |||||
| """Return whether the parameter requires weight shard for parallel optimizer.""" | |||||
| return self._param_info.parallel_optimizer | |||||
| @parallel_optimizer.setter | |||||
| def parallel_optimizer(self, value=True): | |||||
| if not isinstance(value, bool): | |||||
| raise TypeError("`parallel_optimizer` parameter must be bool type") | |||||
| self._param_info.parallel_optimizer = value | |||||
| @property | @property | ||||
| def requires_grad(self): | def requires_grad(self): | ||||
| """Return whether the parameter requires gradient.""" | """Return whether the parameter requires gradient.""" | ||||
| @@ -75,6 +75,12 @@ class ParamInfo { | |||||
| return clone; | return clone; | ||||
| } | } | ||||
| int32_t comm_fusion() const { return fusion_type_; } | |||||
| void set_comm_fusion(int32_t fusion_type) { fusion_type_ = fusion_type; } | |||||
| bool parallel_optimizer() const { return parallel_optimizer_; } | |||||
| void set_parallel_optimizer(bool parallel_optimizer) { parallel_optimizer_ = parallel_optimizer; } | |||||
| private: | private: | ||||
| std::string name_{"Parameter"}; | std::string name_{"Parameter"}; | ||||
| bool requires_grad_{true}; | bool requires_grad_{true}; | ||||
| @@ -84,6 +90,8 @@ class ParamInfo { | |||||
| bool cloned_{false}; | bool cloned_{false}; | ||||
| std::vector<int32_t> be_cloned_index_; | std::vector<int32_t> be_cloned_index_; | ||||
| int32_t cloned_index_{0}; | int32_t cloned_index_{0}; | ||||
| int32_t fusion_type_{1}; | |||||
| bool parallel_optimizer_{true}; | |||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_IR_PARAM_INFO_H_ | #endif // MINDSPORE_CORE_IR_PARAM_INFO_H_ | ||||
| @@ -1075,6 +1075,12 @@ class Cell(Cell_): | |||||
| for param in params: | for param in params: | ||||
| param.set_param_ps(init_in_server) | param.set_param_ps(init_in_server) | ||||
| def set_comm_fusion(self, fusion_type, recurse=True): | |||||
| Validator.check_is_int(fusion_type) | |||||
| for param in self.trainable_params(recurse): | |||||
| param.comm_fusion = fusion_type | |||||
| return self | |||||
| class GraphKernel(Cell): | class GraphKernel(Cell): | ||||
| """ | """ | ||||
| @@ -125,7 +125,7 @@ def get_bprop_all_gather(self): | |||||
| instance_name = "grad_" + self.instance_name | instance_name = "grad_" + self.instance_name | ||||
| reduce_scatter.set_prim_instance_name(instance_name) | reduce_scatter.set_prim_instance_name(instance_name) | ||||
| else: | else: | ||||
| all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", 1) | |||||
| all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion) | |||||
| if self.instance_name: | if self.instance_name: | ||||
| instance_name = "grad_" + self.instance_name | instance_name = "grad_" + self.instance_name | ||||
| all_reduce.set_prim_instance_name(instance_name) | all_reduce.set_prim_instance_name(instance_name) | ||||
| @@ -240,9 +240,7 @@ def get_bprop_mirror_operator(self): | |||||
| mul = P.Mul() | mul = P.Mul() | ||||
| cast = P.Cast() | cast = P.Cast() | ||||
| fusion = 1 | |||||
| if hasattr(self, 'fusion'): | |||||
| fusion = self.fusion | |||||
| fusion = self.get_attr_dict()["fusion"] | |||||
| all_reduce.add_prim_attr("fusion", fusion) | all_reduce.add_prim_attr("fusion", fusion) | ||||
| if hasattr(self, 'parameter'): | if hasattr(self, 'parameter'): | ||||
| parameter = self.parameter | parameter = self.parameter | ||||
| @@ -534,6 +534,7 @@ class _MirrorOperator(PrimitiveWithInfer): | |||||
| self.group = group | self.group = group | ||||
| self.dev_num = dev_num | self.dev_num = dev_num | ||||
| self.mean_flag = mean_flag | self.mean_flag = mean_flag | ||||
| self.add_prim_attr("fusion", 1) | |||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| return x_shape | return x_shape | ||||
| @@ -25,6 +25,7 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from tests.dataset_mock import MindData | from tests.dataset_mock import MindData | ||||
| import pytest | |||||
| class Dataset(MindData): | class Dataset(MindData): | ||||
| @@ -125,6 +126,7 @@ def train_common(net): | |||||
| return allreduce_fusion_dict | return allreduce_fusion_dict | ||||
| @pytest.mark.skip(reason="depreciated feature") | |||||
| def test_allreduce_fusion_parameters(): | def test_allreduce_fusion_parameters(): | ||||
| cost_model_context.reset_cost_model_context() | cost_model_context.reset_cost_model_context() | ||||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2) | cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2) | ||||
| @@ -181,6 +183,7 @@ def test_allreduce_fusion_parameters(): | |||||
| assert computation_time_parameter == 0.1 | assert computation_time_parameter == 0.1 | ||||
| @pytest.mark.skip(reason="depreciated feature") | |||||
| def test_allreduce_fusion1(): | def test_allreduce_fusion1(): | ||||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) | cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) | ||||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) | cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) | ||||
| @@ -205,6 +208,7 @@ def test_allreduce_fusion1(): | |||||
| cost_model_context.reset_cost_model_context() | cost_model_context.reset_cost_model_context() | ||||
| @pytest.mark.skip(reason="depreciated feature") | |||||
| # reset_cost_model_context is called, the default value of costmodel_allreduce_fusion_times is 0, step_allreduce_fusion | # reset_cost_model_context is called, the default value of costmodel_allreduce_fusion_times is 0, step_allreduce_fusion | ||||
| # is bypassed. | # is bypassed. | ||||
| def test_allreduce_fusion2(): | def test_allreduce_fusion2(): | ||||
| @@ -220,6 +224,7 @@ def test_allreduce_fusion2(): | |||||
| cost_model_context.reset_cost_model_context() | cost_model_context.reset_cost_model_context() | ||||
| @pytest.mark.skip(reason="depreciated feature") | |||||
| def test_allreduce_fusion3(): | def test_allreduce_fusion3(): | ||||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) | cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) | ||||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=3) | cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=3) | ||||
| @@ -248,6 +253,7 @@ def test_allreduce_fusion3(): | |||||
| cost_model_context.reset_cost_model_context() | cost_model_context.reset_cost_model_context() | ||||
| @pytest.mark.skip(reason="depreciated feature") | |||||
| def test_allreduce_fusion4(): | def test_allreduce_fusion4(): | ||||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) | cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) | ||||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) | cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) | ||||
| @@ -277,6 +283,7 @@ def test_allreduce_fusion4(): | |||||
| cost_model_context.reset_cost_model_context() | cost_model_context.reset_cost_model_context() | ||||
| @pytest.mark.skip(reason="depreciated feature") | |||||
| def test_allreduce_fusion5(): | def test_allreduce_fusion5(): | ||||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2) | cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2) | ||||
| cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1) | cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1) | ||||
| @@ -66,15 +66,30 @@ class Net2(nn.Cell): | |||||
| return x - y | return x - y | ||||
| def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None): | |||||
| class Net3(nn.Cell): | |||||
| """Net definition""" | |||||
| def __init__(self, strategy1, strategy2): | |||||
| super(Net3, self).__init__() | |||||
| self.fc1 = P.MatMul().shard(strategy=strategy1) | |||||
| self.fc2 = P.MatMul().shard(strategy=strategy2) | |||||
| self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1") | |||||
| self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2", parallel_optimizer=False) | |||||
| def construct(self, x, y): | |||||
| x = self.fc1(x, self.p1) | |||||
| x = self.fc2(x, self.p2) | |||||
| return x - y | |||||
| def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None): | |||||
| context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||
| context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True) | context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True) | ||||
| inputs = Tensor(np.ones([32, 48]).astype(np.float32)) | inputs = Tensor(np.ones([32, 48]).astype(np.float32)) | ||||
| label = Tensor(np.zeros([32, 16]).astype(np.float32)) | label = Tensor(np.zeros([32, 16]).astype(np.float32)) | ||||
| net = Net2(strategy1, strategy2) | |||||
| net = net(strategy1, strategy2) | |||||
| net = _VirtualDatasetCell(net) | net = _VirtualDatasetCell(net) | ||||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | ||||
| train_network = TrainOneStepCell(net, optimizer) | |||||
| train_network = TrainOneStepCell(net, optimizer).set_comm_fusion(4) | |||||
| train_network.set_auto_parallel() | train_network.set_auto_parallel() | ||||
| train_network.set_train() | train_network.set_train() | ||||
| _executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=True) | _executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=True) | ||||
| @@ -83,18 +98,18 @@ def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None): | |||||
| def test_auto_parallel_momentum_1(): | def test_auto_parallel_momentum_1(): | ||||
| auto_parallel_compile_net("auto_parallel", 8) | |||||
| auto_parallel_compile_net("auto_parallel", 8, Net2) | |||||
| def test_auto_parallel_momentum_2(): | def test_auto_parallel_momentum_2(): | ||||
| # data parallel case | # data parallel case | ||||
| auto_parallel_compile_net("auto_parallel", 8, ((8, 1), (1, 1)), ((8, 1), (1, 1))) | |||||
| auto_parallel_compile_net("auto_parallel", 8, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1))) | |||||
| def test_auto_parallel_momentum_3(): | def test_auto_parallel_momentum_3(): | ||||
| # hybrid parallel case | # hybrid parallel case | ||||
| # weight1 could not be shard and weight2 is repeated | # weight1 could not be shard and weight2 is repeated | ||||
| train_network = auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 8), (8, 1)), ((4, 4), (4, 2))) | |||||
| train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2))) | |||||
| param_dict = train_network.parameter_layout_dict | param_dict = train_network.parameter_layout_dict | ||||
| # validate opt_shard_group | # validate opt_shard_group | ||||
| assert not param_dict["weight1"][5] | assert not param_dict["weight1"][5] | ||||
| @@ -104,7 +119,16 @@ def test_auto_parallel_momentum_3(): | |||||
| def test_auto_parallel_momentum_4(): | def test_auto_parallel_momentum_4(): | ||||
| # hybrid parallel cases | # hybrid parallel cases | ||||
| # devices are repeatedly used | # devices are repeatedly used | ||||
| auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 4), (4, 1)), ((4, 4), (4, 2))) | |||||
| auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 4), (4, 1)), ((4, 4), (4, 2))) | |||||
| def test_auto_parallel_momentum_5(): | |||||
| # test parallel optimizer filter | |||||
| train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net3, ((4, 8), (8, 1)), ((4, 4), (4, 2))) | |||||
| param_dict = train_network.parameter_layout_dict | |||||
| # validate opt_shard_group | |||||
| assert not param_dict["weight1"][5] | |||||
| assert not param_dict["weight2"][5] | |||||
| def test_AdamWeightDecay(): | def test_AdamWeightDecay(): | ||||