You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset_numpy_slices.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore.dataset as de
  17. from mindspore import log as logger
  18. import mindspore.dataset.transforms.vision.c_transforms as vision
  19. import pandas as pd
  20. def test_numpy_slices_list_1():
  21. logger.info("Test Slicing a 1D list.")
  22. np_data = [1, 2, 3]
  23. ds = de.NumpySlicesDataset(np_data, shuffle=False)
  24. for i, data in enumerate(ds):
  25. assert data[0] == np_data[i]
  26. def test_numpy_slices_list_2():
  27. logger.info("Test Slicing a 2D list into 1D list.")
  28. np_data = [[1, 2], [3, 4]]
  29. ds = de.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False)
  30. for i, data in enumerate(ds):
  31. assert np.equal(data[0], np_data[i]).all()
  32. def test_numpy_slices_list_3():
  33. logger.info("Test Slicing list in the first dimension.")
  34. np_data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
  35. ds = de.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False)
  36. for i, data in enumerate(ds):
  37. assert np.equal(data[0], np_data[i]).all()
  38. def test_numpy_slices_list_append():
  39. logger.info("Test reading data of image list.")
  40. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  41. resize_height, resize_width = 2, 2
  42. data1 = de.TFRecordDataset(DATA_DIR)
  43. resize_op = vision.Resize((resize_height, resize_width))
  44. data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True), resize_op])
  45. res = []
  46. for data in data1.create_dict_iterator():
  47. res.append(data["image"])
  48. ds = de.NumpySlicesDataset(res, column_names=["col1"], shuffle=False)
  49. for i, data in enumerate(ds):
  50. assert np.equal(data, res[i]).all()
  51. def test_numpy_slices_dict_1():
  52. logger.info("Test Dictionary structure data.")
  53. np_data = {"a": [1, 2], "b": [3, 4]}
  54. ds = de.NumpySlicesDataset(np_data, shuffle=False)
  55. res = [[1, 3], [2, 4]]
  56. for i, data in enumerate(ds):
  57. assert data[0] == res[i][0]
  58. assert data[1] == res[i][1]
  59. def test_numpy_slices_dict_2():
  60. logger.info("Test input data is a tuple of Dictionary structure data.")
  61. data1, data2 = {"a": [1, 2]}, {"b": [3, 4]}
  62. ds = de.NumpySlicesDataset((data1, data2), column_names=["col1", "col2"], shuffle=False)
  63. res = [[1, 3], [2, 4]]
  64. for i, data in enumerate(ds):
  65. assert data[0] == res[i][0]
  66. assert data[1] == res[i][1]
  67. def test_numpy_slices_tuple_1():
  68. logger.info("Test slicing a list of tuple.")
  69. np_data = [([1, 2], [3, 4]), ([11, 12], [13, 14]), ([21, 22], [23, 24])]
  70. res = [[[1, 2], [11, 12], [21, 22]], [[3, 4], [13, 14], [23, 24]]]
  71. ds = de.NumpySlicesDataset(np_data, shuffle=False)
  72. for i, data in enumerate(ds):
  73. assert np.equal(data[0], res[i][0]).all()
  74. assert np.equal(data[1], res[i][1]).all()
  75. assert np.equal(data[2], res[i][2]).all()
  76. assert sum([1 for _ in ds]) == 2
  77. def test_numpy_slices_tuple_2():
  78. logger.info("Test reading different dimension of tuple data.")
  79. features, labels = np.random.sample((5, 2)), np.random.sample((5, 1))
  80. data = (features, labels)
  81. ds = de.NumpySlicesDataset(data, column_names=["col1", "col2"], shuffle=False)
  82. for i, data in enumerate(ds):
  83. assert np.equal(data[0], features[i]).all()
  84. assert data[1] == labels[i]
  85. def test_numpy_slices_csv_value():
  86. logger.info("Test loading value of csv file.")
  87. csv_file = "../data/dataset/testNumpySlicesDataset/heart.csv"
  88. df = pd.read_csv(csv_file)
  89. target = df.pop("target")
  90. df.pop("state")
  91. np_data = (df.values, target.values)
  92. ds = de.NumpySlicesDataset(np_data, column_names=["col1", "col2"], shuffle=False)
  93. for i, data in enumerate(ds):
  94. assert np.equal(np_data[0][i], data[0]).all()
  95. assert np.equal(np_data[1][i], data[1]).all()
  96. def test_numpy_slices_csv_dict():
  97. logger.info("Test loading csv file as dict.")
  98. csv_file = "../data/dataset/testNumpySlicesDataset/heart.csv"
  99. df = pd.read_csv(csv_file)
  100. df.pop("state")
  101. res = df.values
  102. ds = de.NumpySlicesDataset(dict(df), shuffle=False)
  103. for i, data in enumerate(ds):
  104. assert np.equal(data, res[i]).all()
  105. def test_numpy_slices_num_samplers():
  106. logger.info("Test num_samplers.")
  107. np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
  108. ds = de.NumpySlicesDataset(np_data, shuffle=False, num_samples=2)
  109. for i, data in enumerate(ds):
  110. assert np.equal(data[0], np_data[i]).all()
  111. assert sum([1 for _ in ds]) == 2
  112. def test_numpy_slices_distributed_sampler():
  113. logger.info("Test distributed sampler.")
  114. np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
  115. ds = de.NumpySlicesDataset(np_data, shuffle=False, shard_id=0, num_shards=4)
  116. for i, data in enumerate(ds):
  117. assert np.equal(data[0], np_data[i * 4]).all()
  118. assert sum([1 for _ in ds]) == 2
  119. def test_numpy_slices_sequential_sampler():
  120. logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.")
  121. np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
  122. ds = de.NumpySlicesDataset(np_data, sampler=de.SequentialSampler()).repeat(2)
  123. for i, data in enumerate(ds):
  124. assert np.equal(data[0], np_data[i % 8]).all()
  125. if __name__ == "__main__":
  126. test_numpy_slices_list_1()
  127. test_numpy_slices_list_2()
  128. test_numpy_slices_list_3()
  129. test_numpy_slices_list_append()
  130. test_numpy_slices_dict_1()
  131. test_numpy_slices_dict_2()
  132. test_numpy_slices_tuple_1()
  133. test_numpy_slices_tuple_2()
  134. test_numpy_slices_csv_value()
  135. test_numpy_slices_csv_dict()
  136. test_numpy_slices_num_samplers()
  137. test_numpy_slices_distributed_sampler()
  138. test_numpy_slices_sequential_sampler()