Browse Source

!6285 ME move dropoutgrad to inner

Merge pull request !6285 from VectorSL/drop
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
353b8b81c3
4 changed files with 38 additions and 39 deletions
  1. +1
    -1
      mindspore/ops/_grad/grad_nn_ops.py
  2. +1
    -2
      mindspore/ops/operations/__init__.py
  3. +36
    -0
      mindspore/ops/operations/_grad_ops.py
  4. +0
    -36
      mindspore/ops/operations/nn_ops.py

+ 1
- 1
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -930,7 +930,7 @@ def get_bprop_kl_div_loss(self):
@bprop_getters.register(P.Dropout)
def get_bprop_dropout(self):
"""Grad definition for `Dropout` operation."""
grad = P.DropoutGrad(self.keep_prob)
grad = G.DropoutGrad(self.keep_prob)

def bprop(x, out, dout):
_, mask = out


+ 1
- 2
mindspore/ops/operations/__init__.py View File

@@ -61,7 +61,7 @@ from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, U
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D,
DepthwiseConv2dNative,
DropoutDoMask, DropoutGrad, Dropout,
DropoutDoMask, Dropout,
DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate,
Gelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, CTCGreedyDecoder,
@@ -211,7 +211,6 @@ __all__ = [
'DynamicShape',
'DropoutDoMask',
'DropoutGenMask',
'DropoutGrad',
'Dropout',
'Neg',
'InplaceAdd',


+ 36
- 0
mindspore/ops/operations/_grad_ops.py View File

@@ -462,6 +462,42 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
return out


class DropoutGrad(PrimitiveWithInfer):
"""
The gradient of Dropout. During training, randomly zeroes some of the elements
of the input tensor with probability.

Args:
keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
means dropping out 10% of input units.

Inputs:
- **shape** (tuple[int]) - The shape of target mask.

Outputs:
Tensor, the value of generated mask for input shape.

Examples:
>>> dropout_grad = P.DropoutGrad(keep_prob=0.5)
>>> in = Tensor((20, 16, 50, 50))
>>> out = dropout_grad(in)
"""

@prim_attr_register
def __init__(self, keep_prob=0.5):
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name)

def infer_shape(self, dy_shape, mask_shape):
return dy_shape

def infer_dtype(self, dy_dtype, mask_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_subclass("dy", dy_dtype, mstype.tensor, self.name)
validator.check_subclass("mask", mask_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name)
return dy_dtype


class FlattenGrad(PrimitiveWithInfer):
"""Performs gradients of Flatten."""



+ 0
- 36
mindspore/ops/operations/nn_ops.py View File

@@ -5247,42 +5247,6 @@ class Dropout(PrimitiveWithInfer):
return x_dtype, x_dtype


class DropoutGrad(PrimitiveWithInfer):
"""
The gradient of Dropout. During training, randomly zeroes some of the elements
of the input tensor with probability.

Args:
keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
means dropping out 10% of input units.

Inputs:
- **shape** (tuple[int]) - The shape of target mask.

Outputs:
Tensor, the value of generated mask for input shape.

Examples:
>>> dropout_grad = P.DropoutGrad(keep_prob=0.5)
>>> in = Tensor((20, 16, 50, 50))
>>> out = dropout_grad(in)
"""

@prim_attr_register
def __init__(self, keep_prob=0.5):
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name)

def infer_shape(self, dy_shape, mask_shape):
return dy_shape

def infer_dtype(self, dy_dtype, mask_dtype):
valid_types = (mstype.float16, mstype.float32)
validator.check_subclass("dy", dy_dtype, mstype.tensor, self.name)
validator.check_subclass("mask", mask_dtype, mstype.tensor, self.name)
validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name)
return dy_dtype


class CTCLoss(PrimitiveWithInfer):
"""
Calculates the CTC (Connectionist Temporal Classification) loss and the gradient.


Loading…
Cancel
Save