Merge pull request !2286 from gziyan/optimizer_paralleltags/v0.6.0-beta
| @@ -62,6 +62,7 @@ void ParallelContext::Reset() { | |||
| enable_all_reduce_fusion_ = false; | |||
| strategy_ckpt_load_file_ = ""; | |||
| strategy_ckpt_save_file_ = ""; | |||
| enable_parallel_optimizer_ = false; | |||
| } | |||
| void ParallelContext::set_device_num(int32_t device_num) { | |||
| @@ -100,6 +100,11 @@ class ParallelContext { | |||
| void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); | |||
| std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } | |||
| void set_enable_parallel_optimizer(bool enable_parallel_optimizer) { | |||
| enable_parallel_optimizer_ = enable_parallel_optimizer; | |||
| } | |||
| bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } | |||
| void Reset(); | |||
| private: | |||
| @@ -123,6 +128,7 @@ class ParallelContext { | |||
| std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_; | |||
| std::string strategy_ckpt_load_file_; | |||
| std::string strategy_ckpt_save_file_; | |||
| bool enable_parallel_optimizer_; | |||
| }; | |||
| void ParallelParameterContextInit(const FuncGraphPtr &func_graph); | |||
| @@ -205,6 +205,10 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") | |||
| .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.") | |||
| .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.") | |||
| .def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer, | |||
| "Set enable/disable parallel optimizer.") | |||
| .def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer, | |||
| "Get enable/disable parallel optimizer.") | |||
| .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); | |||
| (void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext") | |||
| @@ -29,8 +29,9 @@ from .optimizer import Optimizer | |||
| _adam_opt = C.MultitypeFuncGraph("adam_opt") | |||
| @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): | |||
| @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Bool", "Bool") | |||
| def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag, optim_filter): | |||
| """ | |||
| Update parameters. | |||
| @@ -44,38 +45,44 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad | |||
| m (Tensor): m value of parameters. | |||
| v (Tensor): v value of parameters. | |||
| gradient (Tensor): Gradient of parameters. | |||
| decay_flag (bool): Applies weight decay or not. | |||
| optim_filter (bool): Applies parameter update or not. | |||
| Returns: | |||
| Tensor, the new value of v after updating. | |||
| """ | |||
| op_mul = P.Mul() | |||
| op_square = P.Square() | |||
| op_sqrt = P.Sqrt() | |||
| op_cast = P.Cast() | |||
| op_reshape = P.Reshape() | |||
| op_shape = P.Shape() | |||
| if optim_filter: | |||
| op_mul = P.Mul() | |||
| op_square = P.Square() | |||
| op_sqrt = P.Sqrt() | |||
| op_cast = P.Cast() | |||
| op_reshape = P.Reshape() | |||
| op_shape = P.Shape() | |||
| param_fp32 = op_cast(param, mstype.float32) | |||
| m_fp32 = op_cast(m, mstype.float32) | |||
| v_fp32 = op_cast(v, mstype.float32) | |||
| gradient_fp32 = op_cast(gradient, mstype.float32) | |||
| param_fp32 = op_cast(param, mstype.float32) | |||
| m_fp32 = op_cast(m, mstype.float32) | |||
| v_fp32 = op_cast(v, mstype.float32) | |||
| gradient_fp32 = op_cast(gradient, mstype.float32) | |||
| next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) | |||
| next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) | |||
| - beta1, gradient_fp32) | |||
| next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) | |||
| - beta2, op_square(gradient_fp32)) | |||
| next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) | |||
| - beta2, op_square(gradient_fp32)) | |||
| update = next_m / (eps + op_sqrt(next_v)) | |||
| if decay_flag: | |||
| update = op_mul(weight_decay_tensor, param_fp32) + update | |||
| update_with_lr = op_mul(lr, update) | |||
| next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) | |||
| update = next_m / (eps + op_sqrt(next_v)) | |||
| if decay_flag: | |||
| update = op_mul(weight_decay_tensor, param_fp32) + update | |||
| next_v = F.depend(next_v, F.assign(param, op_cast(next_param, F.dtype(param)))) | |||
| next_v = F.depend(next_v, F.assign(m, op_cast(next_m, F.dtype(m)))) | |||
| next_v = F.depend(next_v, F.assign(v, op_cast(next_v, F.dtype(v)))) | |||
| return next_v | |||
| update_with_lr = op_mul(lr, update) | |||
| next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) | |||
| next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) | |||
| next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) | |||
| next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) | |||
| return next_param | |||
| return gradient | |||
| def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): | |||
| @@ -300,7 +307,7 @@ class AdamWeightDecay(Optimizer): | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||
| Outputs: | |||
| tuple[Parameter], the updated velocity value, the shape is the same as `params`. | |||
| tuple[bool], all elements are True. | |||
| Examples: | |||
| >>> net = Net() | |||
| @@ -328,11 +335,13 @@ class AdamWeightDecay(Optimizer): | |||
| def construct(self, gradients): | |||
| lr = self.get_lr() | |||
| updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||
| return updated_velocity | |||
| optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor), | |||
| self.params, self.moments1, self.moments2, gradients, | |||
| self.decay_flag, self.optim_filter) | |||
| if self.use_parallel: | |||
| optim_result = self.broadcast_params(optim_result) | |||
| return optim_result | |||
| class AdamWeightDecayDynamicLR(Optimizer): | |||
| @@ -363,7 +372,7 @@ class AdamWeightDecayDynamicLR(Optimizer): | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||
| Outputs: | |||
| tuple[Parameter], the updated velocity value, the shape is the same as `params`. | |||
| tuple[bool], all elements are True. | |||
| Examples: | |||
| >>> net = Net() | |||
| @@ -424,12 +433,14 @@ class AdamWeightDecayDynamicLR(Optimizer): | |||
| warmup_lr = self.start_learning_rate * warmup_percent | |||
| is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32) | |||
| lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr | |||
| updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||
| optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor), | |||
| self.params, self.moments1, self.moments2, gradients, | |||
| self.decay_flag, self.optim_filter) | |||
| if self.use_parallel: | |||
| optim_result = self.broadcast_params(optim_result) | |||
| added_global_step = self.global_step + self.one | |||
| F.control_depend(lr, added_global_step) | |||
| self.global_step = added_global_step | |||
| return updated_velocity | |||
| return optim_result | |||
| @@ -32,11 +32,10 @@ num_one = Tensor(np.ones([1]), mstype.float32) | |||
| _lamb_opt = C.MultitypeFuncGraph("lamb_opt") | |||
| @_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| @_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Bool", "Bool") | |||
| def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v, | |||
| gradient, decay_flag): | |||
| gradient, decay_flag, optim_filter): | |||
| """ | |||
| Update parameters. | |||
| @@ -52,66 +51,66 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para | |||
| v (Tensor): v value of parameters. | |||
| gradient (Tensor): Gradient of parameters. | |||
| decay_flag (bool): Specifies whether param update with weight decay. | |||
| optim_filter(bool): Applies parameter update or not. | |||
| Returns: | |||
| Tensor, the new value of v after updating. | |||
| """ | |||
| op_mul = P.Mul() | |||
| op_sqrt = P.Sqrt() | |||
| op_rsqrt = P.Rsqrt() | |||
| op_square = P.Square() | |||
| op_cast = P.Cast() | |||
| op_reshape = P.Reshape() | |||
| op_shape = P.Shape() | |||
| op_pow = P.Pow() | |||
| op_norm = layer.Norm() | |||
| op_select = P.Select() | |||
| op_greater = P.Greater() | |||
| op_fill = P.Fill() | |||
| op_dtype = P.DType() | |||
| param_fp32 = op_cast(param, mstype.float32) | |||
| m_fp32 = op_cast(m, mstype.float32) | |||
| v_fp32 = op_cast(v, mstype.float32) | |||
| gradient_fp32 = op_cast(gradient, mstype.float32) | |||
| next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, | |||
| mstype.float32) - beta1, gradient_fp32) | |||
| next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, | |||
| mstype.float32) - beta2, op_square(gradient_fp32)) | |||
| next_mm = next_m / (op_cast(num_one, mstype.float32) | |||
| - op_pow(beta1, op_cast(global_step + num_one, mstype.float32))) | |||
| next_vv = next_v / (op_cast(num_one, mstype.float32) - | |||
| op_pow(beta2, op_cast(global_step + num_one, mstype.float32))) | |||
| w_norm = op_norm(param_fp32) | |||
| g_norm = op_norm(gradient_fp32) | |||
| g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt( | |||
| next_vv + eps)) + weight_decay_tensor * param_fp32) | |||
| zeros = F.zeros_like(w_norm) | |||
| ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) | |||
| trust_ratio = op_select( | |||
| op_greater(w_norm, zeros), | |||
| op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones), | |||
| ones) | |||
| tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0) | |||
| trust_ratio = C.clip_by_value(trust_ratio, zeros, tens) | |||
| update = next_mm / (op_sqrt(next_vv) + eps) | |||
| if decay_flag: | |||
| update = update + op_mul(weight_decay_tensor, param_fp32) | |||
| update_with_lr = op_mul(op_mul(trust_ratio, lr), update) | |||
| next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) | |||
| next_v = F.depend(next_v, F.assign(param, next_param)) | |||
| next_v = F.depend(next_v, F.assign(m, next_m)) | |||
| next_v = F.depend(next_v, F.assign(v, next_v)) | |||
| return next_v | |||
| if optim_filter: | |||
| op_mul = P.Mul() | |||
| op_sqrt = P.Sqrt() | |||
| op_rsqrt = P.Rsqrt() | |||
| op_square = P.Square() | |||
| op_cast = P.Cast() | |||
| op_reshape = P.Reshape() | |||
| op_shape = P.Shape() | |||
| op_pow = P.Pow() | |||
| op_norm = layer.Norm() | |||
| op_select = P.Select() | |||
| op_greater = P.Greater() | |||
| op_fill = P.Fill() | |||
| op_dtype = P.DType() | |||
| param_fp32 = op_cast(param, mstype.float32) | |||
| m_fp32 = op_cast(m, mstype.float32) | |||
| v_fp32 = op_cast(v, mstype.float32) | |||
| gradient_fp32 = op_cast(gradient, mstype.float32) | |||
| next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32) | |||
| next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32)) | |||
| next_mm = next_m / (op_cast(num_one, mstype.float32) | |||
| - op_pow(beta1, op_cast(global_step + num_one, mstype.float32))) | |||
| next_vv = next_v / (op_cast(num_one, mstype.float32) - | |||
| op_pow(beta2, op_cast(global_step + num_one, mstype.float32))) | |||
| w_norm = op_norm(param_fp32) | |||
| g_norm = op_norm(gradient_fp32) | |||
| g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32) | |||
| zeros = F.zeros_like(w_norm) | |||
| ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) | |||
| trust_ratio = op_select( | |||
| op_greater(w_norm, zeros), | |||
| op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones), | |||
| ones) | |||
| tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0) | |||
| trust_ratio = C.clip_by_value(trust_ratio, zeros, tens) | |||
| update = next_mm / (op_sqrt(next_vv) + eps) | |||
| if decay_flag: | |||
| update = update + op_mul(weight_decay_tensor, param_fp32) | |||
| update_with_lr = op_mul(op_mul(trust_ratio, lr), update) | |||
| next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) | |||
| next_param = F.depend(next_param, F.assign(param, next_param)) | |||
| next_param = F.depend(next_param, F.assign(m, next_m)) | |||
| next_param = F.depend(next_param, F.assign(v, next_v)) | |||
| return next_param | |||
| return gradient | |||
| lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel") | |||
| @@ -238,7 +237,7 @@ class Lamb(Optimizer): | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||
| Outputs: | |||
| tuple[Parameter], the updated velocity value, the shape is the same as `params`. | |||
| tuple[bool], all elements are True. | |||
| Examples: | |||
| >>> net = Net() | |||
| @@ -311,18 +310,21 @@ class Lamb(Optimizer): | |||
| self.warmup_steps, self.global_step), mstype.float32) | |||
| lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr | |||
| if self.enable_graph_kernel: | |||
| updated_velocity = self.hyper_map(F.partial(lamb_opt_graph_kernel, | |||
| self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor, self.global_step), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||
| optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, | |||
| self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor, self.global_step), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||
| else: | |||
| updated_velocity = self.hyper_map(F.partial(_lamb_opt, | |||
| self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor, self.global_step), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||
| optim_result = self.hyper_map(F.partial(_lamb_opt, | |||
| self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor, self.global_step), | |||
| self.params, self.moments1, self.moments2, gradients, | |||
| self.decay_flag, self.optim_filter) | |||
| if self.use_parallel: | |||
| optim_result = self.broadcast_params(optim_result) | |||
| added_global_step = self.global_step + self.one | |||
| F.control_depend(lr, added_global_step) | |||
| self.global_step = added_global_step | |||
| return updated_velocity | |||
| return optim_result | |||
| @@ -22,11 +22,14 @@ from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.tensor import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore import log as logger | |||
| from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from mindspore.train.parallel_utils import ParallelMode | |||
| __all__ = ['Optimizer'] | |||
| @@ -155,6 +158,27 @@ class Optimizer(Cell): | |||
| self.param_length = len(self.parameters) | |||
| self.map_ = C.Map() | |||
| use_parallel = auto_parallel_context().get_enable_parallel_optimizer() | |||
| self.use_parallel = use_parallel | |||
| if use_parallel: | |||
| if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]: | |||
| raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name)) | |||
| if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL, | |||
| ParallelMode.AUTO_PARALLEL]: | |||
| raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format | |||
| (_get_parallel_mode())) | |||
| self.dev_num = _get_device_num() | |||
| if self.dev_num > self.param_length: | |||
| raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is" | |||
| " less than the number of devices {}".format(self.param_length, self.dev_num)) | |||
| self.param_rank = self._get_parameter_group_id() | |||
| self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank)) | |||
| self.param_names = [] | |||
| for param in self.parameters: | |||
| self.param_names.append(param.name) | |||
| else: | |||
| self.optim_filter = (True,) * self.param_length | |||
| def decay_weight(self, gradients): | |||
| """ | |||
| Weight decay. | |||
| @@ -401,6 +425,51 @@ class Optimizer(Cell): | |||
| lr = self.learning_rate | |||
| return lr | |||
| def _get_parameter_group_id(self): | |||
| """ | |||
| Get the parameter partition group id, which is less than the number of devices. | |||
| Returns: | |||
| tuple, the group id tuple of parameters. | |||
| """ | |||
| rank_list = () | |||
| count = 0 | |||
| for _ in range(self.param_length): | |||
| rank_list = rank_list + (count,) | |||
| count = count + 1 | |||
| if count == self.dev_num: | |||
| count = 0 | |||
| return rank_list | |||
| def broadcast_params(self, optim_result): | |||
| """ | |||
| Apply Broadcast operations in the sequential order of parameter groups. | |||
| Returns: | |||
| bool, the status flag. | |||
| """ | |||
| param_group = [] | |||
| key_group = [] | |||
| for _ in range(self.dev_num): | |||
| param_group.append(F.make_tuple()) | |||
| key_group.append(F.make_tuple()) | |||
| for i in range(self.param_length): | |||
| param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],) | |||
| key = P.MakeRefKey(self.param_names[i])() | |||
| key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,) | |||
| new_param_group = [] | |||
| for root in range(self.dev_num): | |||
| ops = P.Broadcast(root) | |||
| next_params = ops(param_group[root]) | |||
| new_param_group.append(next_params) | |||
| for i in range(F.tuple_len(next_params)): | |||
| F.assign(key_group[root][i], next_params[i]) | |||
| status = True | |||
| for i in range(self.dev_num - 1): | |||
| status = F.control_depend(new_param_group[i][0], new_param_group[i+1]) | |||
| return status | |||
| def construct(self, *hyper_params): | |||
| raise NotImplementedError | |||
| @@ -85,6 +85,22 @@ def _list_setitem_with_List(data, number_index, value): | |||
| return F.list_setitem(data, number_index, value) | |||
| @setitem.register("List", "Number", "Tuple") | |||
| def _list_setitem_with_Tuple(data, number_index, value): | |||
| """ | |||
| Assigns value to list. | |||
| Inputs: | |||
| data (list): Data of type lis. | |||
| number_index (Number): Index of data. | |||
| value (list): Value given. | |||
| Outputs: | |||
| list, type is same as the element type of data. | |||
| """ | |||
| return F.list_setitem(data, number_index, value) | |||
| @setitem.register("Dictionary", "String", "Tensor") | |||
| def _dict_setitem_with_tensor(data, key, value): | |||
| """ | |||
| @@ -98,6 +98,7 @@ class AllReduce(PrimitiveWithInfer): | |||
| self.op = op | |||
| self.add_prim_attr('group', _get_group(group)) | |||
| self.add_prim_attr('fusion', 0) | |||
| self.add_prim_attr('index', 0) | |||
| def vm_impl(self, x): | |||
| """Implement by vm mode.""" | |||
| @@ -59,8 +59,7 @@ class Assign(PrimitiveWithInfer): | |||
| return variable | |||
| def infer_dtype(self, variable, value): | |||
| args = {"variable": variable, "value": value} | |||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||
| # Add a type validation later when we don't have to assign a value to RefKey. | |||
| return variable | |||
| @@ -400,6 +400,23 @@ class _AutoParallelContext: | |||
| self.check_context_handle() | |||
| return self._context_handle.get_global_rank_is_set() | |||
| def set_enable_parallel_optimizer(self, enable_parallel_optimizer): | |||
| """ | |||
| Set enable/disable parallel optimizer. | |||
| Args: | |||
| set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer. | |||
| """ | |||
| self.check_context_handle() | |||
| if not isinstance(enable_parallel_optimizer, bool): | |||
| raise TypeError('enable_parallel_optimizer is invalid type') | |||
| self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer) | |||
| def get_enable_parallel_optimizer(self): | |||
| """Get parallel optimizer flag.""" | |||
| self.check_context_handle() | |||
| return self._context_handle.get_enable_parallel_optimizer() | |||
| def reset(self): | |||
| """Reset all settings.""" | |||
| self.check_context_handle() | |||
| @@ -433,7 +450,8 @@ _set_auto_parallel_context_func_map = { | |||
| "parameter_broadcast": auto_parallel_context().set_parameter_broadcast, | |||
| "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, | |||
| "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, | |||
| "full_batch": auto_parallel_context().set_full_batch} | |||
| "full_batch": auto_parallel_context().set_full_batch, | |||
| "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer} | |||
| _get_auto_parallel_context_func_map = { | |||
| @@ -447,13 +465,15 @@ _get_auto_parallel_context_func_map = { | |||
| "parameter_broadcast": auto_parallel_context().get_parameter_broadcast, | |||
| "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, | |||
| "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, | |||
| "full_batch": auto_parallel_context().get_full_batch} | |||
| "full_batch": auto_parallel_context().get_full_batch, | |||
| "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer} | |||
| @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, | |||
| loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, | |||
| parameter_broadcast=bool, strategy_ckpt_load_file=str, | |||
| strategy_ckpt_save_file=str, full_batch=bool) | |||
| strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) | |||
| def _set_auto_parallel_context(**kwargs): | |||
| """ | |||
| Set auto parallel context. | |||
| @@ -493,6 +513,7 @@ def _set_auto_parallel_context(**kwargs): | |||
| strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' | |||
| strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' | |||
| full_batch (bool): Whether to load the whole batch on each device. Default: False. | |||
| enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False. | |||
| Raises: | |||
| ValueError: If input key is not attribute in auto parallel context. | |||
| @@ -535,5 +556,6 @@ def _reset_auto_parallel_context(): | |||
| - parameter_broadcast: False. | |||
| - strategy_ckpt_load_file: "" | |||
| - strategy_ckpt_save_file: "" | |||
| - enable_parallel_optimizer: False | |||
| """ | |||
| auto_parallel_context().reset() | |||
| @@ -327,6 +327,9 @@ TEST_F(TestStepParallel, CreatOpInstance) { | |||
| } else if (name == "instance_name") { | |||
| parse::ConvertData(py::cast<py::object>(item.second), &converted_ret); | |||
| ASSERT_EQ(converted_ret->ToString(), "test"); | |||
| } else if (name == "index") { | |||
| parse::ConvertData(py::cast<py::object>(item.second), &converted_ret); | |||
| ASSERT_EQ(converted_ret->ToString(), "0"); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Test failed"; | |||
| } | |||
| @@ -16,7 +16,6 @@ | |||
| test assign sub | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| @@ -36,27 +35,6 @@ class AssignW(nn.Cell): | |||
| return x | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.b = Parameter(initializer('ones', [5]), name='b') | |||
| self.assign = AssignW() | |||
| def construct(self, value): | |||
| return self.assign(self.b, value) | |||
| def test_assign_through_cell(): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net = Net() | |||
| net.to_float(ms.float16) | |||
| net.add_flags_recursive(fp16=False) | |||
| input_data = Tensor(np.ones([5]).astype(np.float32)) | |||
| net(input_data) | |||
| with pytest.raises(TypeError): | |||
| net(None) | |||
| class AssignOp(nn.Cell): | |||
| def __init__(self): | |||
| super(AssignOp, self).__init__() | |||
| @@ -0,0 +1,114 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test adam """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import _executor | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb | |||
| from mindspore.ops import operations as P | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from mindspore import context | |||
| class Net(nn.Cell): | |||
| """Net definition""" | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.fc1 = nn.Dense(128, 768, activation='relu') | |||
| self.fc2 = nn.Dense(128, 768, activation='relu') | |||
| self.fc3 = nn.Dense(128, 768, activation='relu') | |||
| self.fc4 = nn.Dense(768, 768, activation='relu') | |||
| self.relu4 = nn.ReLU() | |||
| self.relu5 = nn.ReLU() | |||
| self.transpose = P.Transpose() | |||
| self.matmul1 = P.MatMul() | |||
| self.matmul2 = P.MatMul() | |||
| def construct(self, x): | |||
| q = self.fc1(x) | |||
| k = self.fc2(x) | |||
| v = self.fc3(x) | |||
| k = self.transpose(k, (1, 0)) | |||
| c = self.relu4(self.matmul1(q, k)) | |||
| s = self.relu5(self.matmul2(c, v)) | |||
| s = self.fc4(s) | |||
| return s | |||
| def test_AdamWeightDecayDynamicLR(): | |||
| """ test_AdamWeightDecayDynamicLR """ | |||
| auto_parallel_context().set_enable_parallel_optimizer(True) | |||
| context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2) | |||
| inputs = Tensor(np.ones([32, 128]).astype(np.float32)) | |||
| label = Tensor(np.zeros([32, 768]).astype(np.float32)) | |||
| net = Net() | |||
| net.set_train() | |||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| optimizer = AdamWeightDecayDynamicLR(net.trainable_params(), decay_steps=20, learning_rate=0.1) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||
| _executor.compile(train_network, inputs, label) | |||
| def test_AdamWeightDecay(): | |||
| """ test_AdamWeightDecayDynamicLR """ | |||
| auto_parallel_context().set_enable_parallel_optimizer(True) | |||
| context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2) | |||
| inputs = Tensor(np.ones([32, 128]).astype(np.float32)) | |||
| label = Tensor(np.zeros([32, 768]).astype(np.float32)) | |||
| net = Net() | |||
| net.set_train() | |||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| optimizer = AdamWeightDecay(net.trainable_params(), learning_rate=0.1) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||
| _executor.compile(train_network, inputs, label) | |||
| def test_lamb_compile(): | |||
| """ test_Lamb_compile """ | |||
| auto_parallel_context().set_enable_parallel_optimizer(True) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2) | |||
| inputs = Tensor(np.ones([32, 128]).astype(np.float32)) | |||
| label = Tensor(np.zeros([32, 768]).astype(np.float32)) | |||
| net = Net() | |||
| net.set_train() | |||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| optimizer = Lamb(net.trainable_params(), decay_steps=10) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||
| _executor.compile(train_network, inputs, label) | |||
| def test_edge_case(): | |||
| """ test_edge_case """ | |||
| auto_parallel_context().set_enable_parallel_optimizer(True) | |||
| net = Net() | |||
| with pytest.raises(RuntimeError): | |||
| context.set_auto_parallel_context(parallel_mode="stand_alone") | |||
| Lamb(net.trainable_params(), decay_steps=10) | |||
| with pytest.raises(RuntimeError): | |||
| Adam(net.trainable_params(), learning_rate=0.1) | |||
| with pytest.raises(RuntimeError): | |||
| context.set_auto_parallel_context(device_num=16) | |||
| Lamb(net.trainable_params(), decay_steps=10) | |||
| @@ -81,6 +81,10 @@ def test_set_auto_parallel_context(): | |||
| with pytest.raises(ValueError): | |||
| set_algo_parameters(tensor_slice_align_size=1025) | |||
| auto_parallel_context().set_enable_parallel_optimizer(True) | |||
| assert auto_parallel_context().get_enable_parallel_optimizer() is True | |||
| assert not auto_parallel_context().get_all_reduce_fusion_split_indices() | |||
| def test_reset_auto_parallel_context(): | |||
| context.reset_auto_parallel_context() | |||