|
|
|
@@ -65,18 +65,22 @@ class Dropout(Cell): |
|
|
|
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32. |
|
|
|
|
|
|
|
Raises: |
|
|
|
ValueError: If `keep_prob` is not in range (0, 1). |
|
|
|
ValueError: If `keep_prob` is not in range (0, 1]. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input** (Tensor) - An N-D Tensor. |
|
|
|
- **input** (Tensor) - The input tensor. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, output tensor with the same shape as the input. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> x = Tensor(np.ones([20, 16, 50]), mindspore.float32) |
|
|
|
>>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32) |
|
|
|
>>> net = nn.Dropout(keep_prob=0.8) |
|
|
|
>>> net(x) |
|
|
|
[[[1.0, 1.0, 1.0], |
|
|
|
[1.0, 1.0, 1.0]], |
|
|
|
[[1.0, 1.0, 1.0], |
|
|
|
[1.0, 1.0, 1.0]]] |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, keep_prob=0.5, seed0=0, seed1=0, dtype=mstype.float32): |
|
|
|
@@ -84,6 +88,7 @@ class Dropout(Cell): |
|
|
|
if keep_prob <= 0 or keep_prob > 1: |
|
|
|
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) |
|
|
|
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) |
|
|
|
validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name) |
|
|
|
self.keep_prob = keep_prob |
|
|
|
self.seed0 = seed0 |
|
|
|
self.seed1 = seed1 |
|
|
|
@@ -107,8 +112,7 @@ class Dropout(Cell): |
|
|
|
return x |
|
|
|
|
|
|
|
shape = self.get_shape(x) |
|
|
|
dtype = P.DType()(x) |
|
|
|
keep_prob = self.cast(self.keep_prob, dtype) |
|
|
|
keep_prob = self.cast(self.keep_prob, mstype.float32) |
|
|
|
output = self.dropout_gen_mask(shape, keep_prob) |
|
|
|
return self.dropout_do_mask(x, output, keep_prob) |
|
|
|
|
|
|
|
|