Browse Source

!8310 fix example of categorical and rnntloss

From: @yanzhenxiang2020
Reviewed-by: @c_34
Signed-off-by: @c_34
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
d6f6269ff1
3 changed files with 15 additions and 5 deletions
  1. +4
    -3
      mindspore/ops/operations/array_ops.py
  2. +10
    -1
      mindspore/ops/operations/nn_ops.py
  3. +1
    -1
      tests/st/ops/ascend/test_aicpu_ops/test_ctc_greedy_decoder.py

+ 4
- 3
mindspore/ops/operations/array_ops.py View File

@@ -4154,8 +4154,8 @@ class Meshgrid(PrimitiveWithInfer):

Args:
indexing (str): Either 'xy' or 'ij'. Default: 'xy'.
When the indexing argument is set to 'xy' (the default),
the broadcasting instructions for the first two dimensions are swapped.
When the indexing argument is set to 'xy' (the default), the broadcasting
instructions for the first two dimensions are swapped.

Inputs:
- **input_x** (Union[tuple, list]) - A Tuple or list of N 1-D Tensor objects.
@@ -4170,7 +4170,8 @@ class Meshgrid(PrimitiveWithInfer):
>>> z = np.array([8, 9, 0, 1, 2]).astype(np.int32)
>>> inputs = (x, y, z)
>>> meshgrid = ops.Meshgrid(indexing="xy")
>>> meshgrid(inputs)
>>> output = meshgrid(inputs)
>>> print(output)
(Tensor(shape=[3, 4, 6], dtype=UInt32, value=
[[[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],


+ 10
- 1
mindspore/ops/operations/nn_ops.py View File

@@ -2261,8 +2261,17 @@ class RNNTLoss(PrimitiveWithInfer):
>>> labels = np.array([[1, 2]]).astype(np.int32)
>>> input_length = np.array([T] * B).astype(np.int32)
>>> label_length = np.array([len(l) for l in labels]).astype(np.int32)
>>> rnnt_loss = ops.RNNTLoss(blank_label=blank)
>>> rnnt_loss = ops.RNNTLoss(blank_label=0)
>>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
>>> print(costs)
[-3.5036912]
>>> print(grads)
[[[[-0.35275543 -0.64724463 0. 0. 0. ]
[-0.19174816 0. -0.45549652 0. 0. ]
[-0.45549664 0. 0. 0. 0. ]]
[[0. -0.35275543 0. 0. 0. ]
[0. 0. -0.5445037 0. 0. ]
[-1.00000002 0. 0. 0. 0. ]]]]
"""

@prim_attr_register


+ 1
- 1
tests/st/ops/ascend/test_aicpu_ops/test_ctc_greedy_decoder.py View File

@@ -52,7 +52,7 @@ def test_net_assert():
out_expect0 = np.array([0, 0, 0, 1, 1, 0]).reshape(3, 2)
out_expect1 = np.array([0, 1, 1])
out_expect2 = np.array([2, 2])
out_expect3 = np.array([-0.7443749, 0.18251707]).reshape(2, 1)
out_expect3 = np.array([-0.7443749, 0.18251707]).astype(np.float32).reshape(2, 1)
assert np.array_equal(output[0].asnumpy(), out_expect0)
assert np.array_equal(output[1].asnumpy(), out_expect1)
assert np.array_equal(output[2].asnumpy(), out_expect2)


Loading…
Cancel
Save