Browse Source

make optimizer parameter same as gradient

tags/v0.3.0-alpha
guohongzilong 5 years ago
parent
commit
e0c0c8bc86
2 changed files with 3 additions and 3 deletions
  1. +1
    -1
      mindspore/nn/wrap/grad_reducer.py
  2. +2
    -2
      mindspore/nn/wrap/loss_scale.py

+ 1
- 1
mindspore/nn/wrap/grad_reducer.py View File

@@ -141,7 +141,7 @@ class DistributedGradReducer(Cell):
>>> super(TrainingWrapper, self).__init__(auto_prefix=False) >>> super(TrainingWrapper, self).__init__(auto_prefix=False)
>>> self.network = network >>> self.network = network
>>> self.network.add_flags(defer_inline=True) >>> self.network.add_flags(defer_inline=True)
>>> self.weights = ParameterTuple(network.trainable_params())
>>> self.weights = optimizer.parameters
>>> self.optimizer = optimizer >>> self.optimizer = optimizer
>>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) >>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
>>> self.sens = sens >>> self.sens = sens


+ 2
- 2
mindspore/nn/wrap/loss_scale.py View File

@@ -18,7 +18,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from ..cell import Cell from ..cell import Cell
from ...common import Tensor, ParameterTuple
from ...common import Tensor
from ...common.parameter import Parameter from ...common.parameter import Parameter
from ...ops import functional as F from ...ops import functional as F
from ...ops import composite as C from ...ops import composite as C
@@ -201,7 +201,7 @@ class TrainOneStepWithLossScaleCell(Cell):
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()


Loading…
Cancel
Save