diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index ba95566558..17caa5165d 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -112,7 +112,11 @@ class Dropout(Cell): return x shape = self.get_shape(x) - keep_prob = self.cast(self.keep_prob, mstype.float32) + dtype = P.DType()(x) + if _is_float_dtype(dtype): + keep_prob = self.cast(self.keep_prob, dtype) + else: + keep_prob = self.cast(self.keep_prob, mstype.float16) output = self.dropout_gen_mask(shape, keep_prob) return self.dropout_do_mask(x, output, keep_prob) @@ -256,6 +260,12 @@ def _dtype_check(x_dtype): if x_dtype not in [mstype.float32, mstype.float16]: raise TypeError("The input type must be float32 or float16.") +@constexpr +def _is_float_dtype(dtype): + if dtype in [mstype.float32, mstype.float16]: + return True + return False + class ClipByNorm(Cell): r""" Clips tensor values to a maximum :math:`L_2`-norm.