Browse Source

fix nn.dropout performance problem

tags/v1.0.0
liangchenghui 5 years ago
parent
commit
d57c5ce010
1 changed files with 11 additions and 1 deletions
  1. +11
    -1
      mindspore/nn/layer/basic.py

+ 11
- 1
mindspore/nn/layer/basic.py View File

@@ -112,7 +112,11 @@ class Dropout(Cell):
return x return x


shape = self.get_shape(x) shape = self.get_shape(x)
keep_prob = self.cast(self.keep_prob, mstype.float32)
dtype = P.DType()(x)
if _is_float_dtype(dtype):
keep_prob = self.cast(self.keep_prob, dtype)
else:
keep_prob = self.cast(self.keep_prob, mstype.float16)
output = self.dropout_gen_mask(shape, keep_prob) output = self.dropout_gen_mask(shape, keep_prob)
return self.dropout_do_mask(x, output, keep_prob) return self.dropout_do_mask(x, output, keep_prob)


@@ -256,6 +260,12 @@ def _dtype_check(x_dtype):
if x_dtype not in [mstype.float32, mstype.float16]: if x_dtype not in [mstype.float32, mstype.float16]:
raise TypeError("The input type must be float32 or float16.") raise TypeError("The input type must be float32 or float16.")


@constexpr
def _is_float_dtype(dtype):
if dtype in [mstype.float32, mstype.float16]:
return True
return False

class ClipByNorm(Cell): class ClipByNorm(Cell):
r""" r"""
Clips tensor values to a maximum :math:`L_2`-norm. Clips tensor values to a maximum :math:`L_2`-norm.


Loading…
Cancel
Save