|
|
|
@@ -79,7 +79,7 @@ class _BatchNorm(Cell): |
|
|
|
if self.rank_id in self.rank_list[i] and self.group != 1: |
|
|
|
self.is_global = True |
|
|
|
management.create_group('group' + str(i), self.rank_list[i]) |
|
|
|
self.all_reduce = _GlobalBNHelper('group' + str(i)) |
|
|
|
self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1) |
|
|
|
self.shape = P.Shape() |
|
|
|
self.reduce_mean = P.ReduceMean() |
|
|
|
self.square = P.Square() |
|
|
|
|