|
|
|
@@ -3227,7 +3227,8 @@ class Dropout(PrimitiveWithInfer): |
|
|
|
During training, randomly zeroes some of the elements of the input tensor with probability. |
|
|
|
|
|
|
|
Args: |
|
|
|
drop_prob (float): probability of an element to be zeroed. Default: 0. |
|
|
|
keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9, |
|
|
|
means dropping out 10% of input units. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **shape** (tuple[int]) - The shape of target mask. |
|
|
|
@@ -3236,14 +3237,14 @@ class Dropout(PrimitiveWithInfer): |
|
|
|
Tensor, the value of generated mask for input shape. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> dropout = P.Dropout(drop_prob=0.5) |
|
|
|
>>> dropout = P.Dropout(keep_prob=0.5) |
|
|
|
>>> in = Tensor((20, 16, 50, 50)) |
|
|
|
>>> out = dropout(in) |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, drop_prob=0): |
|
|
|
self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name) |
|
|
|
def __init__(self, keep_prob=0.5): |
|
|
|
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name) |
|
|
|
@@ -3262,7 +3263,8 @@ class DropoutGrad(PrimitiveWithInfer): |
|
|
|
of the input tensor with probability. |
|
|
|
|
|
|
|
Args: |
|
|
|
drop_prob (float): probability of an element to be zeroed. Default: 0. |
|
|
|
keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9, |
|
|
|
means dropping out 10% of input units. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **shape** (tuple[int]) - The shape of target mask. |
|
|
|
@@ -3271,14 +3273,14 @@ class DropoutGrad(PrimitiveWithInfer): |
|
|
|
Tensor, the value of generated mask for input shape. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> dropout_grad = P.DropoutGrad(drop_prob=0.5) |
|
|
|
>>> dropout_grad = P.DropoutGrad(keep_prob=0.5) |
|
|
|
>>> in = Tensor((20, 16, 50, 50)) |
|
|
|
>>> out = dropout_grad(in) |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, drop_prob=0): |
|
|
|
self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name) |
|
|
|
def __init__(self, keep_prob=0.5): |
|
|
|
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name) |
|
|
|
|
|
|
|
def infer_shape(self, dy_shape, mask_shape): |
|
|
|
return dy_shape |
|
|
|
|