Browse Source

add RNNTLoss and RandomCategorical op for aicpu

tags/v0.6.0-beta
yanzhenxiang2020 5 years ago
parent
commit
2ae6dfe95a
9 changed files with 294 additions and 1 deletions
  1. +10
    -0
      mindspore/ops/_grad/grad_nn_ops.py
  2. +2
    -0
      mindspore/ops/_op_impl/aicpu/__init__.py
  3. +48
    -0
      mindspore/ops/_op_impl/aicpu/random_categorical.py
  4. +37
    -0
      mindspore/ops/_op_impl/aicpu/rnnt_loss.py
  5. +4
    -1
      mindspore/ops/operations/__init__.py
  6. +56
    -0
      mindspore/ops/operations/nn_ops.py
  7. +58
    -0
      mindspore/ops/operations/random_ops.py
  8. +38
    -0
      tests/st/ops/ascend/test_aicpu_ops/test_random_categorical.py
  9. +41
    -0
      tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py

+ 10
- 0
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -567,6 +567,16 @@ def get_bprop_l2_loss(self):
return bprop return bprop




@bprop_getters.register(P.RNNTLoss)
def get_bprop_rnnt_loss(self):
"""Grad definition for `RNNTLoss` operation."""

def bprop(acts, labels, act_lens, label_lens, out, dout):
grad = out[1]
return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens)
return bprop


@bprop_getters.register(P.PReLU) @bprop_getters.register(P.PReLU)
def get_bprop_prelu(self): def get_bprop_prelu(self):
"""Grad definition for `PReLU` operation.""" """Grad definition for `PReLU` operation."""


+ 2
- 0
mindspore/ops/_op_impl/aicpu/__init__.py View File

@@ -30,3 +30,5 @@ from .ctcloss import _ctcloss_aicpu
from .reverse_sequence import _reverse_sequence_aicpu from .reverse_sequence import _reverse_sequence_aicpu
from .crop_and_resize import _crop_and_resize_aicpu from .crop_and_resize import _crop_and_resize_aicpu
from .end_of_sequence import _end_of_sequence_aicpu from .end_of_sequence import _end_of_sequence_aicpu
from .rnnt_loss import _rnnt_loss_aicpu
from .random_categorical import _random_categorical_aicpu

+ 48
- 0
mindspore/ops/_op_impl/aicpu/random_categorical.py View File

@@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""RandomCategorical op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType

random_categorical_op_info = AiCPURegOp("RandomCategorical") \
.fusion_type("OPAQUE") \
.input(0, "logits", "required") \
.input(1, "num_sample", "required") \
.input(2, "seed", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.get_op_info()

@op_info_register(random_categorical_op_info)
def _random_categorical_aicpu():
"""RandomCategorical AiCPU register"""
return

+ 37
- 0
mindspore/ops/_op_impl/aicpu/rnnt_loss.py View File

@@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""RNNTLoss op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType

rnnt_loss_op_info = AiCPURegOp("RNNTLoss") \
.fusion_type("OPAQUE") \
.input(0, "acts", "required") \
.input(1, "labels", "required") \
.input(2, "input_lengths", "required") \
.input(3, "label_lengths", "required") \
.output(0, "costs", "required") \
.output(1, "grads", "required") \
.attr("blank_label", "int") \
.dtype_format(DataType.F32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F32_Default, DataType.F32_Default) \
.get_op_info()

@op_info_register(rnnt_loss_op_info)
def _rnnt_loss_aicpu():
"""RNNTLoss AiCPU register"""
return

+ 4
- 1
mindspore/ops/operations/__init__.py View File

@@ -55,7 +55,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan)


from .random_ops import (RandomChoiceWithMask, Normal)
from .random_ops import (RandomChoiceWithMask, Normal, RandomCategorical)
from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm, from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D, BiasAdd, Conv2D,
DepthwiseConv2dNative, DepthwiseConv2dNative,
@@ -70,6 +70,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
ResizeBilinear, Sigmoid, ResizeBilinear, Sigmoid,
SigmoidCrossEntropyWithLogits, SigmoidCrossEntropyWithLogits,
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, SmoothL1Loss, Softmax, Softsign, Softplus, LRN,
RNNTLoss,
SoftmaxCrossEntropyWithLogits, ROIAlign, SoftmaxCrossEntropyWithLogits, ROIAlign,
SparseSoftmaxCrossEntropyWithLogits, Tanh, SparseSoftmaxCrossEntropyWithLogits, Tanh,
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
@@ -171,6 +172,7 @@ __all__ = [
'Tanh', 'Tanh',
'RandomChoiceWithMask', 'RandomChoiceWithMask',
'Normal', 'Normal',
'RandomCategorical',
'ResizeBilinear', 'ResizeBilinear',
'ScalarSummary', 'ScalarSummary',
'ImageSummary', 'ImageSummary',
@@ -202,6 +204,7 @@ __all__ = [
'SmoothL1Loss', 'SmoothL1Loss',
'L2Loss', 'L2Loss',
'CTCLoss', 'CTCLoss',
'RNNTLoss',
'ReduceAll', 'ReduceAll',
'ScalarToArray', 'ScalarToArray',
'ScalarToTensor', 'ScalarToTensor',


+ 56
- 0
mindspore/ops/operations/nn_ops.py View File

@@ -1735,6 +1735,62 @@ class DataFormatDimMap(PrimitiveWithInfer):
return x_type return x_type




class RNNTLoss(PrimitiveWithInfer):
"""
Computes the RNNTLoss and its gradient with respect to the softmax outputs.

Args:
blank_label (int): blank label. Default: 0.

Inputs:
- **acts** (Tensor[float32]) - Tensor of shape :math:`(B, T, U, V)`.
- **labels** (Tensor[int32]) - Tensor of shape :math:`(B, U-1)`.
- **input_lengths** (Tensor[int32]) - Tensor of shape :math:`(B,)`.
- **label_lebgths** (Tensor[int32]) - Tensor of shape :math:`(B,)`.

Outputs:
- **costs** (Tensor[int32]) - Tensor of shape :math:`(B,)`.
- **grads** (Tensor[int32]) - Has the same shape as `acts`.

Examples:
>>> B, T, U, V = 1, 2, 3, 5
>>> acts = np.random.random((B, T, U, V)).astype(np.float32)
>>> labels = np.array([[1, 2]]).astype(np.int32)
>>> input_length = np.array([T] * B).astype(np.int32)
>>> label_length = np.array([len(l) for l in labels]).astype(np.int32)
>>> rnnt_loss = P.RNNTLoss(blank_label=blank)
>>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
"""
@prim_attr_register
def __init__(self, blank_label=0):
validator.check_value_type('blank_label', blank_label, [int], self.name)
self.init_prim_io_names(inputs=['acts', 'labels', 'input_length', 'label_length'],
outputs=['costs', 'grads'])

def infer_shape(self, acts_shape, labels_shape, input_length_shape, label_length_shape):
validator.check_integer('acts_rank', len(acts_shape), 4, Rel.EQ, self.name)
validator.check_integer('labels_rank', len(labels_shape), 2, Rel.EQ, self.name)
validator.check_integer('input_length_rank', len(input_length_shape), 1, Rel.EQ, self.name)
validator.check_integer('label_length_rank', len(label_length_shape), 1, Rel.EQ, self.name)
validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2]-1, Rel.EQ, self.name)
validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
costs_shape = (acts_shape[0],)
return (costs_shape, acts_shape)

def infer_dtype(self, acts_type, labels_type, input_length_type, label_length_type):
validator.check_subclass("acts_type", acts_type, mstype.tensor, self.name)
validator.check_subclass("labels_type", labels_type, mstype.tensor, self.name)
validator.check_subclass("input_length_type", input_length_type, mstype.tensor, self.name)
validator.check_subclass("label_length_type", label_length_type, mstype.tensor, self.name)
validator.check_tensor_type_same({"acts_type": acts_type}, [mstype.float32], self.name)
validator.check_tensor_type_same({"labels_type": labels_type}, [mstype.int32], self.name)
validator.check_tensor_type_same({"input_length_type": input_length_type}, [mstype.int32], self.name)
validator.check_tensor_type_same({"label_length_type": label_length_type}, [mstype.int32], self.name)
return (acts_type, acts_type)


class SGD(PrimitiveWithInfer): class SGD(PrimitiveWithInfer):
""" """
Computes stochastic gradient descent (optionally with momentum). Computes stochastic gradient descent (optionally with momentum).


+ 58
- 0
mindspore/ops/operations/random_ops.py View File

@@ -108,3 +108,61 @@ class Normal(PrimitiveWithInfer):
"dtype": mstype.float32, "dtype": mstype.float32,
"value": None} "value": None}
return out return out


class RandomCategorical(PrimitiveWithInfer):
"""
Generates random samples from a given categorical distribution tensor.

Args:
dtype (mindspore.dtype): The type of output. Its value should be one of [mindspore.int16,
mindspore.int32, mindspore.int64]. Default: mindspore.int64.

Inputs:
- **logits** (Tensor) - The input tensor. 2-D Tensor with shape [batch_size, num_classes].
- **num_sample** (int) - Number of sample to be drawn. Only constant values is allowed.
- **seed** (int) - Random seed. Default: 0. Only constant values is allowed.

Outputs:
- **output** (Tensor) - The output Tensor with shape [batch_size, num_samples].

Examples:
>>> class Net(nn.Cell):
>>> def __init__(self, num_sample):
>>> super(Net, self).__init__()
>>> self.random_categorical = P.RandomCategorical(mindspore.int64)
>>> self.num_sample = num_sample
>>> def construct(self, logits, seed=0):
>>> return self.random_categorical(logits, self.num_sample, seed)
>>>
>>> x = np.random.random((10, 5)).astype(np.float32)
>>> net = Net(8)
>>> output = net(Tensor(x))
"""
@prim_attr_register
def __init__(self, dtype=mstype.int64):
"""Init RandomCategorical"""
self.dtype = dtype

valid_values = (mstype.int32, mstype.int16, mstype.int64)
validator.check_type_name("dtype", dtype, valid_values, self.name)
self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'],
outputs=['output'])

def __infer__(self, logits, num_samples, seed):
logits_dtype = logits['dtype']
valid_types = (mstype.float32, mstype.float16, mstype.float64)
validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name)
num_samples_v = num_samples['value']
seed_v = seed['value']
validator.check_value_type('num_samples', num_samples_v, (int,), self.name)
validator.check_value_type('seed', seed_v, (int,), self.name)
validator.check_integer("num_samples", num_samples_v, 0, Rel.GT, self.name)
x_shape = list(logits['shape'])
if len(x_shape) != 2:
raise ValueError("RandomCategorical shape should be 2-dimension.")
ndim = len(x_shape) - 1
x_shape[ndim] = num_samples_v
return {'shape': (x_shape),
'dtype': (self.dtype),
'value': None}

+ 38
- 0
tests/st/ops/ascend/test_aicpu_ops/test_random_categorical.py View File

@@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self, num_sample):
super(Net, self).__init__()
self.random_categorical = P.RandomCategorical(mindspore.int64)
self.num_sample = num_sample

def construct(self, logits, seed=0):
return self.random_categorical(logits, self.num_sample, seed)

def test_net():
x = np.random.random((10, 5)).astype(np.float32)
net = Net(8)
output = net(Tensor(x))
print(x)
print(output.asnumpy())
#print(output.dtype())

+ 41
- 0
tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py View File

@@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.rnnt_loss = P.RNNTLoss(blank_label=0)

def construct(self, acts, labels, act_lens, label_lens):
return self.rnnt_loss(acts, labels, act_lens, label_lens)


def test_net():
B, T, U, V = 1, 2, 3, 5
acts = np.random.random((B, T, U, V)).astype(np.float32)
labels = np.array([[np.random.randint(1, V-1) for _ in range(U-1)]]).astype(np.int32)
input_length = np.array([T] * B).astype(np.int32)
label_length = np.array([len(l) for l in labels]).astype(np.int32)
rnnt_loss = Net()
costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
print(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
print(costs.asnumpy())
print(grads.asnumpy())

Loading…
Cancel
Save