| @@ -567,6 +567,16 @@ def get_bprop_l2_loss(self): | |||
| return bprop | |||
| @bprop_getters.register(P.RNNTLoss) | |||
| def get_bprop_rnnt_loss(self): | |||
| """Grad definition for `RNNTLoss` operation.""" | |||
| def bprop(acts, labels, act_lens, label_lens, out, dout): | |||
| grad = out[1] | |||
| return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens) | |||
| return bprop | |||
| @bprop_getters.register(P.PReLU) | |||
| def get_bprop_prelu(self): | |||
| """Grad definition for `PReLU` operation.""" | |||
| @@ -30,3 +30,5 @@ from .ctcloss import _ctcloss_aicpu | |||
| from .reverse_sequence import _reverse_sequence_aicpu | |||
| from .crop_and_resize import _crop_and_resize_aicpu | |||
| from .end_of_sequence import _end_of_sequence_aicpu | |||
| from .rnnt_loss import _rnnt_loss_aicpu | |||
| from .random_categorical import _random_categorical_aicpu | |||
| @@ -0,0 +1,48 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """RandomCategorical op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||
| random_categorical_op_info = AiCPURegOp("RandomCategorical") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .input(0, "logits", "required") \ | |||
| .input(1, "num_sample", "required") \ | |||
| .input(2, "seed", "required") \ | |||
| .output(0, "output", "required") \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(random_categorical_op_info) | |||
| def _random_categorical_aicpu(): | |||
| """RandomCategorical AiCPU register""" | |||
| return | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """RNNTLoss op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||
| rnnt_loss_op_info = AiCPURegOp("RNNTLoss") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .input(0, "acts", "required") \ | |||
| .input(1, "labels", "required") \ | |||
| .input(2, "input_lengths", "required") \ | |||
| .input(3, "label_lengths", "required") \ | |||
| .output(0, "costs", "required") \ | |||
| .output(1, "grads", "required") \ | |||
| .attr("blank_label", "int") \ | |||
| .dtype_format(DataType.F32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW, | |||
| DataType.F32_NCHW) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(rnnt_loss_op_info) | |||
| def _rnnt_loss_aicpu(): | |||
| """RNNTLoss AiCPU register""" | |||
| return | |||
| @@ -55,7 +55,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, | |||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) | |||
| from .random_ops import (RandomChoiceWithMask, Normal) | |||
| from .random_ops import (RandomChoiceWithMask, Normal, RandomCategorical) | |||
| from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm, | |||
| BiasAdd, Conv2D, | |||
| DepthwiseConv2dNative, | |||
| @@ -70,6 +70,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl | |||
| ResizeBilinear, Sigmoid, | |||
| SigmoidCrossEntropyWithLogits, | |||
| SmoothL1Loss, Softmax, Softsign, Softplus, LRN, | |||
| RNNTLoss, | |||
| SoftmaxCrossEntropyWithLogits, ROIAlign, | |||
| SparseSoftmaxCrossEntropyWithLogits, Tanh, | |||
| TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, | |||
| @@ -171,6 +172,7 @@ __all__ = [ | |||
| 'Tanh', | |||
| 'RandomChoiceWithMask', | |||
| 'Normal', | |||
| 'RandomCategorical', | |||
| 'ResizeBilinear', | |||
| 'ScalarSummary', | |||
| 'ImageSummary', | |||
| @@ -202,6 +204,7 @@ __all__ = [ | |||
| 'SmoothL1Loss', | |||
| 'L2Loss', | |||
| 'CTCLoss', | |||
| 'RNNTLoss', | |||
| 'ReduceAll', | |||
| 'ScalarToArray', | |||
| 'ScalarToTensor', | |||
| @@ -1735,6 +1735,62 @@ class DataFormatDimMap(PrimitiveWithInfer): | |||
| return x_type | |||
| class RNNTLoss(PrimitiveWithInfer): | |||
| """ | |||
| Computes the RNNTLoss and its gradient with respect to the softmax outputs. | |||
| Args: | |||
| blank_label (int): blank label. Default: 0. | |||
| Inputs: | |||
| - **acts** (Tensor[float32]) - Tensor of shape :math:`(B, T, U, V)`. | |||
| - **labels** (Tensor[int32]) - Tensor of shape :math:`(B, U-1)`. | |||
| - **input_lengths** (Tensor[int32]) - Tensor of shape :math:`(B,)`. | |||
| - **label_lebgths** (Tensor[int32]) - Tensor of shape :math:`(B,)`. | |||
| Outputs: | |||
| - **costs** (Tensor[int32]) - Tensor of shape :math:`(B,)`. | |||
| - **grads** (Tensor[int32]) - Has the same shape as `acts`. | |||
| Examples: | |||
| >>> B, T, U, V = 1, 2, 3, 5 | |||
| >>> acts = np.random.random((B, T, U, V)).astype(np.float32) | |||
| >>> labels = np.array([[1, 2]]).astype(np.int32) | |||
| >>> input_length = np.array([T] * B).astype(np.int32) | |||
| >>> label_length = np.array([len(l) for l in labels]).astype(np.int32) | |||
| >>> rnnt_loss = P.RNNTLoss(blank_label=blank) | |||
| >>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, blank_label=0): | |||
| validator.check_value_type('blank_label', blank_label, [int], self.name) | |||
| self.init_prim_io_names(inputs=['acts', 'labels', 'input_length', 'label_length'], | |||
| outputs=['costs', 'grads']) | |||
| def infer_shape(self, acts_shape, labels_shape, input_length_shape, label_length_shape): | |||
| validator.check_integer('acts_rank', len(acts_shape), 4, Rel.EQ, self.name) | |||
| validator.check_integer('labels_rank', len(labels_shape), 2, Rel.EQ, self.name) | |||
| validator.check_integer('input_length_rank', len(input_length_shape), 1, Rel.EQ, self.name) | |||
| validator.check_integer('label_length_rank', len(label_length_shape), 1, Rel.EQ, self.name) | |||
| validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) | |||
| validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2]-1, Rel.EQ, self.name) | |||
| validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) | |||
| validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) | |||
| costs_shape = (acts_shape[0],) | |||
| return (costs_shape, acts_shape) | |||
| def infer_dtype(self, acts_type, labels_type, input_length_type, label_length_type): | |||
| validator.check_subclass("acts_type", acts_type, mstype.tensor, self.name) | |||
| validator.check_subclass("labels_type", labels_type, mstype.tensor, self.name) | |||
| validator.check_subclass("input_length_type", input_length_type, mstype.tensor, self.name) | |||
| validator.check_subclass("label_length_type", label_length_type, mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"acts_type": acts_type}, [mstype.float32], self.name) | |||
| validator.check_tensor_type_same({"labels_type": labels_type}, [mstype.int32], self.name) | |||
| validator.check_tensor_type_same({"input_length_type": input_length_type}, [mstype.int32], self.name) | |||
| validator.check_tensor_type_same({"label_length_type": label_length_type}, [mstype.int32], self.name) | |||
| return (acts_type, acts_type) | |||
| class SGD(PrimitiveWithInfer): | |||
| """ | |||
| Computes stochastic gradient descent (optionally with momentum). | |||
| @@ -108,3 +108,61 @@ class Normal(PrimitiveWithInfer): | |||
| "dtype": mstype.float32, | |||
| "value": None} | |||
| return out | |||
| class RandomCategorical(PrimitiveWithInfer): | |||
| """ | |||
| Generates random samples from a given categorical distribution tensor. | |||
| Args: | |||
| dtype (mindspore.dtype): The type of output. Its value should be one of [mindspore.int16, | |||
| mindspore.int32, mindspore.int64]. Default: mindspore.int64. | |||
| Inputs: | |||
| - **logits** (Tensor) - The input tensor. 2-D Tensor with shape [batch_size, num_classes]. | |||
| - **num_sample** (int) - Number of sample to be drawn. Only constant values is allowed. | |||
| - **seed** (int) - Random seed. Default: 0. Only constant values is allowed. | |||
| Outputs: | |||
| - **output** (Tensor) - The output Tensor with shape [batch_size, num_samples]. | |||
| Examples: | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self, num_sample): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.random_categorical = P.RandomCategorical(mindspore.int64) | |||
| >>> self.num_sample = num_sample | |||
| >>> def construct(self, logits, seed=0): | |||
| >>> return self.random_categorical(logits, self.num_sample, seed) | |||
| >>> | |||
| >>> x = np.random.random((10, 5)).astype(np.float32) | |||
| >>> net = Net(8) | |||
| >>> output = net(Tensor(x)) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, dtype=mstype.int64): | |||
| """Init RandomCategorical""" | |||
| self.dtype = dtype | |||
| valid_values = (mstype.int32, mstype.int16, mstype.int64) | |||
| validator.check_type_name("dtype", dtype, valid_values, self.name) | |||
| self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'], | |||
| outputs=['output']) | |||
| def __infer__(self, logits, num_samples, seed): | |||
| logits_dtype = logits['dtype'] | |||
| valid_types = (mstype.float32, mstype.float16, mstype.float64) | |||
| validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name) | |||
| num_samples_v = num_samples['value'] | |||
| seed_v = seed['value'] | |||
| validator.check_value_type('num_samples', num_samples_v, (int,), self.name) | |||
| validator.check_value_type('seed', seed_v, (int,), self.name) | |||
| validator.check_integer("num_samples", num_samples_v, 0, Rel.GT, self.name) | |||
| x_shape = list(logits['shape']) | |||
| if len(x_shape) != 2: | |||
| raise ValueError("RandomCategorical shape should be 2-dimension.") | |||
| ndim = len(x_shape) - 1 | |||
| x_shape[ndim] = num_samples_v | |||
| return {'shape': (x_shape), | |||
| 'dtype': (self.dtype), | |||
| 'value': None} | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Net(nn.Cell): | |||
| def __init__(self, num_sample): | |||
| super(Net, self).__init__() | |||
| self.random_categorical = P.RandomCategorical(mindspore.int64) | |||
| self.num_sample = num_sample | |||
| def construct(self, logits, seed=0): | |||
| return self.random_categorical(logits, self.num_sample, seed) | |||
| def test_net(): | |||
| x = np.random.random((10, 5)).astype(np.float32) | |||
| net = Net(8) | |||
| output = net(Tensor(x)) | |||
| print(x) | |||
| print(output.asnumpy()) | |||
| #print(output.dtype()) | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.rnnt_loss = P.RNNTLoss(blank_label=0) | |||
| def construct(self, acts, labels, act_lens, label_lens): | |||
| return self.rnnt_loss(acts, labels, act_lens, label_lens) | |||
| def test_net(): | |||
| B, T, U, V = 1, 2, 3, 5 | |||
| acts = np.random.random((B, T, U, V)).astype(np.float32) | |||
| labels = np.array([[np.random.randint(1, V-1) for _ in range(U-1)]]).astype(np.int32) | |||
| input_length = np.array([T] * B).astype(np.int32) | |||
| label_length = np.array([len(l) for l in labels]).astype(np.int32) | |||
| rnnt_loss = Net() | |||
| costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) | |||
| print(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) | |||
| print(costs.asnumpy()) | |||
| print(grads.asnumpy()) | |||