Merge pull request !208 from vlne-v1/I1D33P-auto-mix-precision-eval-supporttags/v0.2.0-alpha
| @@ -636,6 +636,15 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||
| // Dealing with the RefKey case | |||
| auto refkeys = cnode_with_refkeys.second; | |||
| auto cnode = cnode_with_refkeys.first; | |||
| auto cnode_ptr = cnode->cast<CNodePtr>(); | |||
| if (cnode_ptr == nullptr || !IsValueNode<Primitive>(cnode_ptr->input(0))) { | |||
| continue; | |||
| } | |||
| if (!IsAutoParallelCareNode(cnode_ptr)) { | |||
| continue; | |||
| } | |||
| if (refkeys.size() > 1) { | |||
| MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys."; | |||
| } | |||
| @@ -1235,10 +1235,11 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||
| Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`. | |||
| Examples: | |||
| >>> input_x = [1, 2, 3, 4] | |||
| >>> segment_ids = [0, 0, 1, 2] | |||
| >>> input_x = Tensor([1, 2, 3, 4], mindspore.float) | |||
| >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32) | |||
| >>> num_segments = 4 | |||
| >>> type = P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) | |||
| >>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) | |||
| [3, 3, 4, 0] | |||
| """ | |||
| @prim_attr_register | |||
| @@ -22,6 +22,8 @@ from functools import reduce | |||
| import numpy as np | |||
| from ... import context | |||
| from ..._c_expression import signature_rw as sig_rw | |||
| from ..._c_expression import signature_kind as sig_kind | |||
| from ..._checkparam import ParamValidator as validator | |||
| from ..._checkparam import Rel, check_bool, check_int_positive | |||
| from ...common import dtype as mstype | |||
| @@ -1297,29 +1299,31 @@ class ApplyMomentum(PrimitiveWithInfer): | |||
| filter(lambda x: x.requires_grad, net.get_parameters())) | |||
| >>> model = Model(net, loss, opt) | |||
| """ | |||
| __mindspore_signature__ = ( | |||
| ('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD), | |||
| ('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD), | |||
| ('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD), | |||
| ('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD), | |||
| ('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD) | |||
| ) | |||
| @prim_attr_register | |||
| def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0): | |||
| self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'], | |||
| outputs=['output']) | |||
| def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape): | |||
| validator.check(f'variable shape {v_shape}', len(v_shape), '', 0, Rel.GT) | |||
| validator.check(f'accumulation shape {a_shape}', len(a_shape), '', 0, Rel.GT) | |||
| validator.check(f'learning rate shape {l_shape}', len(l_shape), '', 0, Rel.GE) | |||
| validator.check(f'gradient shape {g_shape}', len(g_shape), '', 0, Rel.GE) | |||
| validator.check(f'momentum shape {m_shape}', len(m_shape), '', 0, Rel.GE) | |||
| return v_shape | |||
| def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): | |||
| validator.check_subclass("v_dtype", v_dtype, mstype.tensor) | |||
| validator.check_subclass("a_dtype", a_dtype, mstype.tensor) | |||
| v_type = validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64]) | |||
| validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64]) | |||
| if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey: | |||
| validator.check_subclass("v_dtype", v_dtype, mstype.tensor) | |||
| validator.check_subclass("a_dtype", a_dtype, mstype.tensor) | |||
| validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64]) | |||
| validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64]) | |||
| validator.check_typename("l_dtype", l_dtype, [mstype.float16, mstype.float32, mstype.float64]) | |||
| validator.check_typename("g_dtype", g_dtype, [mstype.float16, mstype.float32, mstype.float64]) | |||
| validator.check_typename("m_dtype", m_dtype, [mstype.float16, mstype.float32, mstype.float64]) | |||
| return v_type | |||
| return g_dtype | |||
| class SmoothL1Loss(PrimitiveWithInfer): | |||
| @@ -82,6 +82,29 @@ def _check_kwargs(key_words): | |||
| if loss_scale_manager: | |||
| validator.check_isinstance('loss_scale_manager', loss_scale_manager, LossScaleManager) | |||
| def _add_loss_network(network, loss_fn, cast_model_type): | |||
| class WithLossCell(nn.Cell): | |||
| "Wrap loss for amp. Cast network output back to float32" | |||
| def __init__(self, backbone, loss_fn): | |||
| super(WithLossCell, self).__init__(auto_prefix=False) | |||
| self._backbone = backbone | |||
| self._loss_fn = loss_fn | |||
| def construct(self, data, label): | |||
| out = self._backbone(data) | |||
| label = _mp_cast_helper(mstype.float32, label) | |||
| return self._loss_fn(F.cast(out, mstype.float32), label) | |||
| validator.check_isinstance('loss_fn', loss_fn, nn.Cell) | |||
| if cast_model_type == mstype.float16: | |||
| network = WithLossCell(network, loss_fn) | |||
| else: | |||
| network = nn.WithLossCell(network, loss_fn) | |||
| return network | |||
| def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): | |||
| """ | |||
| Build the mixed precision training cell automatically. | |||
| @@ -117,24 +140,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): | |||
| _do_keep_batchnorm_fp32(network) | |||
| if loss_fn: | |||
| class WithLossCell(nn.Cell): | |||
| "Wrap loss for amp. Cast network output back to float32" | |||
| def __init__(self, backbone, loss_fn): | |||
| super(WithLossCell, self).__init__(auto_prefix=False) | |||
| self._backbone = backbone | |||
| self._loss_fn = loss_fn | |||
| def construct(self, data, label): | |||
| out = self._backbone(data) | |||
| label = _mp_cast_helper(mstype.float32, label) | |||
| return self._loss_fn(F.cast(out, mstype.float32), label) | |||
| validator.check_isinstance('loss_fn', loss_fn, nn.Cell) | |||
| if config.cast_model_type == mstype.float16: | |||
| network = WithLossCell(network, loss_fn) | |||
| else: | |||
| network = nn.WithLossCell(network, loss_fn) | |||
| network = _add_loss_network(network, loss_fn, config.cast_model_type) | |||
| if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||
| network = _VirtualDatasetCell(network) | |||
| @@ -24,8 +24,7 @@ from .. import context | |||
| from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | |||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper | |||
| from ..nn.metrics import Loss | |||
| from ..nn.wrap import WithLossCell, WithEvalCell, \ | |||
| DataWrapper | |||
| from ..nn.wrap import WithLossCell, DataWrapper, WithEvalCell | |||
| from ..nn.wrap.cell_wrapper import _VirtualDatasetCell | |||
| from .parallel_utils import ParallelMode | |||
| from ..common import dtype as mstype | |||
| @@ -151,7 +150,10 @@ class Model: | |||
| else: | |||
| if self._loss_fn is None: | |||
| raise ValueError("loss_fn can not be None.") | |||
| self._eval_network = WithEvalCell(self._network, self._loss_fn) | |||
| if self._optimizer: | |||
| self._eval_network = self._train_network.network | |||
| else: | |||
| self._eval_network = WithEvalCell(self._network, self._loss_fn) | |||
| self._eval_indexes = [0, 1, 2] | |||
| def _clear_metrics(self): | |||
| @@ -21,47 +21,6 @@ from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore import Parameter, ParameterTuple | |||
| run_opt = C.MultitypeFuncGraph("run_opt") | |||
| # pylint: disable=unused-argument | |||
| @run_opt.register("Function", "Int", "Number", "Number", | |||
| "Tensor", "Tensor", "Tensor") | |||
| def tensor_run_opt(opt, iterator, learning_rate, momentum, | |||
| gradient, variable, moment): | |||
| success = True | |||
| new_weight = opt(gradient, moment, variable, learning_rate, momentum) | |||
| success = F.depend(success, P.Assign()(variable, new_weight)) | |||
| return success | |||
| class OptimizerByMomentum(nn.Cell): | |||
| """ | |||
| OptimizerByMomentum definition | |||
| """ | |||
| # list of tensor | |||
| def __init__(self, weights): | |||
| super(OptimizerByMomentum, self).__init__() | |||
| self.learning_rate = Parameter(0.1, name="learning_rate") | |||
| self.momentum = Parameter(0.05, name="momentum") | |||
| self.iter = Parameter(0, name="iter") | |||
| self.weights = weights | |||
| self.moments = weights.clone(prefix="moments", init='zeros') | |||
| self.hyper_map = C.HyperMap() | |||
| self.opt = P.ApplyMomentum() | |||
| def construct(self, grads): | |||
| success = True | |||
| weights = self.weights | |||
| moments = self.moments | |||
| success = self.hyper_map( | |||
| F.partial(run_opt, self.opt, self.iter, | |||
| self.learning_rate, self.momentum), grads, weights, moments) | |||
| # self.learning_rate = updata_lr(self.learning_rate, self.momentum) | |||
| return success | |||
| class TrainStepWrap(nn.Cell): | |||
| """ | |||
| TrainStepWrap definition | |||
| @@ -71,7 +30,7 @@ class TrainStepWrap(nn.Cell): | |||
| self.network = network | |||
| self.network.set_train() | |||
| self.weights = ParameterTuple(network.trainable_params()) | |||
| self.optimizer = OptimizerByMomentum(self.weights) | |||
| self.optimizer = nn.Momentum(self.weights, 0.1, 0.9) | |||
| self.hyper_map = C.HyperMap() | |||
| self.grad = C.GradOperation('grad', get_by_list=True) | |||
| @@ -107,7 +66,7 @@ class TrainStepWrap2(nn.Cell): | |||
| self.network = network | |||
| self.network.set_train() | |||
| self.weights = ParameterTuple(network.get_parameters()) | |||
| self.optimizer = OptimizerByMomentum(self.weights) | |||
| self.optimizer = nn.Momentum(self.weights, 0.1, 0.9) | |||
| self.hyper_map = C.HyperMap() | |||
| self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||