You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_zip.py 8.9 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. from util import save_and_check_dict, save_and_check_md5
  16. import mindspore.dataset as ds
  17. from mindspore import log as logger
  18. # Dataset in DIR_1 has 5 rows and 5 columns
  19. DATA_DIR_1 = ["../data/dataset/testTFBert5Rows1/5TFDatas.data"]
  20. SCHEMA_DIR_1 = "../data/dataset/testTFBert5Rows1/datasetSchema.json"
  21. # Dataset in DIR_2 has 5 rows and 2 columns
  22. DATA_DIR_2 = ["../data/dataset/testTFBert5Rows2/5TFDatas.data"]
  23. SCHEMA_DIR_2 = "../data/dataset/testTFBert5Rows2/datasetSchema.json"
  24. # Dataset in DIR_3 has 3 rows and 2 columns
  25. DATA_DIR_3 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  26. SCHEMA_DIR_3 = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  27. # Dataset in DIR_4 has 5 rows and 7 columns
  28. DATA_DIR_4 = ["../data/dataset/testTFBert5Rows/5TFDatas.data"]
  29. SCHEMA_DIR_4 = "../data/dataset/testTFBert5Rows/datasetSchema.json"
  30. GENERATE_GOLDEN = False
  31. def test_zip_01():
  32. """
  33. Test zip: zip 2 datasets, #rows-data1 == #rows-data2, #cols-data1 < #cols-data2
  34. """
  35. logger.info("test_zip_01")
  36. ds.config.set_seed(1)
  37. data1 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2)
  38. data2 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  39. dataz = ds.zip((data1, data2))
  40. # Note: zipped dataset has 5 rows and 7 columns
  41. filename = "zip_01_result.npz"
  42. save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
  43. def test_zip_02():
  44. """
  45. Test zip: zip 2 datasets, #rows-data1 < #rows-data2, #cols-data1 == #cols-data2
  46. """
  47. logger.info("test_zip_02")
  48. ds.config.set_seed(1)
  49. data1 = ds.TFRecordDataset(DATA_DIR_3, SCHEMA_DIR_3)
  50. data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2)
  51. dataz = ds.zip((data1, data2))
  52. # Note: zipped dataset has 3 rows and 4 columns
  53. filename = "zip_02_result.npz"
  54. save_and_check_md5(dataz, filename, generate_golden=GENERATE_GOLDEN)
  55. def test_zip_03():
  56. """
  57. Test zip: zip 2 datasets, #rows-data1 > #rows-data2, #cols-data1 > #cols-data2
  58. """
  59. logger.info("test_zip_03")
  60. ds.config.set_seed(1)
  61. data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  62. data2 = ds.TFRecordDataset(DATA_DIR_3, SCHEMA_DIR_3)
  63. dataz = ds.zip((data1, data2))
  64. # Note: zipped dataset has 3 rows and 7 columns
  65. filename = "zip_03_result.npz"
  66. save_and_check_md5(dataz, filename, generate_golden=GENERATE_GOLDEN)
  67. def test_zip_04():
  68. """
  69. Test zip: zip >2 datasets
  70. """
  71. logger.info("test_zip_04")
  72. ds.config.set_seed(1)
  73. data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  74. data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2)
  75. data3 = ds.TFRecordDataset(DATA_DIR_3, SCHEMA_DIR_3)
  76. dataz = ds.zip((data1, data2, data3))
  77. # Note: zipped dataset has 3 rows and 9 columns
  78. filename = "zip_04_result.npz"
  79. save_and_check_md5(dataz, filename, generate_golden=GENERATE_GOLDEN)
  80. def test_zip_05():
  81. """
  82. Test zip: zip dataset with renamed columns
  83. """
  84. logger.info("test_zip_05")
  85. ds.config.set_seed(1)
  86. data1 = ds.TFRecordDataset(DATA_DIR_4, SCHEMA_DIR_4, shuffle=True)
  87. data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2, shuffle=True)
  88. data2 = data2.rename(input_columns="input_ids", output_columns="new_input_ids")
  89. data2 = data2.rename(input_columns="segment_ids", output_columns="new_segment_ids")
  90. dataz = ds.zip((data1, data2))
  91. # Note: zipped dataset has 5 rows and 9 columns
  92. filename = "zip_05_result.npz"
  93. save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
  94. def test_zip_06():
  95. """
  96. Test zip: zip dataset with renamed columns and repeat zipped dataset
  97. """
  98. logger.info("test_zip_06")
  99. ds.config.set_seed(1)
  100. data1 = ds.TFRecordDataset(DATA_DIR_4, SCHEMA_DIR_4, shuffle=False)
  101. data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2, shuffle=False)
  102. data2 = data2.rename(input_columns="input_ids", output_columns="new_input_ids")
  103. data2 = data2.rename(input_columns="segment_ids", output_columns="new_segment_ids")
  104. dataz = ds.zip((data1, data2))
  105. dataz = dataz.repeat(2)
  106. # Note: resultant dataset has 10 rows and 9 columns
  107. filename = "zip_06_result.npz"
  108. save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
  109. def test_zip_exception_01():
  110. """
  111. Test zip: zip same datasets
  112. """
  113. logger.info("test_zip_exception_01")
  114. data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  115. try:
  116. dataz = ds.zip((data1, data1))
  117. num_iter = 0
  118. for _, item in enumerate(dataz.create_dict_iterator()):
  119. logger.info("item[input_mask] is {}".format(item["input_mask"]))
  120. num_iter += 1
  121. logger.info("Number of data in zipped dataz: {}".format(num_iter))
  122. except Exception as e:
  123. logger.info("Got an exception in DE: {}".format(str(e)))
  124. def test_zip_exception_02():
  125. """
  126. Test zip: zip datasets with duplicate column name
  127. """
  128. logger.info("test_zip_exception_02")
  129. data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  130. data2 = ds.TFRecordDataset(DATA_DIR_4, SCHEMA_DIR_4)
  131. try:
  132. dataz = ds.zip((data1, data2))
  133. num_iter = 0
  134. for _, item in enumerate(dataz.create_dict_iterator()):
  135. logger.info("item[input_mask] is {}".format(item["input_mask"]))
  136. num_iter += 1
  137. logger.info("Number of data in zipped dataz: {}".format(num_iter))
  138. except Exception as e:
  139. logger.info("Got an exception in DE: {}".format(str(e)))
  140. def test_zip_exception_03():
  141. """
  142. Test zip: zip with tuple of 1 dataset
  143. """
  144. logger.info("test_zip_exception_03")
  145. data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  146. try:
  147. dataz = ds.zip((data1))
  148. dataz = dataz.repeat(2)
  149. num_iter = 0
  150. for _, item in enumerate(dataz.create_dict_iterator()):
  151. logger.info("item[input_mask] is {}".format(item["input_mask"]))
  152. num_iter += 1
  153. logger.info("Number of data in zipped dataz: {}".format(num_iter))
  154. except Exception as e:
  155. logger.info("Got an exception in DE: {}".format(str(e)))
  156. def test_zip_exception_04():
  157. """
  158. Test zip: zip with empty tuple of datasets
  159. """
  160. logger.info("test_zip_exception_04")
  161. try:
  162. dataz = ds.zip(())
  163. dataz = dataz.repeat(2)
  164. num_iter = 0
  165. for _, item in enumerate(dataz.create_dict_iterator()):
  166. logger.info("item[input_mask] is {}".format(item["input_mask"]))
  167. num_iter += 1
  168. logger.info("Number of data in zipped dataz: {}".format(num_iter))
  169. except Exception as e:
  170. logger.info("Got an exception in DE: {}".format(str(e)))
  171. def test_zip_exception_05():
  172. """
  173. Test zip: zip with non-tuple of 2 datasets
  174. """
  175. logger.info("test_zip_exception_05")
  176. data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  177. data2 = ds.TFRecordDataset(DATA_DIR_2, SCHEMA_DIR_2)
  178. try:
  179. dataz = ds.zip(data1, data2)
  180. num_iter = 0
  181. for _, item in enumerate(dataz.create_dict_iterator()):
  182. logger.info("item[input_mask] is {}".format(item["input_mask"]))
  183. num_iter += 1
  184. logger.info("Number of data in zipped dataz: {}".format(num_iter))
  185. except Exception as e:
  186. logger.info("Got an exception in DE: {}".format(str(e)))
  187. def test_zip_exception_06():
  188. """
  189. Test zip: zip with non-tuple of 1 dataset
  190. """
  191. logger.info("test_zip_exception_06")
  192. data1 = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
  193. try:
  194. dataz = ds.zip(data1)
  195. num_iter = 0
  196. for _, item in enumerate(dataz.create_dict_iterator()):
  197. logger.info("item[input_mask] is {}".format(item["input_mask"]))
  198. num_iter += 1
  199. logger.info("Number of data in zipped dataz: {}".format(num_iter))
  200. except Exception as e:
  201. logger.info("Got an exception in DE: {}".format(str(e)))
  202. if __name__ == '__main__':
  203. test_zip_01()
  204. #test_zip_02()
  205. #test_zip_03()
  206. #test_zip_04()
  207. #test_zip_05()
  208. #test_zip_06()
  209. #test_zip_exception_01()
  210. #test_zip_exception_02()
  211. #test_zip_exception_03()
  212. #test_zip_exception_04()
  213. #test_zip_exception_05()
  214. #test_zip_exception_06()