You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sliding_window.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing SlidingWindow in mindspore.dataset
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.text as text
  22. def test_sliding_window_callable():
  23. """
  24. Test sliding window op is callable
  25. """
  26. op = text.SlidingWindow(2, 0)
  27. input1 = ["大", "家", "早", "上", "好"]
  28. expect = np.array([['大', '家'], ['家', '早'], ['早', '上'], ['上', '好']])
  29. result = op(input1)
  30. assert np.array_equal(result, expect)
  31. # test 2D input
  32. input2 = [["大", "家", "早", "上", "好"]]
  33. with pytest.raises(RuntimeError) as info:
  34. _ = op(input2)
  35. assert "SlidingWindow: SlidingWindow supports 1D input only for now." in str(info.value)
  36. # test input multiple tensors
  37. with pytest.raises(RuntimeError) as info:
  38. _ = op(input1, input1)
  39. assert "The op is OneToOne, can only accept one tensor as input." in str(info.value)
  40. def test_sliding_window_string():
  41. """ test sliding_window with string type"""
  42. inputs = [["大", "家", "早", "上", "好"]]
  43. expect = np.array([['大', '家'], ['家', '早'], ['早', '上'], ['上', '好']])
  44. dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
  45. dataset = dataset.map(operations=text.SlidingWindow(2, 0), input_columns=["text"])
  46. result = []
  47. for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  48. for i in range(data['text'].shape[0]):
  49. result.append([])
  50. for j in range(data['text'].shape[1]):
  51. result[i].append(data['text'][i][j].decode('utf8'))
  52. result = np.array(result)
  53. np.testing.assert_array_equal(result, expect)
  54. def test_sliding_window_number():
  55. inputs = [1]
  56. expect = np.array([[1]])
  57. def gen(nums):
  58. yield (np.array(nums),)
  59. dataset = ds.GeneratorDataset(gen(inputs), column_names=["number"])
  60. dataset = dataset.map(operations=text.SlidingWindow(1, -1), input_columns=["number"])
  61. for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  62. np.testing.assert_array_equal(data['number'], expect)
  63. def test_sliding_window_big_width():
  64. inputs = [[1, 2, 3, 4, 5]]
  65. expect = np.array([])
  66. dataset = ds.NumpySlicesDataset(inputs, column_names=["number"], shuffle=False)
  67. dataset = dataset.map(operations=text.SlidingWindow(30, 0), input_columns=["number"])
  68. for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  69. np.testing.assert_array_equal(data['number'], expect)
  70. def test_sliding_window_exception():
  71. try:
  72. _ = text.SlidingWindow(0, 0)
  73. assert False
  74. except ValueError:
  75. pass
  76. try:
  77. _ = text.SlidingWindow("1", 0)
  78. assert False
  79. except TypeError:
  80. pass
  81. try:
  82. _ = text.SlidingWindow(1, "0")
  83. assert False
  84. except TypeError:
  85. pass
  86. try:
  87. inputs = [[1, 2, 3, 4, 5]]
  88. dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
  89. dataset = dataset.map(operations=text.SlidingWindow(3, -100), input_columns=["text"])
  90. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  91. pass
  92. assert False
  93. except RuntimeError as e:
  94. assert "axis supports 0 or -1 only for now." in str(e)
  95. try:
  96. inputs = ["aa", "bb", "cc"]
  97. dataset = ds.NumpySlicesDataset(inputs, column_names=["text"], shuffle=False)
  98. dataset = dataset.map(operations=text.SlidingWindow(2, 0), input_columns=["text"])
  99. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  100. pass
  101. assert False
  102. except RuntimeError as e:
  103. assert "SlidingWindow supports 1D input only for now." in str(e)
  104. if __name__ == '__main__':
  105. test_sliding_window_callable()
  106. test_sliding_window_string()
  107. test_sliding_window_number()
  108. test_sliding_window_big_width()
  109. test_sliding_window_exception()