You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset_numpy_slices.py 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import sys
  16. import pytest
  17. import numpy as np
  18. import pandas as pd
  19. import mindspore.dataset as de
  20. from mindspore import log as logger
  21. import mindspore.dataset.vision.c_transforms as vision
  22. def test_numpy_slices_list_1():
  23. logger.info("Test Slicing a 1D list.")
  24. np_data = [1, 2, 3]
  25. ds = de.NumpySlicesDataset(np_data, shuffle=False)
  26. for i, data in enumerate(ds):
  27. assert data[0].asnumpy() == np_data[i]
  28. def test_numpy_slices_list_2():
  29. logger.info("Test Slicing a 2D list into 1D list.")
  30. np_data = [[1, 2], [3, 4]]
  31. ds = de.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False)
  32. for i, data in enumerate(ds):
  33. assert np.equal(data[0].asnumpy(), np_data[i]).all()
  34. def test_numpy_slices_list_3():
  35. logger.info("Test Slicing list in the first dimension.")
  36. np_data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
  37. ds = de.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False)
  38. for i, data in enumerate(ds):
  39. assert np.equal(data[0].asnumpy(), np_data[i]).all()
  40. def test_numpy_slices_numpy():
  41. logger.info("Test NumPy structure data.")
  42. np_data = np.array([[[1, 1], [2, 2]], [[3, 3], [4, 4]]])
  43. ds = de.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False)
  44. for i, data in enumerate(ds):
  45. assert np.equal(data[0].asnumpy(), np_data[i]).all()
  46. def test_numpy_slices_list_append():
  47. logger.info("Test reading data of image list.")
  48. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  49. resize_height, resize_width = 2, 2
  50. data1 = de.TFRecordDataset(DATA_DIR)
  51. resize_op = vision.Resize((resize_height, resize_width))
  52. data1 = data1.map(operations=[vision.Decode(True), resize_op], input_columns=["image"])
  53. res = []
  54. for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  55. res.append(data["image"])
  56. ds = de.NumpySlicesDataset(res, column_names=["col1"], shuffle=False)
  57. for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)):
  58. assert np.equal(data, res[i]).all()
  59. def test_numpy_slices_dict_1():
  60. logger.info("Test Dictionary structure data.")
  61. np_data = {"a": [1, 2], "b": [3, 4]}
  62. ds = de.NumpySlicesDataset(np_data, shuffle=False)
  63. res = [[1, 3], [2, 4]]
  64. for i, data in enumerate(ds):
  65. assert data[0].asnumpy() == res[i][0]
  66. assert data[1].asnumpy() == res[i][1]
  67. def test_numpy_slices_tuple_1():
  68. logger.info("Test slicing a list of tuple.")
  69. np_data = [([1, 2], [3, 4]), ([11, 12], [13, 14]), ([21, 22], [23, 24])]
  70. ds = de.NumpySlicesDataset(np_data, shuffle=False)
  71. for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)):
  72. assert np.equal(data, np_data[i]).all()
  73. assert sum([1 for _ in ds]) == 3
  74. def test_numpy_slices_tuple_2():
  75. logger.info("Test slicing a tuple of list.")
  76. np_data = ([1, 2], [3, 4], [5, 6])
  77. expected = [[1, 3, 5], [2, 4, 6]]
  78. ds = de.NumpySlicesDataset(np_data, shuffle=False)
  79. for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)):
  80. assert np.equal(data, expected[i]).all()
  81. assert sum([1 for _ in ds]) == 2
  82. def test_numpy_slices_tuple_3():
  83. logger.info("Test reading different dimension of tuple data.")
  84. features, labels = np.random.sample((5, 2)), np.random.sample((5, 1))
  85. data = (features, labels)
  86. ds = de.NumpySlicesDataset(data, column_names=["col1", "col2"], shuffle=False)
  87. for i, data in enumerate(ds):
  88. assert np.equal(data[0].asnumpy(), features[i]).all()
  89. assert data[1].asnumpy() == labels[i]
  90. def test_numpy_slices_csv_value():
  91. logger.info("Test loading value of csv file.")
  92. csv_file = "../data/dataset/testNumpySlicesDataset/heart.csv"
  93. df = pd.read_csv(csv_file)
  94. target = df.pop("target")
  95. df.pop("state")
  96. np_data = (df.values, target.values)
  97. ds = de.NumpySlicesDataset(np_data, column_names=["col1", "col2"], shuffle=False)
  98. for i, data in enumerate(ds):
  99. assert np.equal(np_data[0][i], data[0].asnumpy()).all()
  100. assert np.equal(np_data[1][i], data[1].asnumpy()).all()
  101. def test_numpy_slices_csv_dict():
  102. logger.info("Test loading csv file as dict.")
  103. csv_file = "../data/dataset/testNumpySlicesDataset/heart.csv"
  104. df = pd.read_csv(csv_file)
  105. df.pop("state")
  106. res = df.values
  107. ds = de.NumpySlicesDataset(dict(df), shuffle=False)
  108. for i, data in enumerate(ds.create_tuple_iterator(output_numpy=True)):
  109. assert np.equal(data, res[i]).all()
  110. def test_numpy_slices_num_samplers():
  111. logger.info("Test num_samplers.")
  112. np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
  113. ds = de.NumpySlicesDataset(np_data, shuffle=False, num_samples=2)
  114. for i, data in enumerate(ds):
  115. assert np.equal(data[0].asnumpy(), np_data[i]).all()
  116. assert sum([1 for _ in ds]) == 2
  117. def test_numpy_slices_distributed_sampler():
  118. logger.info("Test distributed sampler.")
  119. np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
  120. ds = de.NumpySlicesDataset(np_data, shuffle=False, shard_id=0, num_shards=4)
  121. for i, data in enumerate(ds):
  122. assert np.equal(data[0].asnumpy(), np_data[i * 4]).all()
  123. assert sum([1 for _ in ds]) == 2
  124. def test_numpy_slices_distributed_shard_limit():
  125. logger.info("Test Slicing a 1D list.")
  126. np_data = [1, 2, 3]
  127. num = sys.maxsize
  128. with pytest.raises(ValueError) as err:
  129. de.NumpySlicesDataset(np_data, num_shards=num, shard_id=0, shuffle=False)
  130. assert "Input num_shards is not within the required interval of [1, 2147483647]." in str(err.value)
  131. def test_numpy_slices_distributed_zero_shard():
  132. logger.info("Test Slicing a 1D list.")
  133. np_data = [1, 2, 3]
  134. with pytest.raises(ValueError) as err:
  135. de.NumpySlicesDataset(np_data, num_shards=0, shard_id=0, shuffle=False)
  136. assert "Input num_shards is not within the required interval of [1, 2147483647]." in str(err.value)
  137. def test_numpy_slices_sequential_sampler():
  138. logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.")
  139. np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
  140. ds = de.NumpySlicesDataset(np_data, sampler=de.SequentialSampler()).repeat(2)
  141. for i, data in enumerate(ds):
  142. assert np.equal(data[0].asnumpy(), np_data[i % 8]).all()
  143. def test_numpy_slices_invalid_column_names_type():
  144. logger.info("Test incorrect column_names input")
  145. np_data = [1, 2, 3]
  146. with pytest.raises(TypeError) as err:
  147. de.NumpySlicesDataset(np_data, column_names=[1], shuffle=False)
  148. assert "Argument column_names[0] with value 1 is not of type [<class 'str'>]" in str(err.value)
  149. def test_numpy_slices_invalid_column_names_string():
  150. logger.info("Test incorrect column_names input")
  151. np_data = [1, 2, 3]
  152. with pytest.raises(ValueError) as err:
  153. de.NumpySlicesDataset(np_data, column_names=[""], shuffle=False)
  154. assert "column_names[0] should not be empty" in str(err.value)
  155. def test_numpy_slices_invalid_empty_column_names():
  156. logger.info("Test incorrect column_names input")
  157. np_data = [1, 2, 3]
  158. with pytest.raises(ValueError) as err:
  159. de.NumpySlicesDataset(np_data, column_names=[], shuffle=False)
  160. assert "column_names should not be empty" in str(err.value)
  161. def test_numpy_slices_invalid_empty_data_column():
  162. logger.info("Test incorrect column_names input")
  163. np_data = []
  164. with pytest.raises(ValueError) as err:
  165. de.NumpySlicesDataset(np_data, shuffle=False)
  166. assert "Argument data cannot be empty" in str(err.value)
  167. def test_numpy_slice_empty_output_shape():
  168. logger.info("running test_numpy_slice_empty_output_shape")
  169. dataset = de.NumpySlicesDataset([[[1, 2], [3, 4]]], column_names=["col1"])
  170. dataset = dataset.batch(batch_size=3, drop_remainder=True)
  171. assert dataset.output_shapes() == []
  172. if __name__ == "__main__":
  173. test_numpy_slices_list_1()
  174. test_numpy_slices_list_2()
  175. test_numpy_slices_list_3()
  176. test_numpy_slices_list_append()
  177. test_numpy_slices_dict_1()
  178. test_numpy_slices_tuple_1()
  179. test_numpy_slices_tuple_2()
  180. test_numpy_slices_tuple_3()
  181. test_numpy_slices_csv_value()
  182. test_numpy_slices_csv_dict()
  183. test_numpy_slices_num_samplers()
  184. test_numpy_slices_distributed_sampler()
  185. test_numpy_slices_distributed_shard_limit()
  186. test_numpy_slices_distributed_zero_shard()
  187. test_numpy_slices_sequential_sampler()
  188. test_numpy_slices_invalid_column_names_type()
  189. test_numpy_slices_invalid_column_names_string()
  190. test_numpy_slices_invalid_empty_column_names()
  191. test_numpy_slices_invalid_empty_data_column()
  192. test_numpy_slice_empty_output_shape()