| @@ -25,6 +25,7 @@ from mindspore.ops.operations import _inner_ops as inner | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore._extends import cell_attr_register | |||
| from mindspore.common.api import ms_function | |||
| from mindspore import context | |||
| from ..cell import Cell | |||
| from .activation import get_activation | |||
| from ..._checkparam import Validator as validator | |||
| @@ -84,8 +85,19 @@ class Dropout(Cell): | |||
| self.dropout_gen_mask = P.DropoutGenMask(Seed0=seed0, Seed1=seed1) | |||
| self.dropout_do_mask = P.DropoutDoMask() | |||
| self.cast = P.Cast() | |||
| self.is_gpu = context.get_context('device_target') in ["GPU"] | |||
| if self.is_gpu: | |||
| self.dropout = P.Dropout(keep_prob) | |||
| def construct(self, x): | |||
| if not self.training: | |||
| return x | |||
| if self.is_gpu: | |||
| out, _ = self.dropout(x) | |||
| return out | |||
| shape = self.get_shape(x) | |||
| dtype = P.DType()(x) | |||
| keep_prob = self.cast(self.keep_prob, dtype) | |||
| @@ -643,3 +643,17 @@ def get_bprop_binary_cross_entropy(self): | |||
| return dx, zeros_like(y), zeros_like(weight) | |||
| return bprop | |||
| @bprop_getters.register(P.Dropout) | |||
| def get_bprop_dropout(self): | |||
| """Grad definition for `Dropout` operation.""" | |||
| grad = P.DropoutGrad(self.drop_prob) | |||
| def bprop(x, out, dout): | |||
| _, mask = out | |||
| dy, _ = dout | |||
| dx = grad(dy, mask) | |||
| return (dx,) | |||
| return bprop | |||
| @@ -52,7 +52,7 @@ from .random_ops import (RandomChoiceWithMask) | |||
| from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||
| BiasAdd, Conv2D, | |||
| DepthwiseConv2dNative, | |||
| DropoutDoMask, | |||
| DropoutDoMask, DropoutGrad, Dropout, | |||
| DropoutGenMask, Flatten, FusedBatchNorm, | |||
| Gelu, Elu, | |||
| GetNext, L2Normalize, LayerNorm, L2Loss, | |||
| @@ -157,6 +157,8 @@ __all__ = [ | |||
| 'Shape', | |||
| 'DropoutDoMask', | |||
| 'DropoutGenMask', | |||
| 'DropoutGrad', | |||
| 'Dropout', | |||
| 'Neg', | |||
| 'Slice', | |||
| 'DType', | |||
| @@ -2762,3 +2762,68 @@ class ConfusionMulGrad(PrimitiveWithInfer): | |||
| validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name) | |||
| validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name) | |||
| return input0_dtype, input1_dtype | |||
| class Dropout(PrimitiveWithInfer): | |||
| """ | |||
| During training, randomly zeroes some of the elements of the input tensor with probability. | |||
| Args: | |||
| drop_prob (float): probability of an element to be zeroed. Default: 0. | |||
| Inputs: | |||
| - **shape** (tuple[int]) - The shape of target mask. | |||
| Outputs: | |||
| Tensor, the value of generated mask for input shape. | |||
| Examples: | |||
| >>> dropout = P.Dropout(drop_prob=0.5) | |||
| >>> in = Tensor((20, 16, 50, 50)) | |||
| >>> out = dropout(in) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, drop_prob=0): | |||
| self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name) | |||
| def infer_shape(self, x_shape): | |||
| validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name) | |||
| mask_shape = x_shape | |||
| return x_shape, mask_shape | |||
| def infer_dtype(self, x_dtype): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name) | |||
| return x_dtype, x_dtype | |||
| class DropoutGrad(PrimitiveWithInfer): | |||
| """ | |||
| The gradient of Dropout. During training, randomly zeroes some of the elements | |||
| of the input tensor with probability. | |||
| Args: | |||
| drop_prob (float): probability of an element to be zeroed. Default: 0. | |||
| Inputs: | |||
| - **shape** (tuple[int]) - The shape of target mask. | |||
| Outputs: | |||
| Tensor, the value of generated mask for input shape. | |||
| Examples: | |||
| >>> dropout_grad = P.DropoutGrad(drop_prob=0.5) | |||
| >>> in = Tensor((20, 16, 50, 50)) | |||
| >>> out = dropout_grad(in) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, drop_prob=0): | |||
| self.drop_prob = validator.check_number_range("drop_prob", drop_prob, 0, 1, Rel.INC_BOTH, self.name) | |||
| def infer_shape(self, dy_shape, mask_shape): | |||
| return dy_shape | |||
| def infer_dtype(self, dy_dtype, mask_dtype): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name) | |||
| return dy_dtype | |||
| @@ -17,7 +17,9 @@ import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| context.set_context(device_target="Ascend") | |||
| def test_check_dropout_3(): | |||
| Tensor(np.ones([20, 16, 50]).astype(np.int32)) | |||
| @@ -19,26 +19,26 @@ from mindspore.common.api import _executor | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import dtype as mstype | |||
| from mindspore import context | |||
| context.set_context(device_target="Ascend") | |||
| def test_check_dropout_1(): | |||
| x = Tensor(np.ones([20, 16, 50]), mstype.float32) | |||
| m = nn.Dropout(0.8) | |||
| with pytest.raises(NotImplementedError): | |||
| m(x) | |||
| m(x) | |||
| def test_check_dropout_2(): | |||
| x = Tensor(np.ones([20, 16, 50]), mstype.float32) | |||
| m = nn.Dropout(0.3, seed0=1) | |||
| with pytest.raises(NotImplementedError): | |||
| m(x) | |||
| m(x) | |||
| def test_check_dropout_3(): | |||
| x = Tensor(np.ones([20, 16, 50]), mstype.float32) | |||
| m = nn.Dropout(0.3, seed0=1, seed1=1) | |||
| with pytest.raises(NotImplementedError): | |||
| m(x) | |||
| m(x) | |||
| class Net_Dropout(nn.Cell): | |||