You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_reverse_sequence.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import pytest
  2. import numpy as np
  3. import mindspore.context as context
  4. import mindspore.nn as nn
  5. from mindspore import Tensor
  6. from mindspore.common.api import ms_function
  7. from mindspore.ops import operations as P
  8. class Net(nn.Cell):
  9. def __init__(self, seq_dim, batch_dim):
  10. super(Net, self).__init__()
  11. self.reverse_sequence = P.ReverseSequence(
  12. seq_dim=seq_dim, batch_dim=batch_dim)
  13. @ms_function
  14. def construct(self, x, seq_lengths):
  15. return self.reverse_sequence(x, seq_lengths)
  16. @pytest.mark.level0
  17. @pytest.mark.platform_x86_gpu_training
  18. @pytest.mark.env_onecard
  19. def test_net_int8():
  20. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  21. x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.int8)
  22. seq_lengths = np.array([1, 2, 3]).astype(np.int32)
  23. seq_dim = 0
  24. batch_dim = 1
  25. net = Net(seq_dim, batch_dim)
  26. output = net(Tensor(x), Tensor(seq_lengths))
  27. expected = np.array([[1, 5, 9], [4, 2, 6], [7, 8, 3]]).astype(np.int8)
  28. assert np.array_equal(output.asnumpy(), expected)
  29. @pytest.mark.level0
  30. @pytest.mark.platform_x86_gpu_training
  31. @pytest.mark.env_onecard
  32. def test_net_int32():
  33. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  34. x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.int32)
  35. seq_lengths = np.array([1, 2, 3]).astype(np.int64)
  36. seq_dim = 1
  37. batch_dim = 0
  38. net = Net(seq_dim, batch_dim)
  39. output = net(Tensor(x), Tensor(seq_lengths))
  40. expected = np.array([[1, 2, 3], [5, 4, 6], [9, 8, 7]]).astype(np.int32)
  41. assert np.array_equal(output.asnumpy(), expected)
  42. @pytest.mark.level0
  43. @pytest.mark.platform_x86_gpu_training
  44. @pytest.mark.env_onecard
  45. def test_net_float32():
  46. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  47. x = np.array([[[[1, 2], [3, 4]],
  48. [[5, 6], [7, 8]],
  49. [[9, 10], [11, 12]],
  50. [[13, 14], [15, 16]]],
  51. [[[17, 18], [19, 20]],
  52. [[21, 22], [23, 24]],
  53. [[25, 26], [27, 28]],
  54. [[29, 30], [31, 21]]]]).astype(np.float32)
  55. seq_lengths = np.array([2, 2, 2, 2]).astype(np.int64)
  56. seq_dim = 0
  57. batch_dim = 1
  58. net = Net(seq_dim, batch_dim)
  59. output = net(Tensor(x), Tensor(seq_lengths))
  60. expected = np.array([[[[17., 18.], [19., 20.]],
  61. [[21., 22.], [23., 24.]],
  62. [[25., 26.], [27., 28.]],
  63. [[29., 30.], [31., 21.]]],
  64. [[[1., 2.], [3., 4.]],
  65. [[5., 6.], [7., 8.]],
  66. [[9., 10.], [11., 12.]],
  67. [[13., 14.], [15., 16.]]]]).astype(np.float32)
  68. assert np.array_equal(output.asnumpy(), expected)
  69. @pytest.mark.level0
  70. @pytest.mark.platform_x86_gpu_training
  71. @pytest.mark.env_onecard
  72. def test_net_float64_0_dim():
  73. """
  74. Test added to test for 0 seq len edge case
  75. """
  76. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  77. x = np.array([[[[1, 2], [3, 4]],
  78. [[5, 6], [7, 8]],
  79. [[9, 10], [11, 12]],
  80. [[13, 14], [15, 16]]],
  81. [[[17, 18], [19, 20]],
  82. [[21, 22], [23, 24]],
  83. [[25, 26], [27, 28]],
  84. [[29, 30], [31, 21]]]]).astype(np.float32)
  85. seq_lengths = np.array([2, 2, 0, 0]).astype(np.int64)
  86. seq_dim = 2
  87. batch_dim = 1
  88. net = Net(seq_dim, batch_dim)
  89. output = net(Tensor(x), Tensor(seq_lengths))
  90. expected = np.array([[[[3., 4.], [1., 2.]],
  91. [[7., 8.], [5., 6.]],
  92. [[9., 10.], [11., 12.]],
  93. [[13., 14.], [15., 16.]]],
  94. [[[19., 20.], [17., 18.]],
  95. [[23., 24.], [21., 22.]],
  96. [[25., 26.], [27., 28.]],
  97. [[29., 30.], [31., 21.]]]]).astype(np.float32)
  98. assert np.array_equal(output.asnumpy(), expected)