Browse Source

Fix the output is not tuple, when eval

tags/v0.2.0-alpha
Wei Luning 5 years ago
parent
commit
e1c8f248e0
3 changed files with 71 additions and 14 deletions
  1. +14
    -6
      mindspore/nn/wrap/cell_wrapper.py
  2. +5
    -8
      mindspore/train/model.py
  3. +52
    -0
      tests/ut/python/train/test_amp.py

+ 14
- 6
mindspore/nn/wrap/cell_wrapper.py View File

@@ -14,15 +14,23 @@
# ============================================================================ # ============================================================================
"""Cell_wrapper.""" """Cell_wrapper."""
import copy import copy

import numpy as np import numpy as np

from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean,
_get_parallel_mode)
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from ...ops import composite as C, functional as F, operations as P
from ...common import Tensor, dtype as mstype
from ..cell import Cell

from ...common import Tensor
from ...common import dtype as mstype
from ...common.initializer import initializer from ...common.initializer import initializer
from ...common.parameter import Parameter, ParameterTuple from ...common.parameter import Parameter, ParameterTuple
from ...ops import composite as C
from ...ops import functional as F
from ...ops import operations as P
from ...ops.composite.base import _mp_cast_helper
from ...ops.operations.comm_ops import _VirtualDataset from ...ops.operations.comm_ops import _VirtualDataset
from ..cell import Cell
from .grad_reducer import DistributedGradReducer from .grad_reducer import DistributedGradReducer




@@ -310,8 +318,8 @@ class WithEvalCell(Cell):


def construct(self, data, label): def construct(self, data, label):
outputs = self._network(data) outputs = self._network(data)
loss = self._loss_fn(outputs, label)
label = _mp_cast_helper(mstype.float32, label)
loss = self._loss_fn(F.cast(outputs, mstype.float32), label)
return loss, outputs, label return loss, outputs, label






+ 5
- 8
mindspore/train/model.py View File

@@ -24,7 +24,7 @@ from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper
from ..nn.metrics import Loss from ..nn.metrics import Loss
from ..nn.wrap import WithLossCell, DataWrapper, WithEvalCell
from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from .parallel_utils import ParallelMode from .parallel_utils import ParallelMode
from ..common import dtype as mstype from ..common import dtype as mstype
@@ -130,7 +130,7 @@ class Model:
self._loss_fn, self._loss_fn,
level=self._amp_level) level=self._amp_level)
elif self._loss_fn: elif self._loss_fn:
network = WithLossCell(network, self._loss_fn)
network = nn.WithLossCell(network, self._loss_fn)
# If need to check if loss_fn is not None, but optimizer is None # If need to check if loss_fn is not None, but optimizer is None
return network return network


@@ -150,10 +150,7 @@ class Model:
else: else:
if self._loss_fn is None: if self._loss_fn is None:
raise ValueError("loss_fn can not be None.") raise ValueError("loss_fn can not be None.")
if self._optimizer:
self._eval_network = self._train_network.network
else:
self._eval_network = WithEvalCell(self._network, self._loss_fn)
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn)
self._eval_indexes = [0, 1, 2] self._eval_indexes = [0, 1, 2]


def _clear_metrics(self): def _clear_metrics(self):
@@ -263,7 +260,7 @@ class Model:
dataset_helper = DatasetHelper(train_dataset) dataset_helper = DatasetHelper(train_dataset)
# remove later to deal with loop sink # remove later to deal with loop sink
if need_wrap: if need_wrap:
self._train_network = DataWrapper(self._train_network, *(dataset_helper.types_shapes()),
self._train_network = nn.DataWrapper(self._train_network, *(dataset_helper.types_shapes()),
train_dataset.__ME_INITED__) train_dataset.__ME_INITED__)
cb_params.train_network = self._train_network cb_params.train_network = self._train_network
self._train_network.set_train() self._train_network.set_train()
@@ -429,7 +426,7 @@ class Model:


# remove later to deal with loop sink # remove later to deal with loop sink
if need_wrap: if need_wrap:
self._eval_network = DataWrapper(self._eval_network, *(dataset_helper.types_shapes()),
self._eval_network = nn.DataWrapper(self._eval_network, *(dataset_helper.types_shapes()),
valid_dataset.__ME_INITED__) valid_dataset.__ME_INITED__)
self._eval_network.set_train(mode=False) self._eval_network.set_train(mode=False)
self._eval_network.phase = 'eval' self._eval_network.phase = 'eval'


+ 52
- 0
tests/ut/python/train/test_amp.py View File

@@ -14,12 +14,15 @@
# ============================================================================ # ============================================================================
""" auto mixed precision """ """ auto mixed precision """
import numpy as np import numpy as np
import pytest
from mindspore import amp from mindspore import amp
from mindspore import nn from mindspore import nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
import mindspore.context as context import mindspore.context as context
from mindspore.model_zoo.resnet import resnet50 from mindspore.model_zoo.resnet import resnet50
from mindspore.train import Model
from ....dataset_mock import MindData




def setup_module(module): def setup_module(module):
@@ -85,3 +88,52 @@ def test_amp_o0_loss():
optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_network = amp.build_train_network(net, optimizer, loss) train_network = amp.build_train_network(net, optimizer, loss)
output = train_network(inputs, label) output = train_network(inputs, label)


class MindDataSet(MindData):
def __init__(self, dataset_types, dataset_shapes):
super(MindDataSet, self).__init__(size=2, batch_size=32,
np_types=dataset_types,
output_shapes=dataset_shapes,
input_indexs=(0, 1))
def __next__(self):
if self._size < self._iter_num:
raise StopIteration
self._iter_num += 1
next = []
for shape, type in zip(self._output_shapes, self._np_types):
next.append(Tensor(np.ones(shape).astype(type)))
return tuple(next)


def test_compile_model_train_O0():
dataset_types = (np.float32, np.float32)
dataset_shapes = ((16, 16), (16, 16))

dataset = MindDataSet(dataset_types, dataset_shapes)

net = NetNoLoss(16, 16)
loss = nn.MSELoss()
optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)

model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O0")
model.train(2, dataset, dataset_sink_mode=False)
with pytest.raises(ValueError):
# not actual run, the metrics step will fail, check if compile ok.
model.eval(dataset)

def test_compile_model_train_O2():
dataset_types = (np.float32, np.float32)
dataset_shapes = ((16, 16), (16, 16))

dataset = MindDataSet(dataset_types, dataset_shapes)

net = NetNoLoss(16, 16)
loss = nn.MSELoss()
optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)

model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2")
model.train(2, dataset, dataset_sink_mode=False)
with pytest.raises(ValueError):
# not actual run, the metrics step will fail, check if compile ok.
model.eval(dataset)

Loading…
Cancel
Save