You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tfreader_op.py 12 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test TFRecordDataset Ops
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.common.dtype as mstype
  21. import mindspore.dataset as ds
  22. from mindspore import log as logger
  23. from util import save_and_check_dict
  24. FILES = ["../data/dataset/testTFTestAllTypes/test.data"]
  25. DATASET_ROOT = "../data/dataset/testTFTestAllTypes/"
  26. SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  27. DATA_FILES2 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
  28. "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
  29. "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
  30. "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
  31. SCHEMA_FILE2 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
  32. GENERATE_GOLDEN = False
  33. def test_tfrecord_shape():
  34. logger.info("test_tfrecord_shape")
  35. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json"
  36. ds1 = ds.TFRecordDataset(FILES, schema_file)
  37. ds1 = ds1.batch(2)
  38. for data in ds1.create_dict_iterator():
  39. logger.info(data)
  40. output_shape = ds1.output_shapes()
  41. assert len(output_shape[-1]) == 1
  42. def test_tfrecord_read_all_dataset():
  43. logger.info("test_tfrecord_read_all_dataset")
  44. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
  45. ds1 = ds.TFRecordDataset(FILES, schema_file)
  46. assert ds1.get_dataset_size() == 12
  47. count = 0
  48. for _ in ds1.create_tuple_iterator():
  49. count += 1
  50. assert count == 12
  51. def test_tfrecord_num_samples():
  52. logger.info("test_tfrecord_num_samples")
  53. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
  54. ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8)
  55. assert ds1.get_dataset_size() == 8
  56. count = 0
  57. for _ in ds1.create_dict_iterator():
  58. count += 1
  59. assert count == 8
  60. def test_tfrecord_num_samples2():
  61. logger.info("test_tfrecord_num_samples2")
  62. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
  63. ds1 = ds.TFRecordDataset(FILES, schema_file)
  64. assert ds1.get_dataset_size() == 7
  65. count = 0
  66. for _ in ds1.create_dict_iterator():
  67. count += 1
  68. assert count == 7
  69. def test_tfrecord_shape2():
  70. logger.info("test_tfrecord_shape2")
  71. ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
  72. ds1 = ds1.batch(2)
  73. output_shape = ds1.output_shapes()
  74. assert len(output_shape[-1]) == 2
  75. def test_tfrecord_files_basic():
  76. logger.info("test_tfrecord_files_basic")
  77. data = ds.TFRecordDataset(FILES, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  78. filename = "tfrecord_files_basic.npz"
  79. save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
  80. def test_tfrecord_no_schema():
  81. logger.info("test_tfrecord_no_schema")
  82. data = ds.TFRecordDataset(FILES, shuffle=ds.Shuffle.FILES)
  83. filename = "tfrecord_no_schema.npz"
  84. save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
  85. def test_tfrecord_pad():
  86. logger.info("test_tfrecord_pad")
  87. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json"
  88. data = ds.TFRecordDataset(FILES, schema_file, shuffle=ds.Shuffle.FILES)
  89. filename = "tfrecord_pad_bytes10.npz"
  90. save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN)
  91. def test_tfrecord_read_files():
  92. logger.info("test_tfrecord_read_files")
  93. pattern = DATASET_ROOT + "/test.data"
  94. data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  95. assert sum([1 for _ in data]) == 12
  96. pattern = DATASET_ROOT + "/test2.data"
  97. data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  98. assert sum([1 for _ in data]) == 12
  99. pattern = DATASET_ROOT + "/*.data"
  100. data = ds.TFRecordDataset(pattern, SCHEMA_FILE, num_samples=24, shuffle=ds.Shuffle.FILES)
  101. assert sum([1 for _ in data]) == 24
  102. pattern = DATASET_ROOT + "/*.data"
  103. data = ds.TFRecordDataset(pattern, SCHEMA_FILE, num_samples=3, shuffle=ds.Shuffle.FILES)
  104. assert sum([1 for _ in data]) == 3
  105. data = ds.TFRecordDataset([DATASET_ROOT + "/test.data", DATASET_ROOT + "/test2.data"],
  106. SCHEMA_FILE, num_samples=24, shuffle=ds.Shuffle.FILES)
  107. assert sum([1 for _ in data]) == 24
  108. def test_tfrecord_multi_files():
  109. logger.info("test_tfrecord_multi_files")
  110. data1 = ds.TFRecordDataset(DATA_FILES2, SCHEMA_FILE2, shuffle=False)
  111. data1 = data1.repeat(1)
  112. num_iter = 0
  113. for _ in data1.create_dict_iterator():
  114. num_iter += 1
  115. assert num_iter == 12
  116. def test_tfrecord_schema():
  117. logger.info("test_tfrecord_schema")
  118. schema = ds.Schema()
  119. schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
  120. schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
  121. schema.add_column('col_3d', de_type=mstype.int64, shape=[2, 2, 2])
  122. schema.add_column('col_binary', de_type=mstype.uint8, shape=[1])
  123. schema.add_column('col_float', de_type=mstype.float32, shape=[1])
  124. schema.add_column('col_sint16', de_type=mstype.int64, shape=[1])
  125. schema.add_column('col_sint32', de_type=mstype.int64, shape=[1])
  126. schema.add_column('col_sint64', de_type=mstype.int64, shape=[1])
  127. data1 = ds.TFRecordDataset(FILES, schema=schema, shuffle=ds.Shuffle.FILES)
  128. data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  129. for d1, d2 in zip(data1, data2):
  130. for t1, t2 in zip(d1, d2):
  131. assert np.array_equal(t1, t2)
  132. def test_tfrecord_shuffle():
  133. logger.info("test_tfrecord_shuffle")
  134. ds.config.set_seed(1)
  135. data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL)
  136. data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  137. data2 = data2.shuffle(10000)
  138. for d1, d2 in zip(data1, data2):
  139. for t1, t2 in zip(d1, d2):
  140. assert np.array_equal(t1, t2)
  141. def test_tfrecord_shard():
  142. logger.info("test_tfrecord_shard")
  143. tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
  144. "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
  145. def get_res(shard_id, num_repeats):
  146. data1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=shard_id, num_samples=3,
  147. shuffle=ds.Shuffle.FILES)
  148. data1 = data1.repeat(num_repeats)
  149. res = list()
  150. for item in data1.create_dict_iterator():
  151. res.append(item["scalars"][0])
  152. return res
  153. # get separate results from two workers. the 2 results need to satisfy 2 criteria
  154. # 1. two workers always give different results in same epoch (e.g. wrkr1:f1&f3, wrkr2:f2&f4 in one epoch)
  155. # 2. with enough epochs, both workers will get the entire dataset (e,g. ep1_wrkr1: f1&f3, ep2,_wrkr1 f2&f4)
  156. worker1_res = get_res(0, 16)
  157. worker2_res = get_res(1, 16)
  158. # Confirm each worker gets 3x16=48 rows
  159. assert len(worker1_res) == 48
  160. assert len(worker1_res) == len(worker2_res)
  161. # check criteria 1
  162. for i, _ in enumerate(worker1_res):
  163. assert worker1_res[i] != worker2_res[i]
  164. # check criteria 2
  165. assert set(worker2_res) == set(worker1_res)
  166. def test_tfrecord_shard_equal_rows():
  167. logger.info("test_tfrecord_shard_equal_rows")
  168. tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
  169. "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
  170. def get_res(num_shards, shard_id, num_repeats):
  171. ds1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, shard_equal_rows=True)
  172. ds1 = ds1.repeat(num_repeats)
  173. res = list()
  174. for data in ds1.create_dict_iterator():
  175. res.append(data["scalars"][0])
  176. return res
  177. worker1_res = get_res(3, 0, 2)
  178. worker2_res = get_res(3, 1, 2)
  179. worker3_res = get_res(3, 2, 2)
  180. # check criteria 1
  181. for i, _ in enumerate(worker1_res):
  182. assert worker1_res[i] != worker2_res[i]
  183. assert worker2_res[i] != worker3_res[i]
  184. # Confirm each worker gets same number of rows
  185. assert len(worker1_res) == 28
  186. assert len(worker1_res) == len(worker2_res)
  187. assert len(worker2_res) == len(worker3_res)
  188. worker4_res = get_res(1, 0, 1)
  189. assert len(worker4_res) == 40
  190. def test_tfrecord_no_schema_columns_list():
  191. logger.info("test_tfrecord_no_schema_columns_list")
  192. data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"])
  193. row = data.create_dict_iterator().get_next()
  194. assert row["col_sint16"] == [-32768]
  195. with pytest.raises(KeyError) as info:
  196. _ = row["col_sint32"]
  197. assert "col_sint32" in str(info.value)
  198. def test_tfrecord_schema_columns_list():
  199. logger.info("test_tfrecord_schema_columns_list")
  200. schema = ds.Schema()
  201. schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
  202. schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
  203. schema.add_column('col_3d', de_type=mstype.int64, shape=[2, 2, 2])
  204. schema.add_column('col_binary', de_type=mstype.uint8, shape=[1])
  205. schema.add_column('col_float', de_type=mstype.float32, shape=[1])
  206. schema.add_column('col_sint16', de_type=mstype.int64, shape=[1])
  207. schema.add_column('col_sint32', de_type=mstype.int64, shape=[1])
  208. schema.add_column('col_sint64', de_type=mstype.int64, shape=[1])
  209. data = ds.TFRecordDataset(FILES, schema=schema, shuffle=False, columns_list=["col_sint16"])
  210. row = data.create_dict_iterator().get_next()
  211. assert row["col_sint16"] == [-32768]
  212. with pytest.raises(KeyError) as info:
  213. _ = row["col_sint32"]
  214. assert "col_sint32" in str(info.value)
  215. def test_tfrecord_invalid_files():
  216. logger.info("test_tfrecord_invalid_files")
  217. valid_file = "../data/dataset/testTFTestAllTypes/test.data"
  218. invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt"
  219. files = [invalid_file, valid_file, SCHEMA_FILE]
  220. data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  221. with pytest.raises(RuntimeError) as info:
  222. _ = data.create_dict_iterator().get_next()
  223. assert "cannot be opened" in str(info.value)
  224. assert "not valid tfrecord files" in str(info.value)
  225. assert valid_file not in str(info.value)
  226. assert invalid_file in str(info.value)
  227. assert SCHEMA_FILE in str(info.value)
  228. nonexistent_file = "this/file/does/not/exist"
  229. files = [invalid_file, valid_file, SCHEMA_FILE, nonexistent_file]
  230. with pytest.raises(ValueError) as info:
  231. data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  232. assert "did not match any files" in str(info.value)
  233. assert valid_file not in str(info.value)
  234. assert invalid_file not in str(info.value)
  235. assert SCHEMA_FILE not in str(info.value)
  236. assert nonexistent_file in str(info.value)
  237. if __name__ == '__main__':
  238. test_tfrecord_shape()
  239. test_tfrecord_read_all_dataset()
  240. test_tfrecord_num_samples()
  241. test_tfrecord_num_samples2()
  242. test_tfrecord_shape2()
  243. test_tfrecord_files_basic()
  244. test_tfrecord_no_schema()
  245. test_tfrecord_pad()
  246. test_tfrecord_read_files()
  247. test_tfrecord_multi_files()
  248. test_tfrecord_schema()
  249. test_tfrecord_shuffle()
  250. test_tfrecord_shard()
  251. test_tfrecord_shard_equal_rows()
  252. test_tfrecord_no_schema_columns_list()
  253. test_tfrecord_schema_columns_list()
  254. test_tfrecord_invalid_files()