Browse Source

!11306 modify normal_batch bug

From: @Somnus2020
Reviewed-by: @kingxian,@zh_qh
Signed-off-by: @kingxian
tags/v1.1.1
mindspore-ci-bot Gitee 5 years ago
parent
commit
c7efae3bbe
2 changed files with 3 additions and 2 deletions
  1. +1
    -1
      mindspore/nn/layer/basic.py
  2. +2
    -1
      mindspore/nn/layer/normalization.py

+ 1
- 1
mindspore/nn/layer/basic.py View File

@@ -158,7 +158,7 @@ class Dropout(Cell):
return out return out


def extend_repr(self): def extend_repr(self):
return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)
return 'keep_prob={}'.format(self.keep_prob)




class Flatten(Cell): class Flatten(Cell):


+ 2
- 1
mindspore/nn/layer/normalization.py View File

@@ -109,7 +109,8 @@ class _BatchNorm(Cell):
epsilon=self.eps, epsilon=self.eps,
momentum=self.momentum) momentum=self.momentum)
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
self.enable_global_sync = self.is_global and (self.is_ge_backend or (self.is_graph_mode and self.is_ascend))
self.enable_global_sync = self.is_global and (self.is_ge_backend or\
(self.is_graph_mode and self._target == "Ascend"))


data_parallel_strategy = ((1,), (1,)) data_parallel_strategy = ((1,), (1,))
data_parallel_strategy_one = ((1,), ()) data_parallel_strategy_one = ((1,), ())


Loading…
Cancel
Save