|
|
|
@@ -583,6 +583,8 @@ class Optimizer(Cell): |
|
|
|
new_param_group = [] |
|
|
|
for root in range(self.dev_num): |
|
|
|
ops = P.Broadcast(root) |
|
|
|
if root > 0: |
|
|
|
param_group[root] = F.depend(param_group[root], new_param_group[root-1]) |
|
|
|
next_params = ops(param_group[root]) |
|
|
|
new_param_group.append(next_params) |
|
|
|
for i in range(F.tuple_len(next_params)): |
|
|
|
|