Browse Source

!13489 insert depend for opt weight grouping

From: @gong_zi_yan
Reviewed-by: @hwhewei,@kisnwang
Signed-off-by: @hwhewei
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
a37e157697
2 changed files with 2 additions and 1 deletions
  1. +0
    -1
      mindspore/common/parameter.py
  2. +2
    -0
      mindspore/nn/optim/optimizer.py

+ 0
- 1
mindspore/common/parameter.py View File

@@ -149,7 +149,6 @@ class Parameter(Tensor_):
self._is_init = False self._is_init = False
self._inited_param = None self._inited_param = None
self._sliced = False self._sliced = False
self.comm_fusion = 1
self.is_param_ps = False self.is_param_ps = False
self._cast_type = None self._cast_type = None
self._unique = False self._unique = False


+ 2
- 0
mindspore/nn/optim/optimizer.py View File

@@ -585,6 +585,8 @@ class Optimizer(Cell):
ops = P.Broadcast(root) ops = P.Broadcast(root)
if root > 0: if root > 0:
param_group[root] = F.depend(param_group[root], new_param_group[root-1]) param_group[root] = F.depend(param_group[root], new_param_group[root-1])
else:
param_group[root] = F.depend(param_group[root], optim_result)
next_params = ops(param_group[root]) next_params = ops(param_group[root])
new_param_group.append(next_params) new_param_group.append(next_params)
for i in range(F.tuple_len(next_params)): for i in range(F.tuple_len(next_params)):


Loading…
Cancel
Save