You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_zip.py 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. from util import save_and_check_dict
  16. from mindspore import log as logger
  17. import mindspore.dataset as ds
  18. # Dataset in DIR_1 has 5 rows and 5 columns
  19. DATA_DIR_1 = ["../data/dataset/testTFBert5Rows1/5TFDatas.data"]
  20. SCHEMA_DIR_1 = "../data/dataset/testTFBert5Rows1/datasetSchema.json"
  21. # Dataset in DIR_2 has 5 rows and 2 columns
  22. DATA_DIR_2 = ["../data/dataset/testTFBert5Rows2/5TFDatas.data"]
  23. SCHEMA_DIR_2 = "../data/dataset/testTFBert5Rows2/datasetSchema.json"
  24. # Dataset in DIR_3 has 3 rows and 2 columns
  25. DATA_DIR_3 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  26. SCHEMA_DIR_3 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  27. # Dataset in DIR_4 has 5 rows and 7 columns
  28. DATA_DIR_4 = ["../data/dataset/testTFBert5Rows/5TFDatas.data"]
  29. SCHEMA_DIR_4 = "../data/dataset/testTFBert5Rows/datasetSchema.json"
  30. GENERATE_GOLDEN = False
  31. def test_zip_01():
  32. """
  33. Test zip: zip 2 datasets, #rows-data1 == #rows-data2, #cols-data1 < #cols-data2
  34. """
  35. logger.info("test_zip_01")
  36. ds.config.set_seed(1)
  37. data1 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2)
  38. data2 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  39. dataz = ds.zip((data1, data2))
  40. # Note: zipped dataset has 5 rows and 7 columns
  41. filename = "zip_01_result.npz"
  42. parameters = {"params": {}}
  43. save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
  44. def test_zip_02():
  45. """
  46. Test zip: zip 2 datasets, #rows-data1 < #rows-data2, #cols-data1 == #cols-data2
  47. """
  48. logger.info("test_zip_02")
  49. ds.config.set_seed(1)
  50. data1 = ds.TFRecordDataset(DATA_DIR_3, SCHEMA_DIR_3)
  51. data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2)
  52. dataz = ds.zip((data1, data2))
  53. # Note: zipped dataset has 3 rows and 4 columns
  54. filename = "zip_02_result.npz"
  55. parameters = {"params": {}}
  56. save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
  57. def test_zip_03():
  58. """
  59. Test zip: zip 2 datasets, #rows-data1 > #rows-data2, #cols-data1 > #cols-data2
  60. """
  61. logger.info("test_zip_03")
  62. ds.config.set_seed(1)
  63. data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  64. data2 = ds.TFRecordDataset(DATA_DIR_3, SCHEMA_DIR_3)
  65. dataz = ds.zip((data1, data2))
  66. # Note: zipped dataset has 3 rows and 7 columns
  67. filename = "zip_03_result.npz"
  68. parameters = {"params": {}}
  69. save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
  70. def test_zip_04():
  71. """
  72. Test zip: zip >2 datasets
  73. """
  74. logger.info("test_zip_04")
  75. ds.config.set_seed(1)
  76. data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  77. data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2)
  78. data3 = ds.TFRecordDataset(DATA_DIR_3, SCHEMA_DIR_3)
  79. dataz = ds.zip((data1, data2, data3))
  80. # Note: zipped dataset has 3 rows and 9 columns
  81. filename = "zip_04_result.npz"
  82. parameters = {"params": {}}
  83. save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
  84. def test_zip_05():
  85. """
  86. Test zip: zip dataset with renamed columns
  87. """
  88. logger.info("test_zip_05")
  89. ds.config.set_seed(1)
  90. data1 = ds.TFRecordDataset(DATA_DIR_4, SCHEMA_DIR_4, shuffle=True)
  91. data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2, shuffle=True)
  92. data2 = data2.rename(input_columns="input_ids", output_columns="new_input_ids")
  93. data2 = data2.rename(input_columns="segment_ids", output_columns="new_segment_ids")
  94. dataz = ds.zip((data1, data2))
  95. # Note: zipped dataset has 5 rows and 9 columns
  96. filename = "zip_05_result.npz"
  97. parameters = {"params": {}}
  98. save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
  99. def test_zip_06():
  100. """
  101. Test zip: zip dataset with renamed columns and repeat zipped dataset
  102. """
  103. logger.info("test_zip_06")
  104. ds.config.set_seed(1)
  105. data1 = ds.TFRecordDataset(DATA_DIR_4, SCHEMA_DIR_4, shuffle=False)
  106. data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2, shuffle=False)
  107. data2 = data2.rename(input_columns="input_ids", output_columns="new_input_ids")
  108. data2 = data2.rename(input_columns="segment_ids", output_columns="new_segment_ids")
  109. dataz = ds.zip((data1, data2))
  110. dataz = dataz.repeat(2)
  111. # Note: resultant dataset has 10 rows and 9 columns
  112. filename = "zip_06_result.npz"
  113. parameters = {"params": {}}
  114. save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
  115. def test_zip_exception_01():
  116. """
  117. Test zip: zip same datasets
  118. """
  119. logger.info("test_zip_exception_01")
  120. data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  121. try:
  122. dataz = ds.zip((data1, data1))
  123. num_iter = 0
  124. for _, item in enumerate(dataz.create_dict_iterator()):
  125. logger.info("item[input_mask] is {}".format(item["input_mask"]))
  126. num_iter += 1
  127. logger.info("Number of data in zipped dataz: {}".format(num_iter))
  128. except BaseException as e:
  129. logger.info("Got an exception in DE: {}".format(str(e)))
  130. def skip_test_zip_exception_02():
  131. """
  132. Test zip: zip datasets with duplicate column name
  133. """
  134. logger.info("test_zip_exception_02")
  135. data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  136. data2 = ds.TFRecordDataset(DATA_DIR_4, SCHEMA_DIR_4)
  137. try:
  138. dataz = ds.zip((data1, data2))
  139. num_iter = 0
  140. for _, item in enumerate(dataz.create_dict_iterator()):
  141. logger.info("item[input_mask] is {}".format(item["input_mask"]))
  142. num_iter += 1
  143. logger.info("Number of data in zipped dataz: {}".format(num_iter))
  144. except BaseException as e:
  145. logger.info("Got an exception in DE: {}".format(str(e)))
  146. def test_zip_exception_03():
  147. """
  148. Test zip: zip with tuple of 1 dataset
  149. """
  150. logger.info("test_zip_exception_03")
  151. data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  152. try:
  153. dataz = ds.zip((data1))
  154. dataz = dataz.repeat(2)
  155. num_iter = 0
  156. for _, item in enumerate(dataz.create_dict_iterator()):
  157. logger.info("item[input_mask] is {}".format(item["input_mask"]))
  158. num_iter += 1
  159. logger.info("Number of data in zipped dataz: {}".format(num_iter))
  160. except BaseException as e:
  161. logger.info("Got an exception in DE: {}".format(str(e)))
  162. def test_zip_exception_04():
  163. """
  164. Test zip: zip with empty tuple of datasets
  165. """
  166. logger.info("test_zip_exception_04")
  167. try:
  168. dataz = ds.zip(())
  169. dataz = dataz.repeat(2)
  170. num_iter = 0
  171. for _, item in enumerate(dataz.create_dict_iterator()):
  172. logger.info("item[input_mask] is {}".format(item["input_mask"]))
  173. num_iter += 1
  174. logger.info("Number of data in zipped dataz: {}".format(num_iter))
  175. except BaseException as e:
  176. logger.info("Got an exception in DE: {}".format(str(e)))
  177. def test_zip_exception_05():
  178. """
  179. Test zip: zip with non-tuple of 2 datasets
  180. """
  181. logger.info("test_zip_exception_05")
  182. data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  183. data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2)
  184. try:
  185. dataz = ds.zip(data1, data2)
  186. num_iter = 0
  187. for _, item in enumerate(dataz.create_dict_iterator()):
  188. logger.info("item[input_mask] is {}".format(item["input_mask"]))
  189. num_iter += 1
  190. logger.info("Number of data in zipped dataz: {}".format(num_iter))
  191. except BaseException as e:
  192. logger.info("Got an exception in DE: {}".format(str(e)))
  193. def test_zip_exception_06():
  194. """
  195. Test zip: zip with non-tuple of 1 dataset
  196. """
  197. logger.info("test_zip_exception_06")
  198. data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  199. try:
  200. dataz = ds.zip(data1)
  201. num_iter = 0
  202. for _, item in enumerate(dataz.create_dict_iterator()):
  203. logger.info("item[input_mask] is {}".format(item["input_mask"]))
  204. num_iter += 1
  205. logger.info("Number of data in zipped dataz: {}".format(num_iter))
  206. except BaseException as e:
  207. logger.info("Got an exception in DE: {}".format(str(e)))
  208. if __name__ == '__main__':
  209. test_zip_01()
  210. test_zip_02()
  211. test_zip_03()
  212. test_zip_04()
  213. test_zip_05()
  214. test_zip_06()
  215. test_zip_exception_01()
  216. test_zip_exception_02()
  217. test_zip_exception_03()
  218. test_zip_exception_04()
  219. test_zip_exception_05()
  220. test_zip_exception_06()