You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_slice_op.py 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing Slice op in DE
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.transforms.c_transforms as ops
  22. def slice_compare(array, indexing, expected_array):
  23. data = ds.NumpySlicesDataset([array])
  24. if isinstance(indexing, list) and indexing and not isinstance(indexing[0], int):
  25. data = data.map(operations=ops.Slice(*indexing))
  26. else:
  27. data = data.map(operations=ops.Slice(indexing))
  28. for d in data.create_dict_iterator(output_numpy=True):
  29. np.testing.assert_array_equal(expected_array, d['column_0'])
  30. def test_slice_all():
  31. slice_compare([1, 2, 3, 4, 5], None, [1, 2, 3, 4, 5])
  32. slice_compare([1, 2, 3, 4, 5], ..., [1, 2, 3, 4, 5])
  33. slice_compare([1, 2, 3, 4, 5], True, [1, 2, 3, 4, 5])
  34. def test_slice_single_index():
  35. slice_compare([1, 2, 3, 4, 5], 0, [1])
  36. slice_compare([1, 2, 3, 4, 5], -3, [3])
  37. slice_compare([1, 2, 3, 4, 5], [0], [1])
  38. def test_slice_indices_multidim():
  39. slice_compare([[1, 2, 3, 4, 5]], [[0], [0]], 1)
  40. slice_compare([[1, 2, 3, 4, 5]], [[0], [0, 3]], [[1, 4]])
  41. slice_compare([[1, 2, 3, 4, 5]], [0], [[1, 2, 3, 4, 5]])
  42. slice_compare([[1, 2, 3, 4, 5]], [[0], [0, -4]], [[1, 2]])
  43. def test_slice_list_index():
  44. slice_compare([1, 2, 3, 4, 5], [0, 1, 4], [1, 2, 5])
  45. slice_compare([1, 2, 3, 4, 5], [4, 1, 0], [5, 2, 1])
  46. slice_compare([1, 2, 3, 4, 5], [-1, 1, 0], [5, 2, 1])
  47. slice_compare([1, 2, 3, 4, 5], [-1, -4, -2], [5, 2, 4])
  48. slice_compare([1, 2, 3, 4, 5], [3, 3, 3], [4, 4, 4])
  49. def test_slice_index_and_slice():
  50. slice_compare([[1, 2, 3, 4, 5]], [slice(0, 1), [4]], [[5]])
  51. slice_compare([[1, 2, 3, 4, 5]], [[0], slice(0, 2)], [[1, 2]])
  52. slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [[1], slice(2, 4, 1)], [[7, 8]])
  53. def test_slice_slice_obj_1s():
  54. slice_compare([1, 2, 3, 4, 5], slice(1), [1])
  55. slice_compare([1, 2, 3, 4, 5], slice(4), [1, 2, 3, 4])
  56. slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(2), slice(2)], [[1, 2], [5, 6]])
  57. slice_compare([1, 2, 3, 4, 5], slice(10), [1, 2, 3, 4, 5])
  58. def test_slice_slice_obj_2s():
  59. slice_compare([1, 2, 3, 4, 5], slice(0, 2), [1, 2])
  60. slice_compare([1, 2, 3, 4, 5], slice(2, 4), [3, 4])
  61. slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2), slice(1, 2)], [[2], [6]])
  62. slice_compare([1, 2, 3, 4, 5], slice(4, 10), [5])
  63. def test_slice_slice_obj_2s_multidim():
  64. slice_compare([[1, 2, 3, 4, 5]], [slice(0, 1)], [[1, 2, 3, 4, 5]])
  65. slice_compare([[1, 2, 3, 4, 5]], [slice(0, 1), slice(4)], [[1, 2, 3, 4]])
  66. slice_compare([[1, 2, 3, 4, 5]], [slice(0, 1), slice(0, 3)], [[1, 2, 3]])
  67. slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2, 2), slice(2, 4, 1)], [[3, 4]])
  68. slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(1, 0, -1), slice(1)], [[5]])
  69. def test_slice_slice_obj_3s():
  70. """
  71. Test passing in all parameters to the slice objects
  72. """
  73. slice_compare([1, 2, 3, 4, 5], slice(0, 2, 1), [1, 2])
  74. slice_compare([1, 2, 3, 4, 5], slice(0, 4, 1), [1, 2, 3, 4])
  75. slice_compare([1, 2, 3, 4, 5], slice(0, 10, 1), [1, 2, 3, 4, 5])
  76. slice_compare([1, 2, 3, 4, 5], slice(0, 5, 2), [1, 3, 5])
  77. slice_compare([1, 2, 3, 4, 5], slice(0, 2, 2), [1])
  78. slice_compare([1, 2, 3, 4, 5], slice(0, 1, 2), [1])
  79. slice_compare([1, 2, 3, 4, 5], slice(4, 5, 1), [5])
  80. slice_compare([1, 2, 3, 4, 5], slice(2, 5, 3), [3])
  81. slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2, 1)], [[1, 2, 3, 4], [5, 6, 7, 8]])
  82. slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2, 3)], [[1, 2, 3, 4]])
  83. slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2, 2), slice(0, 1, 2)], [[1]])
  84. slice_compare([[1, 2, 3, 4], [5, 6, 7, 8]], [slice(0, 2, 1), slice(0, 1, 2)], [[1], [5]])
  85. slice_compare([[[1, 2, 3, 4], [5, 6, 7, 8]], [[1, 2, 3, 4], [5, 6, 7, 8]]],
  86. [slice(0, 2, 1), slice(0, 1, 1), slice(0, 4, 2)],
  87. [[[1, 3]], [[1, 3]]])
  88. def test_slice_obj_3s_double():
  89. slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 1), [1., 2.])
  90. slice_compare([1., 2., 3., 4., 5.], slice(0, 4, 1), [1., 2., 3., 4.])
  91. slice_compare([1., 2., 3., 4., 5.], slice(0, 5, 2), [1., 3., 5.])
  92. slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 2), [1.])
  93. slice_compare([1., 2., 3., 4., 5.], slice(0, 1, 2), [1.])
  94. slice_compare([1., 2., 3., 4., 5.], slice(4, 5, 1), [5.])
  95. slice_compare([1., 2., 3., 4., 5.], slice(2, 5, 3), [3.])
  96. def test_out_of_bounds_slicing():
  97. """
  98. Test passing indices outside of the input to the slice objects
  99. """
  100. slice_compare([1, 2, 3, 4, 5], slice(-15, -1), [1, 2, 3, 4])
  101. slice_compare([1, 2, 3, 4, 5], slice(-15, 15), [1, 2, 3, 4, 5])
  102. slice_compare([1, 2, 3, 4], slice(-15, -7), [])
  103. def test_slice_multiple_rows():
  104. """
  105. Test passing in multiple rows
  106. """
  107. dataset = [[1], [3, 4, 5], [1, 2], [1, 2, 3, 4, 5, 6, 7]]
  108. exp_dataset = [[], [4, 5], [2], [2, 3, 4]]
  109. def gen():
  110. for row in dataset:
  111. yield (np.array(row),)
  112. data = ds.GeneratorDataset(gen, column_names=["col"])
  113. indexing = slice(1, 4)
  114. data = data.map(operations=ops.Slice(indexing))
  115. for (d, exp_d) in zip(data.create_dict_iterator(output_numpy=True), exp_dataset):
  116. np.testing.assert_array_equal(exp_d, d['col'])
  117. def test_slice_obj_neg():
  118. slice_compare([1, 2, 3, 4, 5], slice(-1, -5, -1), [5, 4, 3, 2])
  119. slice_compare([1, 2, 3, 4, 5], slice(-1), [1, 2, 3, 4])
  120. slice_compare([1, 2, 3, 4, 5], slice(-2), [1, 2, 3])
  121. slice_compare([1, 2, 3, 4, 5], slice(-1, -5, -2), [5, 3])
  122. slice_compare([1, 2, 3, 4, 5], slice(-5, -1, 2), [1, 3])
  123. slice_compare([1, 2, 3, 4, 5], slice(-5, -1), [1, 2, 3, 4])
  124. def test_slice_all_str():
  125. slice_compare([b"1", b"2", b"3", b"4", b"5"], None, [b"1", b"2", b"3", b"4", b"5"])
  126. slice_compare([b"1", b"2", b"3", b"4", b"5"], ..., [b"1", b"2", b"3", b"4", b"5"])
  127. def test_slice_single_index_str():
  128. slice_compare([b"1", b"2", b"3", b"4", b"5"], [0, 1], [b"1", b"2"])
  129. slice_compare([b"1", b"2", b"3", b"4", b"5"], [0, 1], [b"1", b"2"])
  130. slice_compare([b"1", b"2", b"3", b"4", b"5"], [4], [b"5"])
  131. slice_compare([b"1", b"2", b"3", b"4", b"5"], [-1], [b"5"])
  132. slice_compare([b"1", b"2", b"3", b"4", b"5"], [-5], [b"1"])
  133. def test_slice_indexes_multidim_str():
  134. slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [[0], 0], [[b"1"]])
  135. slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [[0], [0, 1]], [[b"1", b"2"]])
  136. def test_slice_list_index_str():
  137. slice_compare([b"1", b"2", b"3", b"4", b"5"], [0, 1, 4], [b"1", b"2", b"5"])
  138. slice_compare([b"1", b"2", b"3", b"4", b"5"], [4, 1, 0], [b"5", b"2", b"1"])
  139. slice_compare([b"1", b"2", b"3", b"4", b"5"], [3, 3, 3], [b"4", b"4", b"4"])
  140. # test str index object here
  141. def test_slice_index_and_slice_str():
  142. slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [slice(0, 1), 4], [[b"5"]])
  143. slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [[0], slice(0, 2)], [[b"1", b"2"]])
  144. slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]], [[1], slice(2, 4, 1)],
  145. [[b"7", b"8"]])
  146. def test_slice_slice_obj_1s_str():
  147. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(1), [b"1"])
  148. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(4), [b"1", b"2", b"3", b"4"])
  149. slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]],
  150. [slice(2), slice(2)],
  151. [[b"1", b"2"], [b"5", b"6"]])
  152. def test_slice_slice_obj_2s_str():
  153. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2), [b"1", b"2"])
  154. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(2, 4), [b"3", b"4"])
  155. slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]],
  156. [slice(0, 2), slice(1, 2)], [[b"2"], [b"6"]])
  157. def test_slice_slice_obj_2s_multidim_str():
  158. slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [slice(0, 1)], [[b"1", b"2", b"3", b"4", b"5"]])
  159. slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [slice(0, 1), slice(4)],
  160. [[b"1", b"2", b"3", b"4"]])
  161. slice_compare([[b"1", b"2", b"3", b"4", b"5"]], [slice(0, 1), slice(0, 3)],
  162. [[b"1", b"2", b"3"]])
  163. slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]],
  164. [slice(0, 2, 2), slice(2, 4, 1)],
  165. [[b"3", b"4"]])
  166. def test_slice_slice_obj_3s_str():
  167. """
  168. Test passing in all parameters to the slice objects
  169. """
  170. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2, 1), [b"1", b"2"])
  171. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 4, 1), [b"1", b"2", b"3", b"4"])
  172. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 5, 2), [b"1", b"3", b"5"])
  173. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2, 2), [b"1"])
  174. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 1, 2), [b"1"])
  175. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(4, 5, 1), [b"5"])
  176. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(2, 5, 3), [b"3"])
  177. slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]], [slice(0, 2, 1)],
  178. [[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]])
  179. slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]], slice(0, 2, 3), [[b"1", b"2", b"3", b"4"]])
  180. slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]],
  181. [slice(0, 2, 2), slice(0, 1, 2)], [[b"1"]])
  182. slice_compare([[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]],
  183. [slice(0, 2, 1), slice(0, 1, 2)],
  184. [[b"1"], [b"5"]])
  185. slice_compare([[[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]],
  186. [[b"1", b"2", b"3", b"4"], [b"5", b"6", b"7", b"8"]]],
  187. [slice(0, 2, 1), slice(0, 1, 1), slice(0, 4, 2)],
  188. [[[b"1", b"3"]], [[b"1", b"3"]]])
  189. def test_slice_obj_neg_str():
  190. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, -1), [b"5", b"4", b"3", b"2"])
  191. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1), [b"1", b"2", b"3", b"4"])
  192. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-2), [b"1", b"2", b"3"])
  193. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, -2), [b"5", b"3"])
  194. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-5, -1, 2), [b"1", b"3"])
  195. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-5, -1), [b"1", b"2", b"3", b"4"])
  196. def test_out_of_bounds_slicing_str():
  197. """
  198. Test passing indices outside of the input to the slice objects
  199. """
  200. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-15, -1), [b"1", b"2", b"3", b"4"])
  201. slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-15, 15), [b"1", b"2", b"3", b"4", b"5"])
  202. indexing = slice(-15, -7)
  203. expected_array = np.array([], dtype="S")
  204. data = [b"1", b"2", b"3", b"4", b"5"]
  205. data = ds.NumpySlicesDataset([data])
  206. data = data.map(operations=ops.Slice(indexing))
  207. for d in data.create_dict_iterator(output_numpy=True):
  208. np.testing.assert_array_equal(expected_array, d['column_0'])
  209. def test_slice_exceptions():
  210. """
  211. Test passing in invalid parameters
  212. """
  213. with pytest.raises(RuntimeError) as info:
  214. slice_compare([b"1", b"2", b"3", b"4", b"5"], [5], [b"1", b"2", b"3", b"4", b"5"])
  215. assert "Index 5 is out of bounds." in str(info.value)
  216. with pytest.raises(RuntimeError) as info:
  217. slice_compare([b"1", b"2", b"3", b"4", b"5"], [], [b"1", b"2", b"3", b"4", b"5"])
  218. assert "Both indices and slices can not be empty." in str(info.value)
  219. with pytest.raises(TypeError) as info:
  220. slice_compare([b"1", b"2", b"3", b"4", b"5"], [[[0, 1]]], [b"1", b"2", b"3", b"4", b"5"])
  221. assert "Argument slice_option[0] with value [0, 1] is not of type " \
  222. "(<class 'int'>,)." in str(info.value)
  223. with pytest.raises(TypeError) as info:
  224. slice_compare([b"1", b"2", b"3", b"4", b"5"], [[slice(3)]], [b"1", b"2", b"3", b"4", b"5"])
  225. assert "Argument slice_option[0] with value slice(None, 3, None) is not of type " \
  226. "(<class 'int'>,)." in str(info.value)
  227. if __name__ == "__main__":
  228. test_slice_all()
  229. test_slice_single_index()
  230. test_slice_indices_multidim()
  231. test_slice_list_index()
  232. test_slice_index_and_slice()
  233. test_slice_slice_obj_1s()
  234. test_slice_slice_obj_2s()
  235. test_slice_slice_obj_2s_multidim()
  236. test_slice_slice_obj_3s()
  237. test_slice_obj_3s_double()
  238. test_slice_multiple_rows()
  239. test_slice_obj_neg()
  240. test_slice_all_str()
  241. test_slice_single_index_str()
  242. test_slice_indexes_multidim_str()
  243. test_slice_list_index_str()
  244. test_slice_index_and_slice_str()
  245. test_slice_slice_obj_1s_str()
  246. test_slice_slice_obj_2s_str()
  247. test_slice_slice_obj_2s_multidim_str()
  248. test_slice_slice_obj_3s_str()
  249. test_slice_obj_neg_str()
  250. test_out_of_bounds_slicing_str()
  251. test_slice_exceptions()