Browse Source

!5833 fix thor optimizer interface

Merge pull request !5833 from wangmin0104/master
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
5901be3ba8
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      model_zoo/official/cv/resnet_thor/src/thor.py

+ 3
- 3
model_zoo/official/cv/resnet_thor/src/thor.py View File

@@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool from mindspore._checkparam import check_bool
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore.nn.optim.optimizer import Optimizer from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_mirror_mean
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from src.grad_reducer_thor import DistributedGradReducerThor from src.grad_reducer_thor import DistributedGradReducerThor


_momentum_opt = C.MultitypeFuncGraph("momentum_opt") _momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@@ -85,7 +85,7 @@ class THOR_GPU(Optimizer):
self.assign = P.Assign() self.assign = P.Assign()
self.mul = P.Mul() self.mul = P.Mul()


mean = _get_mirror_mean()
mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()


parameter_length = len(self.feature_map) parameter_length = len(self.feature_map)
@@ -193,7 +193,7 @@ class THOR(Optimizer):
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
1.0] 1.0]
mean = _get_mirror_mean()
mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()
parameter_length = len(self.feature_map) parameter_length = len(self.feature_map)
self.grad_reducer_Amax = DistributedGradReducerThor(parameter_length, ((27,), 2), mean, degree) self.grad_reducer_Amax = DistributedGradReducerThor(parameter_length, ((27,), 2), mean, degree)


Loading…
Cancel
Save