| @@ -51,8 +51,9 @@ constexpr auto kCacheSwapTable = "CacheSwapTable"; | |||||
| constexpr auto kSubAndFilter = "SubAndFilter"; | constexpr auto kSubAndFilter = "SubAndFilter"; | ||||
| constexpr auto kPadAndShift = "PadAndShift"; | constexpr auto kPadAndShift = "PadAndShift"; | ||||
| constexpr auto kCustRunApi = "RunCpuKernel"; | constexpr auto kCustRunApi = "RunCpuKernel"; | ||||
| constexpr auto kDropout3d = "Dropout3d"; | |||||
| const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kIdentity}; | const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kIdentity}; | ||||
| const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift}; | |||||
| const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift, kDropout3d}; | |||||
| struct AicpuParamHead { | struct AicpuParamHead { | ||||
| uint32_t length; // Total length: include cunstom message | uint32_t length; // Total length: include cunstom message | ||||
| @@ -27,6 +27,7 @@ from .unique_with_pad import _unique_with_pad_aicpu | |||||
| from .sub_and_filter import _sub_and_filter_aicpu | from .sub_and_filter import _sub_and_filter_aicpu | ||||
| from .pad_and_shift import _pad_and_shift_aicpu | from .pad_and_shift import _pad_and_shift_aicpu | ||||
| from .dropout_genmask import _dropout_genmask_aicpu | from .dropout_genmask import _dropout_genmask_aicpu | ||||
| from .dropout3d import _dropout3d_aicpu | |||||
| from .get_next import _get_next_aicpu | from .get_next import _get_next_aicpu | ||||
| from .print_tensor import _print_aicpu | from .print_tensor import _print_aicpu | ||||
| from .topk import _top_k_aicpu | from .topk import _top_k_aicpu | ||||
| @@ -0,0 +1,42 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Dropout3d op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| dropout3d_op_info = AiCPURegOp("Dropout3d") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "x", "required") \ | |||||
| .output(0, "y", "required") \ | |||||
| .attr("keep_prob", "float") \ | |||||
| .attr("inplace", "bool") \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F64_Default, DataType.F64_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(dropout3d_op_info) | |||||
| def _dropout3d_aicpu(): | |||||
| """Dropout3d AiCPU register""" | |||||
| return | |||||
| @@ -63,7 +63,7 @@ from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, U | |||||
| from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm, | from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm, | ||||
| BiasAdd, Conv2D, | BiasAdd, Conv2D, | ||||
| DepthwiseConv2dNative, | DepthwiseConv2dNative, | ||||
| DropoutDoMask, Dropout, DropoutGenMask, Flatten, | |||||
| DropoutDoMask, Dropout, Dropout3d, DropoutGenMask, Flatten, | |||||
| FusedBatchNorm, FusedBatchNormEx, InstanceNorm, BNTrainingReduce, BNTrainingUpdate, | FusedBatchNorm, FusedBatchNormEx, InstanceNorm, BNTrainingReduce, BNTrainingUpdate, | ||||
| Gelu, FastGelu, Elu, | Gelu, FastGelu, Elu, | ||||
| GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder, | GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder, | ||||
| @@ -6242,6 +6242,58 @@ class Dropout(PrimitiveWithInfer): | |||||
| return x_dtype, x_dtype | return x_dtype, x_dtype | ||||
| class Dropout3d(PrimitiveWithInfer): | |||||
| """ | |||||
| During training, randomly zeroes some of the channels of the input tensor | |||||
| with probability keep_prob from a Bernoulli distribution. | |||||
| Args: | |||||
| keep_prob (float): The keep probability of a channel, between 0 and 1, e.g. `keep_prob` = 0.8, | |||||
| means dropping out %20 of channels. Default: 0.5. | |||||
| inplace (bool): When `inplace` is True, this operation will be done in-place. Default: False. | |||||
| Inputs: | |||||
| - **input** (Tensor) - A 5-D tensor with shape :math:`(N, C, D, H, W)`. | |||||
| When `inplace` is True, `input` should be Parameter. | |||||
| Outputs: | |||||
| - **output** (Tensor) - with the same shape as the input tensor. | |||||
| Raises: | |||||
| TypeError: If the data type of `keep_prob` is not float. | |||||
| ValueError: If `keep_prob` is out of the range [0.0, 1.0]; | |||||
| or if the dim of input is not 5-D. | |||||
| Supported Platforms: | |||||
| ``Ascend`` | |||||
| Examples: | |||||
| >>> dropout = ops.Dropout3d(keep_prob=0.5) | |||||
| >>> x = Tensor(np.random.randn(2, 1, 2, 1, 2), mindspore.float32) | |||||
| >>> output = dropout(x) | |||||
| >>> print(output) | |||||
| [[[[[0. 0.]] | |||||
| [[0. 0.]]]] | |||||
| [[[[-2.98 -0.01]] | |||||
| [[-0.34 1.57]]]]] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, keep_prob=0.5, inplace=False): | |||||
| self.inplace = validator.check_value_type("inplace", inplace, [bool], self.name) | |||||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | |||||
| self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name) | |||||
| def infer_shape(self, x_shape): | |||||
| validator.check_int(len(x_shape), 5, Rel.GE, "dim of input", self.name) | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype): | |||||
| valid_dtypes = mstype.number_type + (mstype.bool_,) | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) | |||||
| return x_dtype | |||||
| class CTCLoss(PrimitiveWithInfer): | class CTCLoss(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Calculates the CTC (Connectionist Temporal Classification) loss and the gradient. | Calculates the CTC (Connectionist Temporal Classification) loss and the gradient. | ||||
| @@ -0,0 +1,64 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, keep_prob, inplace): | |||||
| super(Net, self).__init__() | |||||
| self.drop = P.Dropout3d(keep_prob=keep_prob, inplace=inplace) | |||||
| def construct(self, x): | |||||
| return self.drop(x) | |||||
| class NetInplace(nn.Cell): | |||||
| def __init__(self, keep_prob, inplace, x): | |||||
| super(NetInplace, self).__init__() | |||||
| self.drop = P.Dropout3d(keep_prob=keep_prob, inplace=inplace) | |||||
| self.x = x | |||||
| def construct(self): | |||||
| return self.drop(self.x) | |||||
| def test_net_float32(): | |||||
| x = Tensor(np.random.randn(3, 4, 3, 3, 3), mindspore.float32) | |||||
| net = Net(0.7, False) | |||||
| output = net(x) | |||||
| print(x) | |||||
| print(output) | |||||
| y = (output.asnumpy() == x.asnumpy()/0.7).reshape(3*4, 3*3*3) | |||||
| for i in range(3*4): | |||||
| if not y[i].all(): | |||||
| assert y[i].sum() == 0 | |||||
| def test_net_float32_inplace(): | |||||
| x = mindspore.Parameter(Tensor(np.random.randn(3, 4, 3, 3, 3), mindspore.float32)) | |||||
| net = NetInplace(0.7, True, x) | |||||
| output = net() | |||||
| print(Tensor(x)) | |||||
| print(output) | |||||
| assert np.array_equal(x.asnumpy(), output.asnumpy()) | |||||