You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_concat.py 13 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore.common.dtype as mstype
  17. import mindspore.dataset as ds
  18. import mindspore.dataset.transforms.c_transforms as C
  19. import mindspore.dataset.transforms.py_transforms
  20. import mindspore.dataset.vision.py_transforms as F
  21. from mindspore import log as logger
  22. # In generator dataset: Number of rows is 3; its values are 0, 1, 2
  23. def generator():
  24. for i in range(3):
  25. yield (np.array([i]),)
  26. # In generator_10 dataset: Number of rows is 7; its values are 3, 4, 5 ... 9
  27. def generator_10():
  28. for i in range(3, 10):
  29. yield (np.array([i]),)
  30. # In generator_20 dataset: Number of rows is 10; its values are 10, 11, 12 ... 19
  31. def generator_20():
  32. for i in range(10, 20):
  33. yield (np.array([i]),)
  34. # In generator_29 dataset: Number of rows is 9; its values are 20, 21, 22 ... 28
  35. def generator_29():
  36. for i in range(20, 29):
  37. yield (np.array([i]),)
  38. def test_concat_01():
  39. """
  40. Test concat: test concat 2 datasets that have the same column name and data type
  41. """
  42. logger.info("test_concat_01")
  43. data1 = ds.GeneratorDataset(generator, ["col1"])
  44. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  45. data3 = data1 + data2
  46. # Here i refers to index, d refers to data element
  47. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  48. t = d
  49. logger.info("data: %i", t[0][0])
  50. assert i == t[0][0]
  51. assert sum([1 for _ in data3]) == 10
  52. def test_concat_02():
  53. """
  54. Test concat: test concat 2 datasets using concat operation not "+" operation
  55. """
  56. logger.info("test_concat_02")
  57. data1 = ds.GeneratorDataset(generator, ["col1"])
  58. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  59. data3 = data1.concat(data2)
  60. # Here i refers to index, d refers to data element
  61. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  62. t = d
  63. logger.info("data: %i", t[0][0])
  64. assert i == t[0][0]
  65. assert sum([1 for _ in data3]) == 10
  66. def test_concat_03():
  67. """
  68. Test concat: test concat dataset that has different column
  69. """
  70. logger.info("test_concat_03")
  71. data1 = ds.GeneratorDataset(generator, ["col1"])
  72. data2 = ds.GeneratorDataset(generator_10, ["col2"])
  73. data3 = data1 + data2
  74. try:
  75. for _, _ in enumerate(data3):
  76. pass
  77. assert False
  78. except RuntimeError:
  79. pass
  80. def test_concat_04():
  81. """
  82. Test concat: test concat dataset that has different rank
  83. """
  84. logger.info("test_concat_04")
  85. data1 = ds.GeneratorDataset(generator, ["col1"])
  86. data2 = ds.GeneratorDataset(generator_10, ["col2"])
  87. data2 = data2.batch(3)
  88. data3 = data1 + data2
  89. try:
  90. for _, _ in enumerate(data3):
  91. pass
  92. assert False
  93. except RuntimeError:
  94. pass
  95. def test_concat_05():
  96. """
  97. Test concat: test concat dataset that has different data type
  98. """
  99. logger.info("test_concat_05")
  100. data1 = ds.GeneratorDataset(generator, ["col1"])
  101. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  102. type_cast_op = C.TypeCast(mstype.float32)
  103. data1 = data1.map(operations=type_cast_op, input_columns=["col1"])
  104. data3 = data1 + data2
  105. try:
  106. for _, _ in enumerate(data3):
  107. pass
  108. assert False
  109. except RuntimeError:
  110. pass
  111. def test_concat_06():
  112. """
  113. Test concat: test concat multi datasets in one time
  114. """
  115. logger.info("test_concat_06")
  116. data1 = ds.GeneratorDataset(generator, ["col1"])
  117. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  118. data3 = ds.GeneratorDataset(generator_20, ["col1"])
  119. dataset = data1 + data2 + data3
  120. # Here i refers to index, d refers to data element
  121. for i, d in enumerate(dataset.create_tuple_iterator(output_numpy=True)):
  122. t = d
  123. logger.info("data: %i", t[0][0])
  124. assert i == t[0][0]
  125. assert sum([1 for _ in dataset]) == 20
  126. def test_concat_07():
  127. """
  128. Test concat: test concat one dataset with multi datasets (datasets list)
  129. """
  130. logger.info("test_concat_07")
  131. data1 = ds.GeneratorDataset(generator, ["col1"])
  132. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  133. data3 = ds.GeneratorDataset(generator_20, ["col1"])
  134. dataset = [data2] + [data3]
  135. data4 = data1 + dataset
  136. # Here i refers to index, d refers to data element
  137. for i, d in enumerate(data4.create_tuple_iterator(output_numpy=True)):
  138. t = d
  139. logger.info("data: %i", t[0][0])
  140. assert i == t[0][0]
  141. assert sum([1 for _ in data4]) == 20
  142. def test_concat_08():
  143. """
  144. Test concat: test concat 2 datasets, and then repeat
  145. """
  146. logger.info("test_concat_08")
  147. data1 = ds.GeneratorDataset(generator, ["col1"])
  148. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  149. data3 = data1 + data2
  150. data3 = data3.repeat(2)
  151. # Here i refers to index, d refers to data element
  152. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  153. t = d
  154. logger.info("data: %i", t[0][0])
  155. assert i % 10 == t[0][0]
  156. assert sum([1 for _ in data3]) == 20
  157. def test_concat_09():
  158. """
  159. Test concat: test concat 2 datasets, both of them have been repeat before
  160. """
  161. logger.info("test_concat_09")
  162. data1 = ds.GeneratorDataset(generator, ["col1"])
  163. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  164. data1 = data1.repeat(2)
  165. data2 = data2.repeat(2)
  166. data3 = data1 + data2
  167. res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8, 9]
  168. # Here i refers to index, d refers to data element
  169. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  170. t = d
  171. logger.info("data: %i", t[0][0])
  172. assert res[i] == t[0][0]
  173. assert sum([1 for _ in data3]) == 20
  174. def test_concat_10():
  175. """
  176. Test concat: test concat 2 datasets, one of them have repeat before
  177. """
  178. logger.info("test_concat_10")
  179. data1 = ds.GeneratorDataset(generator, ["col1"])
  180. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  181. data1 = data1.repeat(2)
  182. data3 = data1 + data2
  183. res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  184. # Here i refers to index, d refers to data element
  185. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  186. t = d
  187. logger.info("data: %i", t[0][0])
  188. assert res[i] == t[0][0]
  189. assert sum([1 for _ in data3]) == 13
  190. def test_concat_11():
  191. """
  192. Test concat: test dataset batch then concat
  193. """
  194. logger.info("test_concat_11")
  195. data1 = ds.GeneratorDataset(generator, ["col1"])
  196. data2 = ds.GeneratorDataset(generator_20, ["col1"])
  197. data1 = data1.batch(3)
  198. data2 = data2.batch(5)
  199. data3 = data1 + data2
  200. res = [0, 10, 15, 20]
  201. # Here i refers to index, d refers to data element
  202. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  203. t = d
  204. logger.info("data: %i", t[0][0])
  205. assert res[i] == t[0][0]
  206. assert sum([1 for _ in data3]) == 3
  207. def test_concat_12():
  208. """
  209. Test concat: test dataset concat then shuffle
  210. """
  211. logger.info("test_concat_12")
  212. data1 = ds.GeneratorDataset(generator, ["col1"])
  213. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  214. data3 = data1 + data2
  215. res = [8, 6, 2, 5, 0, 4, 9, 3, 7, 1]
  216. ds.config.set_seed(1)
  217. assert data3.get_dataset_size() == 10
  218. data3 = data3.shuffle(buffer_size=10)
  219. # Here i refers to index, d refers to data element
  220. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  221. t = d
  222. logger.info("data: %i", t[0][0])
  223. assert res[i] == t[0][0]
  224. assert sum([1 for _ in data3]) == 10
  225. def test_concat_13():
  226. """
  227. Test concat: test dataset batch then shuffle and concat
  228. """
  229. logger.info("test_concat_13")
  230. data1 = ds.GeneratorDataset(generator, ["col1"])
  231. data2 = ds.GeneratorDataset(generator_20, ["col1"])
  232. data1 = data1.batch(3)
  233. data2 = data2.batch(5)
  234. data3 = data1 + data2
  235. res = [15, 0, 10]
  236. ds.config.set_seed(1)
  237. assert data3.get_dataset_size() == 3
  238. data3 = data3.shuffle(buffer_size=int(data3.get_dataset_size()))
  239. # Here i refers to index, d refers to data element
  240. for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)):
  241. t = d
  242. logger.info("data: %i", t[0][0])
  243. assert res[i] == t[0][0]
  244. assert sum([1 for _ in data3]) == 3
  245. def test_concat_14():
  246. """
  247. Test concat: Testing concat on two different source datasets with different dataset operations.
  248. """
  249. logger.info("test_concat_14")
  250. DATA_DIR = "../data/dataset/testPK/data"
  251. DATA_DIR2 = "../data/dataset/testImageNetData/train/"
  252. data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=3)
  253. data2 = ds.ImageFolderDataset(DATA_DIR2, num_samples=2)
  254. transforms1 = mindspore.dataset.transforms.py_transforms.Compose([F.Decode(),
  255. F.Resize((224, 224)),
  256. F.ToTensor()])
  257. data1 = data1.map(operations=transforms1, input_columns=["image"])
  258. data2 = data2.map(operations=transforms1, input_columns=["image"])
  259. data3 = data1 + data2
  260. expected, output = [], []
  261. for d in data1.create_tuple_iterator(output_numpy=True):
  262. expected.append(d[0])
  263. for d in data2.create_tuple_iterator(output_numpy=True):
  264. expected.append(d[0])
  265. for d in data3.create_tuple_iterator(output_numpy=True):
  266. output.append(d[0])
  267. assert len(expected) == len(output)
  268. np.array_equal(np.array(output), np.array(expected))
  269. assert sum([1 for _ in data3]) == 5
  270. assert data3.get_dataset_size() == 5
  271. def test_concat_15():
  272. """
  273. Test concat: create dataset with different format of dataset file, and then concat
  274. """
  275. logger.info("test_concat_15")
  276. DATA_DIR = "../data/dataset/testPK/data"
  277. DATA_DIR2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  278. data1 = ds.ImageFolderDataset(DATA_DIR)
  279. data2 = ds.TFRecordDataset(DATA_DIR2, columns_list=["image"])
  280. data1 = data1.project(["image"])
  281. data3 = data1 + data2
  282. assert sum([1 for _ in data3]) == 47
  283. def test_concat_16():
  284. """
  285. Test concat: test get_dataset_size on nested concats
  286. """
  287. logger.info("test_concat_16")
  288. DATA_DIR = "../data/dataset/testPK/data"
  289. DATA_DIR2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  290. data1 = ds.ImageFolderDataset(DATA_DIR)
  291. data2 = ds.TFRecordDataset(DATA_DIR2, columns_list=["image"])
  292. data3 = ds.GeneratorDataset(generator, ["col1"])
  293. data4 = ds.GeneratorDataset(generator_10, ["col1"])
  294. data5 = data1 + data2
  295. data6 = data3 + data4
  296. data7 = data5 + data6
  297. ds.config.set_seed(1)
  298. # 57 is the total size of all 4 leaf datasets
  299. assert data7.get_dataset_size() == 57
  300. def test_concat_17():
  301. """
  302. Test concat: test get_dataset_size on nested concats (with sampler)
  303. """
  304. logger.info("test_concat_17")
  305. data1 = ds.GeneratorDataset(generator, ["col1"])
  306. data2 = ds.GeneratorDataset(generator_10, ["col1"])
  307. data3 = ds.GeneratorDataset(generator_20, ["col1"])
  308. data4 = ds.GeneratorDataset(generator_29, ["col1"])
  309. data5 = data1 + data2
  310. data6 = data3 + data4
  311. data7 = data5 + data6
  312. ds.config.set_seed(1)
  313. shard_num = 10
  314. counter = 0
  315. for i in range(shard_num):
  316. distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
  317. data7.use_sampler(distributed_sampler)
  318. iter_counter = 0
  319. for _ in data7.create_dict_iterator(num_epochs=1, output_numpy=True):
  320. counter += 1
  321. iter_counter += 1
  322. assert data7.get_dataset_size() == iter_counter
  323. # 29 is the total size of all 4 leaf datasets
  324. assert counter == 29
  325. if __name__ == "__main__":
  326. test_concat_01()
  327. test_concat_02()
  328. test_concat_03()
  329. test_concat_04()
  330. test_concat_05()
  331. test_concat_06()
  332. test_concat_07()
  333. test_concat_08()
  334. test_concat_09()
  335. test_concat_10()
  336. test_concat_11()
  337. test_concat_12()
  338. test_concat_13()
  339. test_concat_14()
  340. test_concat_15()
  341. test_concat_16()
  342. test_concat_17()