|
|
|
@@ -74,7 +74,7 @@ class _BatchNorm(Cell): |
|
|
|
management.create_group('group' + str(i), self.rank_list[i]) |
|
|
|
self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1) |
|
|
|
self.shape = P.Shape() |
|
|
|
self.reduce_mean = P.ReduceMean() |
|
|
|
self.reduce_mean = P.ReduceMean(keep_dims=True) |
|
|
|
self.square = P.Square() |
|
|
|
self.sqrt = P.Sqrt() |
|
|
|
self.cast = P.Cast() |
|
|
|
|