|
|
|
@@ -82,6 +82,7 @@ class _BatchNorm(Cell): |
|
|
|
self.dtype = P.DType() |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.is_ascend = context.get_context("device_target") == "Ascend" |
|
|
|
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE |
|
|
|
|
|
|
|
if context.get_context("enable_ge"): |
|
|
|
self.is_ge_backend = True |
|
|
|
@@ -89,7 +90,7 @@ class _BatchNorm(Cell): |
|
|
|
else: |
|
|
|
self.is_ge_backend = False |
|
|
|
self.momentum = 1.0 - momentum |
|
|
|
if self.is_ge_backend or self.is_ascend: |
|
|
|
if self.is_graph_mode and (self.is_ge_backend or self.is_ascend): |
|
|
|
self.bn_train = P.BatchNorm(is_training=True, |
|
|
|
epsilon=self.eps) |
|
|
|
else: |
|
|
|
@@ -147,7 +148,7 @@ class _BatchNorm(Cell): |
|
|
|
if self.is_ge_backend and self.is_global: |
|
|
|
axes, re_shape = _shape_infer(F.shape(x), self.num_features) |
|
|
|
y = self._global_sync(x, axes, re_shape) |
|
|
|
elif self.is_ge_backend or self.is_ascend: |
|
|
|
elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend): |
|
|
|
y, batch_mean, batch_var, _, _ = \ |
|
|
|
self.bn_train(x, |
|
|
|
self.gamma, |
|
|
|
|