Browse Source

fix example of categorical and rnntloss

tags/v1.1.0
yanzhenxiang2020 5 years ago
parent
commit
dca109c9a5
3 changed files with 15 additions and 5 deletions
  1. +4
    -3
      mindspore/ops/operations/array_ops.py
  2. +10
    -1
      mindspore/ops/operations/nn_ops.py
  3. +1
    -1
      tests/st/ops/ascend/test_aicpu_ops/test_ctc_greedy_decoder.py

+ 4
- 3
mindspore/ops/operations/array_ops.py View File

@@ -4151,8 +4151,8 @@ class Meshgrid(PrimitiveWithInfer):

Args:
indexing (str): Either 'xy' or 'ij'. Default: 'xy'.
When the indexing argument is set to 'xy' (the default),
the broadcasting instructions for the first two dimensions are swapped.
When the indexing argument is set to 'xy' (the default), the broadcasting
instructions for the first two dimensions are swapped.

Inputs:
- **input_x** (Union[tuple, list]) - A Tuple or list of N 1-D Tensor objects.
@@ -4167,7 +4167,8 @@ class Meshgrid(PrimitiveWithInfer):
>>> z = np.array([8, 9, 0, 1, 2]).astype(np.int32)
>>> inputs = (x, y, z)
>>> meshgrid = ops.Meshgrid(indexing="xy")
>>> meshgrid(inputs)
>>> output = meshgrid(inputs)
>>> print(output)
(Tensor(shape=[3, 4, 6], dtype=UInt32, value=
[[[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],


+ 10
- 1
mindspore/ops/operations/nn_ops.py View File

@@ -2261,8 +2261,17 @@ class RNNTLoss(PrimitiveWithInfer):
>>> labels = np.array([[1, 2]]).astype(np.int32)
>>> input_length = np.array([T] * B).astype(np.int32)
>>> label_length = np.array([len(l) for l in labels]).astype(np.int32)
>>> rnnt_loss = ops.RNNTLoss(blank_label=blank)
>>> rnnt_loss = ops.RNNTLoss(blank_label=0)
>>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
>>> print(costs)
[-3.5036912]
>>> print(grads)
[[[[-0.35275543 -0.64724463 0. 0. 0. ]
[-0.19174816 0. -0.45549652 0. 0. ]
[-0.45549664 0. 0. 0. 0. ]]
[[0. -0.35275543 0. 0. 0. ]
[0. 0. -0.5445037 0. 0. ]
[-1.00000002 0. 0. 0. 0. ]]]]
"""

@prim_attr_register


+ 1
- 1
tests/st/ops/ascend/test_aicpu_ops/test_ctc_greedy_decoder.py View File

@@ -52,7 +52,7 @@ def test_net_assert():
out_expect0 = np.array([0, 0, 0, 1, 1, 0]).reshape(3, 2)
out_expect1 = np.array([0, 1, 1])
out_expect2 = np.array([2, 2])
out_expect3 = np.array([-0.7443749, 0.18251707]).reshape(2, 1)
out_expect3 = np.array([-0.7443749, 0.18251707]).astype(np.float32).reshape(2, 1)
assert np.array_equal(output[0].asnumpy(), out_expect0)
assert np.array_equal(output[1].asnumpy(), out_expect1)
assert np.array_equal(output[2].asnumpy(), out_expect2)


Loading…
Cancel
Save