From a220adca3e9d34ef20cfba289525478fd2daedbf Mon Sep 17 00:00:00 2001 From: lilei Date: Fri, 15 Jan 2021 14:54:20 +0800 Subject: [PATCH] modify batch_normal --- mindspore/nn/layer/basic.py | 2 +- mindspore/nn/layer/normalization.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 2e6992fa07..d9c4ce7971 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -158,7 +158,7 @@ class Dropout(Cell): return out def extend_repr(self): - return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype) + return 'keep_prob={}'.format(self.keep_prob) class Flatten(Cell): diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index b54922ac9f..35bba54f79 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -109,7 +109,8 @@ class _BatchNorm(Cell): epsilon=self.eps, momentum=self.momentum) self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) - self.enable_global_sync = self.is_global and (self.is_ge_backend or (self.is_graph_mode and self.is_ascend)) + self.enable_global_sync = self.is_global and (self.is_ge_backend or\ + (self.is_graph_mode and self._target == "Ascend")) data_parallel_strategy = ((1,), (1,)) data_parallel_strategy_one = ((1,), ())