| @@ -14,15 +14,23 @@ | |||
| # ============================================================================ | |||
| """Cell_wrapper.""" | |||
| import copy | |||
| import numpy as np | |||
| from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, | |||
| _get_parallel_mode) | |||
| from mindspore.train.parallel_utils import ParallelMode | |||
| from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean | |||
| from ...ops import composite as C, functional as F, operations as P | |||
| from ...common import Tensor, dtype as mstype | |||
| from ..cell import Cell | |||
| from ...common import Tensor | |||
| from ...common import dtype as mstype | |||
| from ...common.initializer import initializer | |||
| from ...common.parameter import Parameter, ParameterTuple | |||
| from ...ops import composite as C | |||
| from ...ops import functional as F | |||
| from ...ops import operations as P | |||
| from ...ops.composite.base import _mp_cast_helper | |||
| from ...ops.operations.comm_ops import _VirtualDataset | |||
| from ..cell import Cell | |||
| from .grad_reducer import DistributedGradReducer | |||
| @@ -310,8 +318,8 @@ class WithEvalCell(Cell): | |||
| def construct(self, data, label): | |||
| outputs = self._network(data) | |||
| loss = self._loss_fn(outputs, label) | |||
| label = _mp_cast_helper(mstype.float32, label) | |||
| loss = self._loss_fn(F.cast(outputs, mstype.float32), label) | |||
| return loss, outputs, label | |||
| @@ -24,7 +24,7 @@ from .. import context | |||
| from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | |||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper | |||
| from ..nn.metrics import Loss | |||
| from ..nn.wrap import WithLossCell, DataWrapper, WithEvalCell | |||
| from .. import nn | |||
| from ..nn.wrap.cell_wrapper import _VirtualDatasetCell | |||
| from .parallel_utils import ParallelMode | |||
| from ..common import dtype as mstype | |||
| @@ -130,7 +130,7 @@ class Model: | |||
| self._loss_fn, | |||
| level=self._amp_level) | |||
| elif self._loss_fn: | |||
| network = WithLossCell(network, self._loss_fn) | |||
| network = nn.WithLossCell(network, self._loss_fn) | |||
| # If need to check if loss_fn is not None, but optimizer is None | |||
| return network | |||
| @@ -150,10 +150,7 @@ class Model: | |||
| else: | |||
| if self._loss_fn is None: | |||
| raise ValueError("loss_fn can not be None.") | |||
| if self._optimizer: | |||
| self._eval_network = self._train_network.network | |||
| else: | |||
| self._eval_network = WithEvalCell(self._network, self._loss_fn) | |||
| self._eval_network = nn.WithEvalCell(self._network, self._loss_fn) | |||
| self._eval_indexes = [0, 1, 2] | |||
| def _clear_metrics(self): | |||
| @@ -263,7 +260,7 @@ class Model: | |||
| dataset_helper = DatasetHelper(train_dataset) | |||
| # remove later to deal with loop sink | |||
| if need_wrap: | |||
| self._train_network = DataWrapper(self._train_network, *(dataset_helper.types_shapes()), | |||
| self._train_network = nn.DataWrapper(self._train_network, *(dataset_helper.types_shapes()), | |||
| train_dataset.__ME_INITED__) | |||
| cb_params.train_network = self._train_network | |||
| self._train_network.set_train() | |||
| @@ -429,7 +426,7 @@ class Model: | |||
| # remove later to deal with loop sink | |||
| if need_wrap: | |||
| self._eval_network = DataWrapper(self._eval_network, *(dataset_helper.types_shapes()), | |||
| self._eval_network = nn.DataWrapper(self._eval_network, *(dataset_helper.types_shapes()), | |||
| valid_dataset.__ME_INITED__) | |||
| self._eval_network.set_train(mode=False) | |||
| self._eval_network.phase = 'eval' | |||
| @@ -14,12 +14,15 @@ | |||
| # ============================================================================ | |||
| """ auto mixed precision """ | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import amp | |||
| from mindspore import nn | |||
| from mindspore import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| import mindspore.context as context | |||
| from mindspore.model_zoo.resnet import resnet50 | |||
| from mindspore.train import Model | |||
| from ....dataset_mock import MindData | |||
| def setup_module(module): | |||
| @@ -85,3 +88,52 @@ def test_amp_o0_loss(): | |||
| optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| train_network = amp.build_train_network(net, optimizer, loss) | |||
| output = train_network(inputs, label) | |||
| class MindDataSet(MindData): | |||
| def __init__(self, dataset_types, dataset_shapes): | |||
| super(MindDataSet, self).__init__(size=2, batch_size=32, | |||
| np_types=dataset_types, | |||
| output_shapes=dataset_shapes, | |||
| input_indexs=(0, 1)) | |||
| def __next__(self): | |||
| if self._size < self._iter_num: | |||
| raise StopIteration | |||
| self._iter_num += 1 | |||
| next = [] | |||
| for shape, type in zip(self._output_shapes, self._np_types): | |||
| next.append(Tensor(np.ones(shape).astype(type))) | |||
| return tuple(next) | |||
| def test_compile_model_train_O0(): | |||
| dataset_types = (np.float32, np.float32) | |||
| dataset_shapes = ((16, 16), (16, 16)) | |||
| dataset = MindDataSet(dataset_types, dataset_shapes) | |||
| net = NetNoLoss(16, 16) | |||
| loss = nn.MSELoss() | |||
| optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O0") | |||
| model.train(2, dataset, dataset_sink_mode=False) | |||
| with pytest.raises(ValueError): | |||
| # not actual run, the metrics step will fail, check if compile ok. | |||
| model.eval(dataset) | |||
| def test_compile_model_train_O2(): | |||
| dataset_types = (np.float32, np.float32) | |||
| dataset_shapes = ((16, 16), (16, 16)) | |||
| dataset = MindDataSet(dataset_types, dataset_shapes) | |||
| net = NetNoLoss(16, 16) | |||
| loss = nn.MSELoss() | |||
| optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2") | |||
| model.train(2, dataset, dataset_sink_mode=False) | |||
| with pytest.raises(ValueError): | |||
| # not actual run, the metrics step will fail, check if compile ok. | |||
| model.eval(dataset) | |||