Browse Source

modify batch_normal

tags/v1.2.0-rc1
lilei 5 years ago
parent
commit
8037ef77f5
2 changed files with 3 additions and 2 deletions
  1. +1
    -1
      mindspore/nn/layer/basic.py
  2. +2
    -1
      mindspore/nn/layer/normalization.py

+ 1
- 1
mindspore/nn/layer/basic.py View File

@@ -161,7 +161,7 @@ class Dropout(Cell):
return out

def extend_repr(self):
return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)
return 'keep_prob={}'.format(self.keep_prob)


class Flatten(Cell):


+ 2
- 1
mindspore/nn/layer/normalization.py View File

@@ -109,7 +109,8 @@ class _BatchNorm(Cell):
epsilon=self.eps,
momentum=self.momentum)
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
self.enable_global_sync = self.is_global and (self.is_ge_backend or (self.is_graph_mode and self.is_ascend))
self.enable_global_sync = self.is_global and (self.is_ge_backend or\
(self.is_graph_mode and self._target == "Ascend"))

data_parallel_strategy = ((1,), (1,))
data_parallel_strategy_one = ((1,), ())


Loading…
Cancel
Save