Browse Source

add Dropout2D and rename Dropout3d to Dropout3D

tags/v1.2.0-rc1
yanzhenxiang2020 4 years ago
parent
commit
95dbfe0636
9 changed files with 278 additions and 57 deletions
  1. +4
    -2
      mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h
  2. +40
    -0
      mindspore/ops/_grad/grad_nn_ops.py
  3. +1
    -0
      mindspore/ops/_op_impl/aicpu/__init__.py
  4. +42
    -0
      mindspore/ops/_op_impl/aicpu/dropout2d.py
  5. +16
    -16
      mindspore/ops/_op_impl/aicpu/dropout3d.py
  6. +3
    -1
      mindspore/ops/operations/__init__.py
  7. +74
    -14
      mindspore/ops/operations/nn_ops.py
  8. +69
    -0
      tests/st/ops/ascend/test_aicpu_ops/test_dropout2d.py
  9. +29
    -24
      tests/st/ops/ascend/test_aicpu_ops/test_dropout3d.py

+ 4
- 2
mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h View File

@@ -52,9 +52,11 @@ constexpr auto kCacheSwapTable = "CacheSwapTable";
constexpr auto kSubAndFilter = "SubAndFilter"; constexpr auto kSubAndFilter = "SubAndFilter";
constexpr auto kPadAndShift = "PadAndShift"; constexpr auto kPadAndShift = "PadAndShift";
constexpr auto kCustRunApi = "RunCpuKernel"; constexpr auto kCustRunApi = "RunCpuKernel";
constexpr auto kDropout3d = "Dropout3d";
constexpr auto kDropout2D = "Dropout2D";
constexpr auto kDropout3D = "Dropout3D";
const std::set<std::string> kCustAiCpuKernelOps{kIdentity}; const std::set<std::string> kCustAiCpuKernelOps{kIdentity};
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift, kDropout3d};
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter,
kPadAndShift, kDropout3D, kDropout2D};


struct AicpuParamHead { struct AicpuParamHead {
uint32_t length; // Total length: include cunstom message uint32_t length; // Total length: include cunstom message


+ 40
- 0
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -1263,6 +1263,46 @@ def get_bprop_dropout(self):
return bprop return bprop




@bprop_getters.register(P.Dropout2D)
def get_bprop_dropout2d(self):
"""Grad definition for `Dropout2D` operation."""
dtype = P.DType()
cast = P.Cast()
mul = P.Mul()
keep_prob = self.keep_prob

def bprop(x, out, dout):
_, mask = dout
y = cast(mask, mstype.float32)
if keep_prob != 0:
y = y * (1 / keep_prob)
y = mul(x, y)
y = cast(y, dtype(x))
return (y,)

return bprop


@bprop_getters.register(P.Dropout3D)
def get_bprop_dropout3d(self):
"""Grad definition for `Dropout3D` operation."""
dtype = P.DType()
cast = P.Cast()
mul = P.Mul()
keep_prob = self.keep_prob

def bprop(x, out, dout):
_, mask = dout
y = cast(mask, mstype.float32)
if keep_prob != 0:
y = y * (1 / keep_prob)
y = mul(x, y)
y = cast(y, dtype(x))
return (y,)

return bprop


@bprop_getters.register(P.CTCLoss) @bprop_getters.register(P.CTCLoss)
def get_bprop_ctc_loss(self): def get_bprop_ctc_loss(self):
"""Grad definition for `CTCLoss` operation""" """Grad definition for `CTCLoss` operation"""


+ 1
- 0
mindspore/ops/_op_impl/aicpu/__init__.py View File

@@ -27,6 +27,7 @@ from .unique_with_pad import _unique_with_pad_aicpu
from .sub_and_filter import _sub_and_filter_aicpu from .sub_and_filter import _sub_and_filter_aicpu
from .pad_and_shift import _pad_and_shift_aicpu from .pad_and_shift import _pad_and_shift_aicpu
from .dropout_genmask import _dropout_genmask_aicpu from .dropout_genmask import _dropout_genmask_aicpu
from .dropout2d import _dropout2d_aicpu
from .dropout3d import _dropout3d_aicpu from .dropout3d import _dropout3d_aicpu
from .get_next import _get_next_aicpu from .get_next import _get_next_aicpu
from .print_tensor import _print_aicpu from .print_tensor import _print_aicpu


+ 42
- 0
mindspore/ops/_op_impl/aicpu/dropout2d.py View File

@@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Dropout2D op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
dropout2d_op_info = AiCPURegOp("Dropout2D") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.output(1, "mask", "required") \
.attr("keep_prob", "float") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.BOOL_Default) \
.get_op_info()


@op_info_register(dropout2d_op_info)
def _dropout2d_aicpu():
"""Dropout2D AiCPU register"""
return

+ 16
- 16
mindspore/ops/_op_impl/aicpu/dropout3d.py View File

@@ -13,30 +13,30 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================


"""Dropout3d op"""
"""Dropout3D op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
dropout3d_op_info = AiCPURegOp("Dropout3d") \
dropout3d_op_info = AiCPURegOp("Dropout3D") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x", "required") \ .input(0, "x", "required") \
.output(0, "y", "required") \ .output(0, "y", "required") \
.output(1, "mask", "required") \
.attr("keep_prob", "float") \ .attr("keep_prob", "float") \
.attr("inplace", "bool") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.BOOL_Default) \
.get_op_info() .get_op_info()




@op_info_register(dropout3d_op_info) @op_info_register(dropout3d_op_info)
def _dropout3d_aicpu(): def _dropout3d_aicpu():
"""Dropout3d AiCPU register"""
"""Dropout3D AiCPU register"""
return return

+ 3
- 1
mindspore/ops/operations/__init__.py View File

@@ -64,7 +64,7 @@ from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, U
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm, from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D, BiasAdd, Conv2D,
DepthwiseConv2dNative, DepthwiseConv2dNative,
DropoutDoMask, Dropout, Dropout3d, DropoutGenMask, Flatten,
DropoutDoMask, Dropout, Dropout2D, Dropout3D, DropoutGenMask, Flatten,
FusedBatchNorm, FusedBatchNormEx, InstanceNorm, BNTrainingReduce, BNTrainingUpdate, FusedBatchNorm, FusedBatchNormEx, InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
GeLU, Gelu, FastGeLU, FastGelu, Elu, GeLU, Gelu, FastGeLU, FastGelu, Elu,


@@ -243,6 +243,8 @@ __all__ = [
'DropoutDoMask', 'DropoutDoMask',
'DropoutGenMask', 'DropoutGenMask',
'Dropout', 'Dropout',
'Dropout2D',
'Dropout3D',
'Neg', 'Neg',
'InplaceAdd', 'InplaceAdd',
'InplaceSub', 'InplaceSub',


+ 74
- 14
mindspore/ops/operations/nn_ops.py View File

@@ -6657,22 +6657,77 @@ class Dropout(PrimitiveWithCheck):
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)




class Dropout3d(PrimitiveWithInfer):
class Dropout2D(PrimitiveWithInfer):
""" """
During training, randomly zeroes some of the channels of the input tensor During training, randomly zeroes some of the channels of the input tensor
with probability keep_prob from a Bernoulli distribution.
with probability 1-`keep_prob` from a Bernoulli distribution.


Args: Args:
keep_prob (float): The keep probability of a channel, between 0 and 1, e.g. `keep_prob` = 0.8, keep_prob (float): The keep probability of a channel, between 0 and 1, e.g. `keep_prob` = 0.8,
means dropping out %20 of channels. Default: 0.5.
inplace (bool): When `inplace` is True, this operation will be done in-place. Default: False.
means dropping out 20% of channels. Default: 0.5.

Inputs:
- **input** (Tensor) - A 4-D tensor with shape :math:`(N, C, H, W)`.

Outputs:
- **output** (Tensor) - with the same shape and data type as the input tensor.
- **mask** (Tensor[bool]) - with the same shape as the input tensor.

Raises:
TypeError: If the data type of `keep_prob` is not float.
ValueError: If `keep_prob` is out of the range [0.0, 1.0];
or if the dim of input is not 4-D.

Supported Platforms:
``Ascend``

Examples:
>>> dropout = ops.Dropout2D(keep_prob=0.5)
>>> x = Tensor(np.random.randn(2, 1, 2, 3), mindspore.float32)
>>> output, mask = dropout(x)
>>> print(output)
[[[[0. 0. 0.]
[0. 0. 0.]]]
[[[0.88 -2.98 -0.01]
[2.16 -0.34 1.57]]]]
>>> print(mask)
[[[[False False False]
[False False False]]]
[[[True True True]
[True True True]]]]
"""

@prim_attr_register
def __init__(self, keep_prob=0.5):
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)

def infer_shape(self, x_shape):
validator.check_int(len(x_shape), 4, Rel.EQ, "dim of input", self.name)
return x_shape, x_shape

def infer_dtype(self, x_dtype):
valid_dtypes = mstype.int_type + (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
mask_dtype = mstype.tensor_type(mstype.bool_)
return x_dtype, mask_dtype


class Dropout3D(PrimitiveWithInfer):
"""
During training, randomly zeroes some of the channels of the input tensor
with probability 1-`keep_prob` from a Bernoulli distribution.

Args:
keep_prob (float): The keep probability of a channel, between 0 and 1, e.g. `keep_prob` = 0.8,
means dropping out 20% of channels. Default: 0.5.


Inputs: Inputs:
- **input** (Tensor) - A 5-D tensor with shape :math:`(N, C, D, H, W)`. - **input** (Tensor) - A 5-D tensor with shape :math:`(N, C, D, H, W)`.
When `inplace` is True, `input` should be Parameter.


Outputs: Outputs:
- **output** (Tensor) - with the same shape as the input tensor.
- **output** (Tensor) - with the same shape and data type as the input tensor.
- **mask** (Tensor[bool]) - with the same shape as the input tensor.


Raises: Raises:
TypeError: If the data type of `keep_prob` is not float. TypeError: If the data type of `keep_prob` is not float.
@@ -6683,30 +6738,35 @@ class Dropout3d(PrimitiveWithInfer):
``Ascend`` ``Ascend``


Examples: Examples:
>>> dropout = ops.Dropout3d(keep_prob=0.5)
>>> dropout = ops.Dropout3D(keep_prob=0.5)
>>> x = Tensor(np.random.randn(2, 1, 2, 1, 2), mindspore.float32) >>> x = Tensor(np.random.randn(2, 1, 2, 1, 2), mindspore.float32)
>>> output = dropout(x)
>>> output, mask = dropout(x)
>>> print(output) >>> print(output)
[[[[[0. 0.]] [[[[[0. 0.]]
[[0. 0.]]]] [[0. 0.]]]]
[[[[-2.98 -0.01]] [[[[-2.98 -0.01]]
[[-0.34 1.57]]]]] [[-0.34 1.57]]]]]
>>> print(mask)
[[[[[False False]]
[[False False]]]]
[[[[True True]]
[[True True]]]]]
""" """


@prim_attr_register @prim_attr_register
def __init__(self, keep_prob=0.5, inplace=False):
self.inplace = validator.check_value_type("inplace", inplace, [bool], self.name)
def __init__(self, keep_prob=0.5):
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name) self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_int(len(x_shape), 5, Rel.GE, "dim of input", self.name)
return x_shape
validator.check_int(len(x_shape), 5, Rel.EQ, "dim of input", self.name)
return x_shape, x_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
valid_dtypes = mstype.number_type + (mstype.bool_,)
valid_dtypes = mstype.int_type + (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
return x_dtype
mask_dtype = mstype.tensor_type(mstype.bool_)
return x_dtype, mask_dtype




class CTCLoss(PrimitiveWithInfer): class CTCLoss(PrimitiveWithInfer):


+ 69
- 0
tests/st/ops/ascend/test_aicpu_ops/test_dropout2d.py View File

@@ -0,0 +1,69 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.composite import GradOperation

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")

dtype = np.float16
x0 = Tensor(np.random.randn(3, 4, 3, 3).astype(dtype))
x1 = Tensor(np.random.randn(3, 4, 3, 3).astype(dtype))


class Net(nn.Cell):
def __init__(self, keep_prob):
super(Net, self).__init__()
self.drop = P.Dropout2D(keep_prob=keep_prob)

def construct(self, x):
return self.drop(x)


class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network
self.network.set_train()

def construct(self, x, y):
return self.grad(self.network)(x, y)


def test_net_float32():
net = Net(0.7)
output, mask = net(x0)
print(x0)
print(output)

y = (output.asnumpy() == (x0.asnumpy()/0.7).astype(dtype)).reshape(3*4, 3*3)
output_reshape = output.asnumpy().reshape(3*4, 3*3)
for i in range(3*4):
if not y[i].all():
assert output_reshape[i].sum() == 0
return output, mask


def test_net_grad():
net = Grad(Net(0.7))
y = test_net_float32()
output = net(x1, y)
print("input: ", x1)
print("forward output: ", y)
print("backward output: ", output)

+ 29
- 24
tests/st/ops/ascend/test_aicpu_ops/test_dropout3d.py View File

@@ -13,52 +13,57 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import numpy as np import numpy as np

import mindspore
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.composite import GradOperation


context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


dtype = np.float16
x0 = Tensor(np.random.randn(3, 4, 3, 3, 3).astype(dtype))
x1 = Tensor(np.random.randn(3, 4, 3, 3, 3).astype(dtype))



class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, keep_prob, inplace):
def __init__(self, keep_prob):
super(Net, self).__init__() super(Net, self).__init__()
self.drop = P.Dropout3d(keep_prob=keep_prob, inplace=inplace)
self.drop = P.Dropout3D(keep_prob=keep_prob)


def construct(self, x): def construct(self, x):
return self.drop(x) return self.drop(x)




class NetInplace(nn.Cell):
def __init__(self, keep_prob, inplace, x):
super(NetInplace, self).__init__()
self.drop = P.Dropout3d(keep_prob=keep_prob, inplace=inplace)
self.x = x
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network
self.network.set_train()


def construct(self):
return self.drop(self.x)
def construct(self, x, y):
return self.grad(self.network)(x, y)




def test_net_float32(): def test_net_float32():
x = Tensor(np.random.randn(3, 4, 3, 3, 3), mindspore.float32)
net = Net(0.7, False)
output = net(x)
print(x)
net = Net(0.7)
output, mask = net(x0)
print(x0)
print(output) print(output)


y = (output.asnumpy() == x.asnumpy()/0.7).reshape(3*4, 3*3*3)
y = (output.asnumpy() == (x0.asnumpy()/0.7).astype(dtype)).reshape(3*4, 3*3*3)
output_reshape = output.asnumpy().reshape(3*4, 3*3*3)
for i in range(3*4): for i in range(3*4):
if not y[i].all(): if not y[i].all():
assert y[i].sum() == 0
assert output_reshape[i].sum() == 0
return output, mask




def test_net_float32_inplace():
x = mindspore.Parameter(Tensor(np.random.randn(3, 4, 3, 3, 3), mindspore.float32))
net = NetInplace(0.7, True, x)
output = net()
print(Tensor(x))
print(output)
assert np.array_equal(x.asnumpy(), output.asnumpy())
def test_net_grad():
net = Grad(Net(0.7))
y = test_net_float32()
output = net(x1, y)
print("input: ", x1)
print("forward output: ", y)
print("backward output: ", output)

Loading…
Cancel
Save