You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_concat.py 11 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore.common.dtype as mstype
  17. import mindspore.dataset as ds
  18. import mindspore.dataset.transforms.c_transforms as C
  19. import mindspore.dataset.transforms.py_transforms
  20. import mindspore.dataset.vision.py_transforms as F
  21. from mindspore import log as logger
  22. # In generator dataset: Number of rows is 3; its values are 0, 1, 2
  23. def generator():
  24. for i in range(3):
  25. yield (np.array([i]),)
  26. # In generator_10 dataset: Number of rows is 7; its values are 3, 4, 5 ... 9
  27. def generator_10():
  28. for i in range(3, 10):
  29. yield (np.array([i]),)
  30. # In generator_20 dataset: Number of rows is 10; its values are 10, 11, 12 ... 19
  31. def generator_20():
  32. for i in range(10, 20):
  33. yield (np.array([i]),)
  34. def test_concat_01():
  35. """
  36. Test concat: test concat 2 datasets that have the same column name and data type
  37. """
  38. logger.info("test_concat_01")
  39. data1 = ds.GeneratorDataset(generator, ["col1"])
  40. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  41. data3 = data1 + data2
  42. # Here i refers to index, d refers to data element
  43. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  44. t = d
  45. logger.info("data: %i", t[0][0])
  46. assert i == t[0][0]
  47. assert sum([1 for _ in data3]) == 10
  48. def test_concat_02():
  49. """
  50. Test concat: test concat 2 datasets using concat operation not "+" operation
  51. """
  52. logger.info("test_concat_02")
  53. data1 = ds.GeneratorDataset(generator, ["col1"])
  54. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  55. data3 = data1.concat(data2)
  56. # Here i refers to index, d refers to data element
  57. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  58. t = d
  59. logger.info("data: %i", t[0][0])
  60. assert i == t[0][0]
  61. assert sum([1 for _ in data3]) == 10
  62. def test_concat_03():
  63. """
  64. Test concat: test concat dataset that has different column
  65. """
  66. logger.info("test_concat_03")
  67. data1 = ds.GeneratorDataset(generator, ["col1"])
  68. data2 = ds.GeneratorDataset(generator_10, ["col2"])
  69. data3 = data1 + data2
  70. try:
  71. for _, _ in enumerate(data3):
  72. pass
  73. assert False
  74. except RuntimeError:
  75. pass
  76. def test_concat_04():
  77. """
  78. Test concat: test concat dataset that has different rank
  79. """
  80. logger.info("test_concat_04")
  81. data1 = ds.GeneratorDataset(generator, ["col1"])
  82. data2 = ds.GeneratorDataset(generator_10, ["col2"])
  83. data2 = data2.batch(3)
  84. data3 = data1 + data2
  85. try:
  86. for _, _ in enumerate(data3):
  87. pass
  88. assert False
  89. except RuntimeError:
  90. pass
  91. def test_concat_05():
  92. """
  93. Test concat: test concat dataset that has different data type
  94. """
  95. logger.info("test_concat_05")
  96. data1 = ds.GeneratorDataset(generator, ["col1"])
  97. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  98. type_cast_op = C.TypeCast(mstype.float32)
  99. data1 = data1.map(operations=type_cast_op, input_columns=["col1"])
  100. data3 = data1 + data2
  101. try:
  102. for _, _ in enumerate(data3):
  103. pass
  104. assert False
  105. except RuntimeError:
  106. pass
  107. def test_concat_06():
  108. """
  109. Test concat: test concat multi datasets in one time
  110. """
  111. logger.info("test_concat_06")
  112. data1 = ds.GeneratorDataset(generator, ["col1"])
  113. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  114. data3 = ds.GeneratorDataset(generator_20, ["col1"])
  115. dataset = data1 + data2 + data3
  116. # Here i refers to index, d refers to data element
  117. for i, d in enumerate(dataset.create_tuple_iterator(output_numpy=True)):
  118. t = d
  119. logger.info("data: %i", t[0][0])
  120. assert i == t[0][0]
  121. assert sum([1 for _ in dataset]) == 20
  122. def test_concat_07():
  123. """
  124. Test concat: test concat one dataset with multi datasets (datasets list)
  125. """
  126. logger.info("test_concat_07")
  127. data1 = ds.GeneratorDataset(generator, ["col1"])
  128. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  129. data3 = ds.GeneratorDataset(generator_20, ["col1"])
  130. dataset = [data2] + [data3]
  131. data4 = data1 + dataset
  132. # Here i refers to index, d refers to data element
  133. for i, d in enumerate(data4.create_tuple_iterator(output_numpy=True)):
  134. t = d
  135. logger.info("data: %i", t[0][0])
  136. assert i == t[0][0]
  137. assert sum([1 for _ in data4]) == 20
  138. def test_concat_08():
  139. """
  140. Test concat: test concat 2 datasets, and then repeat
  141. """
  142. logger.info("test_concat_08")
  143. data1 = ds.GeneratorDataset(generator, ["col1"])
  144. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  145. data3 = data1 + data2
  146. data3 = data3.repeat(2)
  147. # Here i refers to index, d refers to data element
  148. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  149. t = d
  150. logger.info("data: %i", t[0][0])
  151. assert i % 10 == t[0][0]
  152. assert sum([1 for _ in data3]) == 20
  153. def test_concat_09():
  154. """
  155. Test concat: test concat 2 datasets, both of them have been repeat before
  156. """
  157. logger.info("test_concat_09")
  158. data1 = ds.GeneratorDataset(generator, ["col1"])
  159. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  160. data1 = data1.repeat(2)
  161. data2 = data2.repeat(2)
  162. data3 = data1 + data2
  163. res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8, 9]
  164. # Here i refers to index, d refers to data element
  165. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  166. t = d
  167. logger.info("data: %i", t[0][0])
  168. assert res[i] == t[0][0]
  169. assert sum([1 for _ in data3]) == 20
  170. def test_concat_10():
  171. """
  172. Test concat: test concat 2 datasets, one of them have repeat before
  173. """
  174. logger.info("test_concat_10")
  175. data1 = ds.GeneratorDataset(generator, ["col1"])
  176. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  177. data1 = data1.repeat(2)
  178. data3 = data1 + data2
  179. res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  180. # Here i refers to index, d refers to data element
  181. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  182. t = d
  183. logger.info("data: %i", t[0][0])
  184. assert res[i] == t[0][0]
  185. assert sum([1 for _ in data3]) == 13
  186. def test_concat_11():
  187. """
  188. Test concat: test dataset batch then concat
  189. """
  190. logger.info("test_concat_11")
  191. data1 = ds.GeneratorDataset(generator, ["col1"])
  192. data2 = ds.GeneratorDataset(generator_20, ["col1"])
  193. data1 = data1.batch(3)
  194. data2 = data2.batch(5)
  195. data3 = data1 + data2
  196. res = [0, 10, 15, 20]
  197. # Here i refers to index, d refers to data element
  198. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  199. t = d
  200. logger.info("data: %i", t[0][0])
  201. assert res[i] == t[0][0]
  202. assert sum([1 for _ in data3]) == 3
  203. def test_concat_12():
  204. """
  205. Test concat: test dataset concat then shuffle
  206. """
  207. logger.info("test_concat_12")
  208. data1 = ds.GeneratorDataset(generator, ["col1"])
  209. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  210. data3 = data1 + data2
  211. res = [8, 6, 2, 5, 0, 4, 9, 3, 7, 1]
  212. ds.config.set_seed(1)
  213. assert data3.get_dataset_size() == 10
  214. data3 = data3.shuffle(buffer_size=10)
  215. # Here i refers to index, d refers to data element
  216. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  217. t = d
  218. logger.info("data: %i", t[0][0])
  219. assert res[i] == t[0][0]
  220. assert sum([1 for _ in data3]) == 10
  221. def test_concat_13():
  222. """
  223. Test concat: test dataset batch then shuffle and concat
  224. """
  225. logger.info("test_concat_13")
  226. data1 = ds.GeneratorDataset(generator, ["col1"])
  227. data2 = ds.GeneratorDataset(generator_20, ["col1"])
  228. data1 = data1.batch(3)
  229. data2 = data2.batch(5)
  230. data3 = data1 + data2
  231. res = [15, 0, 10]
  232. ds.config.set_seed(1)
  233. assert data3.get_dataset_size() == 3
  234. data3 = data3.shuffle(buffer_size=int(data3.get_dataset_size()))
  235. # Here i refers to index, d refers to data element
  236. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  237. t = d
  238. logger.info("data: %i", t[0][0])
  239. assert res[i] == t[0][0]
  240. assert sum([1 for _ in data3]) == 3
  241. def test_concat_14():
  242. """
  243. Test concat: create dataset with different dataset folder, and do diffrent operation then concat
  244. """
  245. logger.info("test_concat_14")
  246. DATA_DIR = "../data/dataset/testPK/data"
  247. DATA_DIR2 = "../data/dataset/testImageNetData/train/"
  248. data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=3)
  249. data2 = ds.ImageFolderDataset(DATA_DIR2, num_samples=2)
  250. transforms1 = mindspore.dataset.transforms.py_transforms.Compose([F.Decode(),
  251. F.Resize((224, 224)),
  252. F.ToTensor()])
  253. data1 = data1.map(operations=transforms1, input_columns=["image"])
  254. data2 = data2.map(operations=transforms1, input_columns=["image"])
  255. data3 = data1 + data2
  256. expected, output = [], []
  257. for d in data1.create_tuple_iterator(output_numpy=True):
  258. expected.append(d[0])
  259. for d in data2.create_tuple_iterator(output_numpy=True):
  260. expected.append(d[0])
  261. for d in data3.create_tuple_iterator(output_numpy=True):
  262. output.append(d[0])
  263. assert len(expected) == len(output)
  264. np.array_equal(np.array(output), np.array(expected))
  265. assert sum([1 for _ in data3]) == 5
  266. assert data3.get_dataset_size() == 5
  267. def test_concat_15():
  268. """
  269. Test concat: create dataset with different format of dataset file, and then concat
  270. """
  271. logger.info("test_concat_15")
  272. DATA_DIR = "../data/dataset/testPK/data"
  273. DATA_DIR2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  274. data1 = ds.ImageFolderDataset(DATA_DIR)
  275. data2 = ds.TFRecordDataset(DATA_DIR2, columns_list=["image"])
  276. data1 = data1.project(["image"])
  277. data3 = data1 + data2
  278. assert sum([1 for _ in data3]) == 47
  279. if __name__ == "__main__":
  280. test_concat_01()
  281. test_concat_02()
  282. test_concat_03()
  283. test_concat_04()
  284. test_concat_05()
  285. test_concat_06()
  286. test_concat_07()
  287. test_concat_08()
  288. test_concat_09()
  289. test_concat_10()
  290. test_concat_11()
  291. test_concat_12()
  292. test_concat_13()
  293. test_concat_14()
  294. test_concat_15()