|
|
@@ -18,7 +18,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer |
|
|
from mindspore.train.parallel_utils import ParallelMode |
|
|
from mindspore.train.parallel_utils import ParallelMode |
|
|
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean |
|
|
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean |
|
|
from ..cell import Cell |
|
|
from ..cell import Cell |
|
|
from ...common import Tensor, ParameterTuple |
|
|
|
|
|
|
|
|
from ...common import Tensor |
|
|
from ...common.parameter import Parameter |
|
|
from ...common.parameter import Parameter |
|
|
from ...ops import functional as F |
|
|
from ...ops import functional as F |
|
|
from ...ops import composite as C |
|
|
from ...ops import composite as C |
|
|
@@ -201,7 +201,7 @@ class TrainOneStepWithLossScaleCell(Cell): |
|
|
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) |
|
|
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) |
|
|
self.network = network |
|
|
self.network = network |
|
|
self.network.add_flags(defer_inline=True) |
|
|
self.network.add_flags(defer_inline=True) |
|
|
self.weights = ParameterTuple(network.trainable_params()) |
|
|
|
|
|
|
|
|
self.weights = optimizer.parameters |
|
|
self.optimizer = optimizer |
|
|
self.optimizer = optimizer |
|
|
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) |
|
|
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) |
|
|
self.hyper_map = C.HyperMap() |
|
|
self.hyper_map = C.HyperMap() |
|
|
|