|
|
|
@@ -112,7 +112,11 @@ class Dropout(Cell): |
|
|
|
return x |
|
|
|
|
|
|
|
shape = self.get_shape(x) |
|
|
|
keep_prob = self.cast(self.keep_prob, mstype.float32) |
|
|
|
dtype = P.DType()(x) |
|
|
|
if _is_float_dtype(dtype): |
|
|
|
keep_prob = self.cast(self.keep_prob, dtype) |
|
|
|
else: |
|
|
|
keep_prob = self.cast(self.keep_prob, mstype.float16) |
|
|
|
output = self.dropout_gen_mask(shape, keep_prob) |
|
|
|
return self.dropout_do_mask(x, output, keep_prob) |
|
|
|
|
|
|
|
@@ -256,6 +260,12 @@ def _dtype_check(x_dtype): |
|
|
|
if x_dtype not in [mstype.float32, mstype.float16]: |
|
|
|
raise TypeError("The input type must be float32 or float16.") |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _is_float_dtype(dtype): |
|
|
|
if dtype in [mstype.float32, mstype.float16]: |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
class ClipByNorm(Cell): |
|
|
|
r""" |
|
|
|
Clips tensor values to a maximum :math:`L_2`-norm. |
|
|
|
|