| @@ -112,7 +112,11 @@ class Dropout(Cell): | |||||
| return x | return x | ||||
| shape = self.get_shape(x) | shape = self.get_shape(x) | ||||
| keep_prob = self.cast(self.keep_prob, mstype.float32) | |||||
| dtype = P.DType()(x) | |||||
| if _is_float_dtype(dtype): | |||||
| keep_prob = self.cast(self.keep_prob, dtype) | |||||
| else: | |||||
| keep_prob = self.cast(self.keep_prob, mstype.float16) | |||||
| output = self.dropout_gen_mask(shape, keep_prob) | output = self.dropout_gen_mask(shape, keep_prob) | ||||
| return self.dropout_do_mask(x, output, keep_prob) | return self.dropout_do_mask(x, output, keep_prob) | ||||
| @@ -256,6 +260,12 @@ def _dtype_check(x_dtype): | |||||
| if x_dtype not in [mstype.float32, mstype.float16]: | if x_dtype not in [mstype.float32, mstype.float16]: | ||||
| raise TypeError("The input type must be float32 or float16.") | raise TypeError("The input type must be float32 or float16.") | ||||
| @constexpr | |||||
| def _is_float_dtype(dtype): | |||||
| if dtype in [mstype.float32, mstype.float16]: | |||||
| return True | |||||
| return False | |||||
| class ClipByNorm(Cell): | class ClipByNorm(Cell): | ||||
| r""" | r""" | ||||
| Clips tensor values to a maximum :math:`L_2`-norm. | Clips tensor values to a maximum :math:`L_2`-norm. | ||||