Merge pull request !208 from vlne-v1/I1D33P-auto-mix-precision-eval-supporttags/v0.2.0-alpha
| @@ -636,6 +636,15 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| // Dealing with the RefKey case | // Dealing with the RefKey case | ||||
| auto refkeys = cnode_with_refkeys.second; | auto refkeys = cnode_with_refkeys.second; | ||||
| auto cnode = cnode_with_refkeys.first; | auto cnode = cnode_with_refkeys.first; | ||||
| auto cnode_ptr = cnode->cast<CNodePtr>(); | |||||
| if (cnode_ptr == nullptr || !IsValueNode<Primitive>(cnode_ptr->input(0))) { | |||||
| continue; | |||||
| } | |||||
| if (!IsAutoParallelCareNode(cnode_ptr)) { | |||||
| continue; | |||||
| } | |||||
| if (refkeys.size() > 1) { | if (refkeys.size() > 1) { | ||||
| MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys."; | MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys."; | ||||
| } | } | ||||
| @@ -1235,10 +1235,11 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||||
| Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`. | Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`. | ||||
| Examples: | Examples: | ||||
| >>> input_x = [1, 2, 3, 4] | |||||
| >>> segment_ids = [0, 0, 1, 2] | |||||
| >>> input_x = Tensor([1, 2, 3, 4], mindspore.float) | |||||
| >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32) | |||||
| >>> num_segments = 4 | >>> num_segments = 4 | ||||
| >>> type = P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) | |||||
| >>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) | |||||
| [3, 3, 4, 0] | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -22,6 +22,8 @@ from functools import reduce | |||||
| import numpy as np | import numpy as np | ||||
| from ... import context | from ... import context | ||||
| from ..._c_expression import signature_rw as sig_rw | |||||
| from ..._c_expression import signature_kind as sig_kind | |||||
| from ..._checkparam import ParamValidator as validator | from ..._checkparam import ParamValidator as validator | ||||
| from ..._checkparam import Rel, check_bool, check_int_positive | from ..._checkparam import Rel, check_bool, check_int_positive | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| @@ -1297,29 +1299,31 @@ class ApplyMomentum(PrimitiveWithInfer): | |||||
| filter(lambda x: x.requires_grad, net.get_parameters())) | filter(lambda x: x.requires_grad, net.get_parameters())) | ||||
| >>> model = Model(net, loss, opt) | >>> model = Model(net, loss, opt) | ||||
| """ | """ | ||||
| __mindspore_signature__ = ( | |||||
| ('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD), | |||||
| ('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD), | |||||
| ('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD), | |||||
| ('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD), | |||||
| ('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD) | |||||
| ) | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0): | def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0): | ||||
| self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'], | self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'], | ||||
| outputs=['output']) | outputs=['output']) | ||||
| def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape): | def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape): | ||||
| validator.check(f'variable shape {v_shape}', len(v_shape), '', 0, Rel.GT) | |||||
| validator.check(f'accumulation shape {a_shape}', len(a_shape), '', 0, Rel.GT) | |||||
| validator.check(f'learning rate shape {l_shape}', len(l_shape), '', 0, Rel.GE) | |||||
| validator.check(f'gradient shape {g_shape}', len(g_shape), '', 0, Rel.GE) | |||||
| validator.check(f'momentum shape {m_shape}', len(m_shape), '', 0, Rel.GE) | |||||
| return v_shape | return v_shape | ||||
| def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): | def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): | ||||
| validator.check_subclass("v_dtype", v_dtype, mstype.tensor) | |||||
| validator.check_subclass("a_dtype", a_dtype, mstype.tensor) | |||||
| v_type = validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64]) | |||||
| validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64]) | |||||
| if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey: | |||||
| validator.check_subclass("v_dtype", v_dtype, mstype.tensor) | |||||
| validator.check_subclass("a_dtype", a_dtype, mstype.tensor) | |||||
| validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64]) | |||||
| validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64]) | |||||
| validator.check_typename("l_dtype", l_dtype, [mstype.float16, mstype.float32, mstype.float64]) | validator.check_typename("l_dtype", l_dtype, [mstype.float16, mstype.float32, mstype.float64]) | ||||
| validator.check_typename("g_dtype", g_dtype, [mstype.float16, mstype.float32, mstype.float64]) | validator.check_typename("g_dtype", g_dtype, [mstype.float16, mstype.float32, mstype.float64]) | ||||
| validator.check_typename("m_dtype", m_dtype, [mstype.float16, mstype.float32, mstype.float64]) | validator.check_typename("m_dtype", m_dtype, [mstype.float16, mstype.float32, mstype.float64]) | ||||
| return v_type | |||||
| return g_dtype | |||||
| class SmoothL1Loss(PrimitiveWithInfer): | class SmoothL1Loss(PrimitiveWithInfer): | ||||
| @@ -82,6 +82,29 @@ def _check_kwargs(key_words): | |||||
| if loss_scale_manager: | if loss_scale_manager: | ||||
| validator.check_isinstance('loss_scale_manager', loss_scale_manager, LossScaleManager) | validator.check_isinstance('loss_scale_manager', loss_scale_manager, LossScaleManager) | ||||
| def _add_loss_network(network, loss_fn, cast_model_type): | |||||
| class WithLossCell(nn.Cell): | |||||
| "Wrap loss for amp. Cast network output back to float32" | |||||
| def __init__(self, backbone, loss_fn): | |||||
| super(WithLossCell, self).__init__(auto_prefix=False) | |||||
| self._backbone = backbone | |||||
| self._loss_fn = loss_fn | |||||
| def construct(self, data, label): | |||||
| out = self._backbone(data) | |||||
| label = _mp_cast_helper(mstype.float32, label) | |||||
| return self._loss_fn(F.cast(out, mstype.float32), label) | |||||
| validator.check_isinstance('loss_fn', loss_fn, nn.Cell) | |||||
| if cast_model_type == mstype.float16: | |||||
| network = WithLossCell(network, loss_fn) | |||||
| else: | |||||
| network = nn.WithLossCell(network, loss_fn) | |||||
| return network | |||||
| def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): | def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): | ||||
| """ | """ | ||||
| Build the mixed precision training cell automatically. | Build the mixed precision training cell automatically. | ||||
| @@ -117,24 +140,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): | |||||
| _do_keep_batchnorm_fp32(network) | _do_keep_batchnorm_fp32(network) | ||||
| if loss_fn: | if loss_fn: | ||||
| class WithLossCell(nn.Cell): | |||||
| "Wrap loss for amp. Cast network output back to float32" | |||||
| def __init__(self, backbone, loss_fn): | |||||
| super(WithLossCell, self).__init__(auto_prefix=False) | |||||
| self._backbone = backbone | |||||
| self._loss_fn = loss_fn | |||||
| def construct(self, data, label): | |||||
| out = self._backbone(data) | |||||
| label = _mp_cast_helper(mstype.float32, label) | |||||
| return self._loss_fn(F.cast(out, mstype.float32), label) | |||||
| validator.check_isinstance('loss_fn', loss_fn, nn.Cell) | |||||
| if config.cast_model_type == mstype.float16: | |||||
| network = WithLossCell(network, loss_fn) | |||||
| else: | |||||
| network = nn.WithLossCell(network, loss_fn) | |||||
| network = _add_loss_network(network, loss_fn, config.cast_model_type) | |||||
| if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | ||||
| network = _VirtualDatasetCell(network) | network = _VirtualDatasetCell(network) | ||||
| @@ -24,8 +24,7 @@ from .. import context | |||||
| from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | ||||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper | _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper | ||||
| from ..nn.metrics import Loss | from ..nn.metrics import Loss | ||||
| from ..nn.wrap import WithLossCell, WithEvalCell, \ | |||||
| DataWrapper | |||||
| from ..nn.wrap import WithLossCell, DataWrapper, WithEvalCell | |||||
| from ..nn.wrap.cell_wrapper import _VirtualDatasetCell | from ..nn.wrap.cell_wrapper import _VirtualDatasetCell | ||||
| from .parallel_utils import ParallelMode | from .parallel_utils import ParallelMode | ||||
| from ..common import dtype as mstype | from ..common import dtype as mstype | ||||
| @@ -151,7 +150,10 @@ class Model: | |||||
| else: | else: | ||||
| if self._loss_fn is None: | if self._loss_fn is None: | ||||
| raise ValueError("loss_fn can not be None.") | raise ValueError("loss_fn can not be None.") | ||||
| self._eval_network = WithEvalCell(self._network, self._loss_fn) | |||||
| if self._optimizer: | |||||
| self._eval_network = self._train_network.network | |||||
| else: | |||||
| self._eval_network = WithEvalCell(self._network, self._loss_fn) | |||||
| self._eval_indexes = [0, 1, 2] | self._eval_indexes = [0, 1, 2] | ||||
| def _clear_metrics(self): | def _clear_metrics(self): | ||||
| @@ -21,47 +21,6 @@ from mindspore.ops import composite as C | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore import Parameter, ParameterTuple | from mindspore import Parameter, ParameterTuple | ||||
| run_opt = C.MultitypeFuncGraph("run_opt") | |||||
| # pylint: disable=unused-argument | |||||
| @run_opt.register("Function", "Int", "Number", "Number", | |||||
| "Tensor", "Tensor", "Tensor") | |||||
| def tensor_run_opt(opt, iterator, learning_rate, momentum, | |||||
| gradient, variable, moment): | |||||
| success = True | |||||
| new_weight = opt(gradient, moment, variable, learning_rate, momentum) | |||||
| success = F.depend(success, P.Assign()(variable, new_weight)) | |||||
| return success | |||||
| class OptimizerByMomentum(nn.Cell): | |||||
| """ | |||||
| OptimizerByMomentum definition | |||||
| """ | |||||
| # list of tensor | |||||
| def __init__(self, weights): | |||||
| super(OptimizerByMomentum, self).__init__() | |||||
| self.learning_rate = Parameter(0.1, name="learning_rate") | |||||
| self.momentum = Parameter(0.05, name="momentum") | |||||
| self.iter = Parameter(0, name="iter") | |||||
| self.weights = weights | |||||
| self.moments = weights.clone(prefix="moments", init='zeros') | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.opt = P.ApplyMomentum() | |||||
| def construct(self, grads): | |||||
| success = True | |||||
| weights = self.weights | |||||
| moments = self.moments | |||||
| success = self.hyper_map( | |||||
| F.partial(run_opt, self.opt, self.iter, | |||||
| self.learning_rate, self.momentum), grads, weights, moments) | |||||
| # self.learning_rate = updata_lr(self.learning_rate, self.momentum) | |||||
| return success | |||||
| class TrainStepWrap(nn.Cell): | class TrainStepWrap(nn.Cell): | ||||
| """ | """ | ||||
| TrainStepWrap definition | TrainStepWrap definition | ||||
| @@ -71,7 +30,7 @@ class TrainStepWrap(nn.Cell): | |||||
| self.network = network | self.network = network | ||||
| self.network.set_train() | self.network.set_train() | ||||
| self.weights = ParameterTuple(network.trainable_params()) | self.weights = ParameterTuple(network.trainable_params()) | ||||
| self.optimizer = OptimizerByMomentum(self.weights) | |||||
| self.optimizer = nn.Momentum(self.weights, 0.1, 0.9) | |||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| self.grad = C.GradOperation('grad', get_by_list=True) | self.grad = C.GradOperation('grad', get_by_list=True) | ||||
| @@ -107,7 +66,7 @@ class TrainStepWrap2(nn.Cell): | |||||
| self.network = network | self.network = network | ||||
| self.network.set_train() | self.network.set_train() | ||||
| self.weights = ParameterTuple(network.get_parameters()) | self.weights = ParameterTuple(network.get_parameters()) | ||||
| self.optimizer = OptimizerByMomentum(self.weights) | |||||
| self.optimizer = nn.Momentum(self.weights, 0.1, 0.9) | |||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) | self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) | ||||
| self.sens = sens | self.sens = sens | ||||