You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tfreader_op.py 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. from util import save_and_check
  17. import mindspore.common.dtype as mstype
  18. import mindspore.dataset as ds
  19. from mindspore import log as logger
  20. import pytest
  21. FILES = ["../data/dataset/testTFTestAllTypes/test.data"]
  22. DATASET_ROOT = "../data/dataset/testTFTestAllTypes/"
  23. SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  24. GENERATE_GOLDEN = False
  25. def test_case_tf_shape():
  26. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json"
  27. ds1 = ds.TFRecordDataset(FILES, schema_file)
  28. ds1 = ds1.batch(2)
  29. for data in ds1.create_dict_iterator():
  30. logger.info(data)
  31. output_shape = ds1.output_shapes()
  32. assert (len(output_shape[-1]) == 1)
  33. def test_case_tf_shape_2():
  34. ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE)
  35. ds1 = ds1.batch(2)
  36. output_shape = ds1.output_shapes()
  37. assert (len(output_shape[-1]) == 2)
  38. def test_case_tf_file():
  39. logger.info("reading data from: {}".format(FILES[0]))
  40. parameters = {"params": {}}
  41. data = ds.TFRecordDataset(FILES, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  42. filename = "tfreader_result.npz"
  43. save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN)
  44. def test_case_tf_file_no_schema():
  45. logger.info("reading data from: {}".format(FILES[0]))
  46. parameters = {"params": {}}
  47. data = ds.TFRecordDataset(FILES, shuffle=ds.Shuffle.FILES)
  48. filename = "tf_file_no_schema.npz"
  49. save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN)
  50. def test_case_tf_file_pad():
  51. logger.info("reading data from: {}".format(FILES[0]))
  52. parameters = {"params": {}}
  53. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json"
  54. data = ds.TFRecordDataset(FILES, schema_file, shuffle=ds.Shuffle.FILES)
  55. filename = "tf_file_padBytes10.npz"
  56. save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN)
  57. def test_tf_files():
  58. pattern = DATASET_ROOT + "/test.data"
  59. data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  60. assert sum([1 for _ in data]) == 12
  61. pattern = DATASET_ROOT + "/test2.data"
  62. data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  63. assert sum([1 for _ in data]) == 12
  64. pattern = DATASET_ROOT + "/*.data"
  65. data = ds.TFRecordDataset(pattern, SCHEMA_FILE, num_samples=24, shuffle=ds.Shuffle.FILES)
  66. assert sum([1 for _ in data]) == 24
  67. pattern = DATASET_ROOT + "/*.data"
  68. data = ds.TFRecordDataset(pattern, SCHEMA_FILE, num_samples=3, shuffle=ds.Shuffle.FILES)
  69. assert sum([1 for _ in data]) == 3
  70. data = ds.TFRecordDataset([DATASET_ROOT + "/test.data", DATASET_ROOT + "/test2.data"],
  71. SCHEMA_FILE, num_samples=24, shuffle=ds.Shuffle.FILES)
  72. assert sum([1 for _ in data]) == 24
  73. def test_tf_record_schema():
  74. schema = ds.Schema()
  75. schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
  76. schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
  77. schema.add_column('col_3d', de_type=mstype.int64, shape=[2, 2, 2])
  78. schema.add_column('col_binary', de_type=mstype.uint8, shape=[1])
  79. schema.add_column('col_float', de_type=mstype.float32, shape=[1])
  80. schema.add_column('col_sint16', de_type=mstype.int64, shape=[1])
  81. schema.add_column('col_sint32', de_type=mstype.int64, shape=[1])
  82. schema.add_column('col_sint64', de_type=mstype.int64, shape=[1])
  83. data1 = ds.TFRecordDataset(FILES, schema=schema, shuffle=ds.Shuffle.FILES)
  84. data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  85. for d1, d2 in zip(data1, data2):
  86. for t1, t2 in zip(d1, d2):
  87. assert np.array_equal(t1, t2)
  88. def test_tf_record_shuffle():
  89. ds.config.set_seed(1)
  90. data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL)
  91. data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  92. data2 = data2.shuffle(10000)
  93. for d1, d2 in zip(data1, data2):
  94. for t1, t2 in zip(d1, d2):
  95. assert np.array_equal(t1, t2)
  96. def skip_test_tf_record_shard():
  97. tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
  98. "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
  99. def get_res(shard_id, num_repeats):
  100. data1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=shard_id, num_samples=3,
  101. shuffle=ds.Shuffle.FILES)
  102. data1 = data1.repeat(num_repeats)
  103. res = list()
  104. for item in data1.create_dict_iterator():
  105. res.append(item["scalars"][0])
  106. return res
  107. # get separate results from two workers. the 2 results need to satisfy 2 criteria
  108. # 1. two workers always give different results in same epoch (e.g. wrkr1:f1&f3, wrkr2:f2&f4 in one epoch)
  109. # 2. with enough epochs, both workers will get the entire dataset (e,g. ep1_wrkr1: f1&f3, ep2,_wrkr1 f2&f4)
  110. worker1_res = get_res(0, 16)
  111. worker2_res = get_res(1, 16)
  112. # check criteria 1
  113. for i in range(len(worker1_res)):
  114. assert (worker1_res[i] != worker2_res[i])
  115. # check criteria 2
  116. assert (set(worker2_res) == set(worker1_res))
  117. assert (len(set(worker2_res)) == 12)
  118. def test_tf_shard_equal_rows():
  119. tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
  120. "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
  121. def get_res(num_shards, shard_id, num_repeats):
  122. ds1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, shard_equal_rows=True)
  123. ds1 = ds1.repeat(num_repeats)
  124. res = list()
  125. for data in ds1.create_dict_iterator():
  126. res.append(data["scalars"][0])
  127. return res
  128. worker1_res = get_res(3, 0, 2)
  129. worker2_res = get_res(3, 1, 2)
  130. worker3_res = get_res(3, 2, 2)
  131. # check criteria 1
  132. for i in range(len(worker1_res)):
  133. assert (worker1_res[i] != worker2_res[i])
  134. assert (worker2_res[i] != worker3_res[i])
  135. assert (len(worker1_res) == 28)
  136. worker4_res = get_res(1, 0, 1)
  137. assert (len(worker4_res) == 40)
  138. def test_case_tf_file_no_schema_columns_list():
  139. data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"])
  140. row = data.create_dict_iterator().get_next()
  141. assert row["col_sint16"] == [-32768]
  142. with pytest.raises(KeyError) as info:
  143. a = row["col_sint32"]
  144. assert "col_sint32" in str(info.value)
  145. def test_tf_record_schema_columns_list():
  146. schema = ds.Schema()
  147. schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
  148. schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2])
  149. schema.add_column('col_3d', de_type=mstype.int64, shape=[2, 2, 2])
  150. schema.add_column('col_binary', de_type=mstype.uint8, shape=[1])
  151. schema.add_column('col_float', de_type=mstype.float32, shape=[1])
  152. schema.add_column('col_sint16', de_type=mstype.int64, shape=[1])
  153. schema.add_column('col_sint32', de_type=mstype.int64, shape=[1])
  154. schema.add_column('col_sint64', de_type=mstype.int64, shape=[1])
  155. data = ds.TFRecordDataset(FILES, schema=schema, shuffle=False, columns_list=["col_sint16"])
  156. row = data.create_dict_iterator().get_next()
  157. assert row["col_sint16"] == [-32768]
  158. with pytest.raises(KeyError) as info:
  159. a = row["col_sint32"]
  160. assert "col_sint32" in str(info.value)
  161. def test_case_invalid_files():
  162. valid_file = "../data/dataset/testTFTestAllTypes/test.data"
  163. invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt"
  164. files = [invalid_file, valid_file, SCHEMA_FILE]
  165. data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  166. with pytest.raises(RuntimeError) as info:
  167. row = data.create_dict_iterator().get_next()
  168. assert "cannot be opened" in str(info.value)
  169. assert "not valid tfrecord files" in str(info.value)
  170. assert valid_file not in str(info.value)
  171. assert invalid_file in str(info.value)
  172. assert SCHEMA_FILE in str(info.value)
  173. nonexistent_file = "this/file/does/not/exist"
  174. files = [invalid_file, valid_file, SCHEMA_FILE, nonexistent_file]
  175. with pytest.raises(ValueError) as info:
  176. data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  177. assert "did not match any files" in str(info.value)
  178. assert valid_file not in str(info.value)
  179. assert invalid_file not in str(info.value)
  180. assert SCHEMA_FILE not in str(info.value)
  181. assert nonexistent_file in str(info.value)
  182. if __name__ == '__main__':
  183. test_case_tf_shape()
  184. test_case_tf_file()
  185. test_case_tf_file_no_schema()
  186. test_case_tf_file_pad()
  187. test_tf_files()
  188. test_tf_record_schema()
  189. test_tf_record_shuffle()
  190. test_tf_shard_equal_rows()
  191. test_case_invalid_files()