|
|
|
@@ -109,7 +109,8 @@ class _BatchNorm(Cell): |
|
|
|
epsilon=self.eps, |
|
|
|
momentum=self.momentum) |
|
|
|
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) |
|
|
|
self.enable_global_sync = self.is_global and (self.is_ge_backend or (self.is_graph_mode and self.is_ascend)) |
|
|
|
self.enable_global_sync = self.is_global and (self.is_ge_backend or\ |
|
|
|
(self.is_graph_mode and self._target == "Ascend")) |
|
|
|
|
|
|
|
data_parallel_strategy = ((1,), (1,)) |
|
|
|
data_parallel_strategy_one = ((1,), ()) |
|
|
|
|