From d57c5ce0100b71569e973498cd997546ffe9f66d Mon Sep 17 00:00:00 2001 From: liangchenghui Date: Fri, 11 Sep 2020 20:56:36 +0800 Subject: [PATCH] fix nn.dropout performance problem --- mindspore/nn/layer/basic.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index ba95566558..17caa5165d 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -112,7 +112,11 @@ class Dropout(Cell): return x shape = self.get_shape(x) - keep_prob = self.cast(self.keep_prob, mstype.float32) + dtype = P.DType()(x) + if _is_float_dtype(dtype): + keep_prob = self.cast(self.keep_prob, dtype) + else: + keep_prob = self.cast(self.keep_prob, mstype.float16) output = self.dropout_gen_mask(shape, keep_prob) return self.dropout_do_mask(x, output, keep_prob) @@ -256,6 +260,12 @@ def _dtype_check(x_dtype): if x_dtype not in [mstype.float32, mstype.float16]: raise TypeError("The input type must be float32 or float16.") +@constexpr +def _is_float_dtype(dtype): + if dtype in [mstype.float32, mstype.float16]: + return True + return False + class ClipByNorm(Cell): r""" Clips tensor values to a maximum :math:`L_2`-norm.