Browse Source

make optimizer parameter same as gradient

tags/v0.3.0-alpha
guohongzilong 5 years ago
parent
commit
e0c0c8bc86
2 changed files with 3 additions and 3 deletions
  1. +1
    -1
      mindspore/nn/wrap/grad_reducer.py
  2. +2
    -2
      mindspore/nn/wrap/loss_scale.py

+ 1
- 1
mindspore/nn/wrap/grad_reducer.py View File

@@ -141,7 +141,7 @@ class DistributedGradReducer(Cell):
>>> super(TrainingWrapper, self).__init__(auto_prefix=False)
>>> self.network = network
>>> self.network.add_flags(defer_inline=True)
>>> self.weights = ParameterTuple(network.trainable_params())
>>> self.weights = optimizer.parameters
>>> self.optimizer = optimizer
>>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
>>> self.sens = sens


+ 2
- 2
mindspore/nn/wrap/loss_scale.py View File

@@ -18,7 +18,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from ..cell import Cell
from ...common import Tensor, ParameterTuple
from ...common import Tensor
from ...common.parameter import Parameter
from ...ops import functional as F
from ...ops import composite as C
@@ -201,7 +201,7 @@ class TrainOneStepWithLossScaleCell(Cell):
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.hyper_map = C.HyperMap()


Loading…
Cancel
Save