You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tfreader_op.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. from util import save_and_check
  17. import mindspore.common.dtype as mstype
  18. import mindspore.dataset as ds
  19. from mindspore import log as logger
  20. import pytest
  21. FILES = ["../data/dataset/testTFTestAllTypes/test.data"]
  22. DATASET_ROOT = "../data/dataset/testTFTestAllTypes/"
  23. SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  24. GENERATE_GOLDEN = False
  25. def test_case_tf_shape():
  26. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json"
  27. ds1 = ds.TFRecordDataset(FILES, schema_file)
  28. ds1 = ds1.batch(2)
  29. for data in ds1.create_dict_iterator():
  30. logger.info(data)
  31. output_shape = ds1.output_shapes()
  32. assert (len(output_shape[-1]) == 1)
  33. def test_case_tf_read_all_dataset():
  34. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
  35. ds1 = ds.TFRecordDataset(FILES, schema_file)
  36. assert ds1.get_dataset_size() == 12
  37. count = 0
  38. for data in ds1.create_tuple_iterator():
  39. count += 1
  40. assert count == 12
  41. def test_case_num_samples():
  42. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
  43. ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8)
  44. assert ds1.get_dataset_size() == 8
  45. count = 0
  46. for data in ds1.create_dict_iterator():
  47. count += 1
  48. assert count == 8
  49. def test_case_num_samples2():
  50. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json"
  51. ds1 = ds.TFRecordDataset(FILES, schema_file)
  52. assert ds1.get_dataset_size() == 7
  53. count = 0
  54. for data in ds1.create_dict_iterator():
  55. count += 1
  56. assert count == 7
  57. def test_case_tf_shape_2():
  58. ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
  59. ds1 = ds1.batch(2)
  60. output_shape = ds1.output_shapes()
  61. assert (len(output_shape[-1]) == 2)
  62. def test_case_tf_file():
  63. logger.info("reading data from: {}".format(FILES[0]))
  64. parameters = {"params": {}}
  65. data = ds.TFRecordDataset(FILES, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  66. filename = "tfreader_result.npz"
  67. save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN)
  68. def test_case_tf_file_no_schema():
  69. logger.info("reading data from: {}".format(FILES[0]))
  70. parameters = {"params": {}}
  71. data = ds.TFRecordDataset(FILES, shuffle=ds.Shuffle.FILES)
  72. filename = "tf_file_no_schema.npz"
  73. save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN)
  74. def test_case_tf_file_pad():
  75. logger.info("reading data from: {}".format(FILES[0]))
  76. parameters = {"params": {}}
  77. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json"
  78. data = ds.TFRecordDataset(FILES, schema_file, shuffle=ds.Shuffle.FILES)
  79. filename = "tf_file_padBytes10.npz"
  80. save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN)
  81. def test_tf_files():
  82. pattern = DATASET_ROOT + "/test.data"
  83. data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  84. assert sum([1 for _ in data]) == 12
  85. pattern = DATASET_ROOT + "/test2.data"
  86. data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  87. assert sum([1 for _ in data]) == 12
  88. pattern = DATASET_ROOT + "/*.data"
  89. data = ds.TFRecordDataset(pattern, SCHEMA_FILE, num_samples=24, shuffle=ds.Shuffle.FILES)
  90. assert sum([1 for _ in data]) == 24
  91. pattern = DATASET_ROOT + "/*.data"
  92. data = ds.TFRecordDataset(pattern, SCHEMA_FILE, num_samples=3, shuffle=ds.Shuffle.FILES)
  93. assert sum([1 for _ in data]) == 3
  94. data = ds.TFRecordDataset([DATASET_ROOT + "/test.data", DATASET_ROOT + "/test2.data"],
  95. SCHEMA_FILE, num_samples=24, shuffle=ds.Shuffle.FILES)
  96. assert sum([1 for _ in data]) == 24
  97. def test_tf_record_schema():
  98. schema = ds.Schema()
  99. schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
  100. schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
  101. schema.add_column('col_3d', de_type=mstype.int64, shape=[2, 2, 2])
  102. schema.add_column('col_binary', de_type=mstype.uint8, shape=[1])
  103. schema.add_column('col_float', de_type=mstype.float32, shape=[1])
  104. schema.add_column('col_sint16', de_type=mstype.int64, shape=[1])
  105. schema.add_column('col_sint32', de_type=mstype.int64, shape=[1])
  106. schema.add_column('col_sint64', de_type=mstype.int64, shape=[1])
  107. data1 = ds.TFRecordDataset(FILES, schema=schema, shuffle=ds.Shuffle.FILES)
  108. data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  109. for d1, d2 in zip(data1, data2):
  110. for t1, t2 in zip(d1, d2):
  111. assert np.array_equal(t1, t2)
  112. def test_tf_record_shuffle():
  113. ds.config.set_seed(1)
  114. data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL)
  115. data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  116. data2 = data2.shuffle(10000)
  117. for d1, d2 in zip(data1, data2):
  118. for t1, t2 in zip(d1, d2):
  119. assert np.array_equal(t1, t2)
  120. def skip_test_tf_record_shard():
  121. tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
  122. "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
  123. def get_res(shard_id, num_repeats):
  124. data1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=shard_id, num_samples=3,
  125. shuffle=ds.Shuffle.FILES)
  126. data1 = data1.repeat(num_repeats)
  127. res = list()
  128. for item in data1.create_dict_iterator():
  129. res.append(item["scalars"][0])
  130. return res
  131. # get separate results from two workers. the 2 results need to satisfy 2 criteria
  132. # 1. two workers always give different results in same epoch (e.g. wrkr1:f1&f3, wrkr2:f2&f4 in one epoch)
  133. # 2. with enough epochs, both workers will get the entire dataset (e,g. ep1_wrkr1: f1&f3, ep2,_wrkr1 f2&f4)
  134. worker1_res = get_res(0, 16)
  135. worker2_res = get_res(1, 16)
  136. # check criteria 1
  137. for i in range(len(worker1_res)):
  138. assert (worker1_res[i] != worker2_res[i])
  139. # check criteria 2
  140. assert (set(worker2_res) == set(worker1_res))
  141. assert (len(set(worker2_res)) == 12)
  142. def test_tf_shard_equal_rows():
  143. tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
  144. "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
  145. def get_res(num_shards, shard_id, num_repeats):
  146. ds1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, shard_equal_rows=True)
  147. ds1 = ds1.repeat(num_repeats)
  148. res = list()
  149. for data in ds1.create_dict_iterator():
  150. res.append(data["scalars"][0])
  151. return res
  152. worker1_res = get_res(3, 0, 2)
  153. worker2_res = get_res(3, 1, 2)
  154. worker3_res = get_res(3, 2, 2)
  155. # check criteria 1
  156. for i in range(len(worker1_res)):
  157. assert (worker1_res[i] != worker2_res[i])
  158. assert (worker2_res[i] != worker3_res[i])
  159. assert (len(worker1_res) == 28)
  160. worker4_res = get_res(1, 0, 1)
  161. assert (len(worker4_res) == 40)
  162. def test_case_tf_file_no_schema_columns_list():
  163. data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"])
  164. row = data.create_dict_iterator().get_next()
  165. assert row["col_sint16"] == [-32768]
  166. with pytest.raises(KeyError) as info:
  167. a = row["col_sint32"]
  168. assert "col_sint32" in str(info.value)
  169. def test_tf_record_schema_columns_list():
  170. schema = ds.Schema()
  171. schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
  172. schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
  173. schema.add_column('col_3d', de_type=mstype.int64, shape=[2, 2, 2])
  174. schema.add_column('col_binary', de_type=mstype.uint8, shape=[1])
  175. schema.add_column('col_float', de_type=mstype.float32, shape=[1])
  176. schema.add_column('col_sint16', de_type=mstype.int64, shape=[1])
  177. schema.add_column('col_sint32', de_type=mstype.int64, shape=[1])
  178. schema.add_column('col_sint64', de_type=mstype.int64, shape=[1])
  179. data = ds.TFRecordDataset(FILES, schema=schema, shuffle=False, columns_list=["col_sint16"])
  180. row = data.create_dict_iterator().get_next()
  181. assert row["col_sint16"] == [-32768]
  182. with pytest.raises(KeyError) as info:
  183. a = row["col_sint32"]
  184. assert "col_sint32" in str(info.value)
  185. def test_case_invalid_files():
  186. valid_file = "../data/dataset/testTFTestAllTypes/test.data"
  187. invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt"
  188. files = [invalid_file, valid_file, SCHEMA_FILE]
  189. data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  190. with pytest.raises(RuntimeError) as info:
  191. row = data.create_dict_iterator().get_next()
  192. assert "cannot be opened" in str(info.value)
  193. assert "not valid tfrecord files" in str(info.value)
  194. assert valid_file not in str(info.value)
  195. assert invalid_file in str(info.value)
  196. assert SCHEMA_FILE in str(info.value)
  197. nonexistent_file = "this/file/does/not/exist"
  198. files = [invalid_file, valid_file, SCHEMA_FILE, nonexistent_file]
  199. with pytest.raises(ValueError) as info:
  200. data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  201. assert "did not match any files" in str(info.value)
  202. assert valid_file not in str(info.value)
  203. assert invalid_file not in str(info.value)
  204. assert SCHEMA_FILE not in str(info.value)
  205. assert nonexistent_file in str(info.value)
  206. if __name__ == '__main__':
  207. test_case_tf_shape()
  208. test_case_tf_file()
  209. test_case_tf_file_no_schema()
  210. test_case_tf_file_pad()
  211. test_tf_files()
  212. test_tf_record_schema()
  213. test_tf_record_shuffle()
  214. test_tf_shard_equal_rows()
  215. test_case_invalid_files()