You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serdes_dataset.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing dataset serialize and deserialize in DE
  17. """
  18. import filecmp
  19. import glob
  20. import json
  21. import os
  22. import numpy as np
  23. from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME
  24. from util import config_get_set_num_parallel_workers, config_get_set_seed
  25. import mindspore.dataset as ds
  26. import mindspore.dataset.transforms.c_transforms as c
  27. import mindspore.dataset.vision.c_transforms as vision
  28. from mindspore import log as logger
  29. from mindspore.dataset.vision import Inter
  30. def skip_test_imagefolder(remove_json_files=True):
  31. """
  32. Test simulating resnet50 dataset pipeline.
  33. """
  34. data_dir = "../data/dataset/testPK/data"
  35. ds.config.set_seed(1)
  36. # define data augmentation parameters
  37. rescale = 1.0 / 255.0
  38. shift = 0.0
  39. resize_height, resize_width = 224, 224
  40. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  41. # Constructing DE pipeline
  42. sampler = ds.WeightedRandomSampler(weights, 11)
  43. child_sampler = ds.SequentialSampler()
  44. sampler.add_child(child_sampler)
  45. data1 = ds.ImageFolderDataset(data_dir, sampler=sampler)
  46. data1 = data1.repeat(1)
  47. data1 = data1.map(operations=[vision.Decode(True)], input_columns=["image"])
  48. rescale_op = vision.Rescale(rescale, shift)
  49. resize_op = vision.Resize((resize_height, resize_width), Inter.LINEAR)
  50. data1 = data1.map(operations=[rescale_op, resize_op], input_columns=["image"])
  51. data1 = data1.batch(2)
  52. # Serialize the dataset pre-processing pipeline.
  53. # data1 should still work after saving.
  54. ds.serialize(data1, "imagenet_dataset_pipeline.json")
  55. ds1_dict = ds.serialize(data1)
  56. assert validate_jsonfile("imagenet_dataset_pipeline.json") is True
  57. # Print the serialized pipeline to stdout
  58. ds.show(data1)
  59. # Deserialize the serialized json file
  60. data2 = ds.deserialize(json_filepath="imagenet_dataset_pipeline.json")
  61. # Serialize the pipeline we just deserialized.
  62. # The content of the json file should be the same to the previous serialize.
  63. ds.serialize(data2, "imagenet_dataset_pipeline_1.json")
  64. assert validate_jsonfile("imagenet_dataset_pipeline_1.json") is True
  65. assert filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json')
  66. # Deserialize the latest json file again
  67. data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json")
  68. data4 = ds.deserialize(input_dict=ds1_dict)
  69. num_samples = 0
  70. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  71. for item1, item2, item3, item4 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  72. data2.create_dict_iterator(num_epochs=1, output_numpy=True),
  73. data3.create_dict_iterator(num_epochs=1, output_numpy=True),
  74. data4.create_dict_iterator(num_epochs=1, output_numpy=True)):
  75. np.testing.assert_array_equal(item1['image'], item2['image'])
  76. np.testing.assert_array_equal(item1['image'], item3['image'])
  77. np.testing.assert_array_equal(item1['label'], item2['label'])
  78. np.testing.assert_array_equal(item1['label'], item3['label'])
  79. np.testing.assert_array_equal(item3['image'], item4['image'])
  80. np.testing.assert_array_equal(item3['label'], item4['label'])
  81. num_samples += 1
  82. logger.info("Number of data in data1: {}".format(num_samples))
  83. assert num_samples == 6
  84. # Remove the generated json file
  85. if remove_json_files:
  86. delete_json_files()
  87. def test_mnist_dataset(remove_json_files=True):
  88. """
  89. Test serdes on mnist dataset pipeline.
  90. """
  91. data_dir = "../data/dataset/testMnistData"
  92. ds.config.set_seed(1)
  93. data1 = ds.MnistDataset(data_dir, num_samples=100)
  94. one_hot_encode = c.OneHot(10) # num_classes is input argument
  95. data1 = data1.map(operations=one_hot_encode, input_columns="label")
  96. # batch_size is input argument
  97. data1 = data1.batch(batch_size=10, drop_remainder=True)
  98. ds.serialize(data1, "mnist_dataset_pipeline.json")
  99. assert validate_jsonfile("mnist_dataset_pipeline.json") is True
  100. data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json")
  101. ds.serialize(data2, "mnist_dataset_pipeline_1.json")
  102. assert validate_jsonfile("mnist_dataset_pipeline_1.json") is True
  103. assert filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json')
  104. data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json")
  105. num = 0
  106. for data1, data2, data3 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  107. data2.create_dict_iterator(num_epochs=1, output_numpy=True),
  108. data3.create_dict_iterator(num_epochs=1, output_numpy=True)):
  109. np.testing.assert_array_equal(data1['image'], data2['image'])
  110. np.testing.assert_array_equal(data1['image'], data3['image'])
  111. np.testing.assert_array_equal(data1['label'], data2['label'])
  112. np.testing.assert_array_equal(data1['label'], data3['label'])
  113. num += 1
  114. logger.info("mnist total num samples is {}".format(str(num)))
  115. assert num == 10
  116. if remove_json_files:
  117. delete_json_files()
  118. def test_zip_dataset(remove_json_files=True):
  119. """
  120. Test serdes on zip dataset pipeline.
  121. """
  122. files = ["../data/dataset/testTFTestAllTypes/test.data"]
  123. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  124. ds.config.set_seed(1)
  125. ds0 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  126. data1 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  127. data2 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.FILES)
  128. data2 = data2.shuffle(10000)
  129. data2 = data2.rename(input_columns=["col_sint16", "col_sint32", "col_sint64", "col_float",
  130. "col_1d", "col_2d", "col_3d", "col_binary"],
  131. output_columns=["column_sint16", "column_sint32", "column_sint64", "column_float",
  132. "column_1d", "column_2d", "column_3d", "column_binary"])
  133. data3 = ds.zip((data1, data2))
  134. ds.serialize(data3, "zip_dataset_pipeline.json")
  135. assert validate_jsonfile("zip_dataset_pipeline.json") is True
  136. assert validate_jsonfile("zip_dataset_pipeline_typo.json") is False
  137. data4 = ds.deserialize(json_filepath="zip_dataset_pipeline.json")
  138. ds.serialize(data4, "zip_dataset_pipeline_1.json")
  139. assert validate_jsonfile("zip_dataset_pipeline_1.json") is True
  140. assert filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json')
  141. rows = 0
  142. for d0, d3, d4 in zip(ds0.create_tuple_iterator(output_numpy=True), data3.create_tuple_iterator(output_numpy=True),
  143. data4.create_tuple_iterator(output_numpy=True)):
  144. num_cols = len(d0)
  145. offset = 0
  146. for t1 in d0:
  147. np.testing.assert_array_equal(t1, d3[offset])
  148. np.testing.assert_array_equal(t1, d3[offset + num_cols])
  149. np.testing.assert_array_equal(t1, d4[offset])
  150. np.testing.assert_array_equal(t1, d4[offset + num_cols])
  151. offset += 1
  152. rows += 1
  153. assert rows == 12
  154. if remove_json_files:
  155. delete_json_files()
  156. def skip_test_random_crop():
  157. """
  158. Test serdes on RandomCrop pipeline.
  159. """
  160. logger.info("test_random_crop")
  161. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  162. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  163. original_seed = config_get_set_seed(1)
  164. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  165. # First dataset
  166. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  167. decode_op = vision.Decode()
  168. random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
  169. data1 = data1.map(operations=decode_op, input_columns="image")
  170. data1 = data1.map(operations=random_crop_op, input_columns="image")
  171. # Serializing into python dictionary
  172. ds1_dict = ds.serialize(data1)
  173. # Serializing into json object
  174. _ = json.dumps(ds1_dict, indent=2)
  175. # Reconstruct dataset pipeline from its serialized form
  176. data1_1 = ds.deserialize(input_dict=ds1_dict)
  177. # Second dataset
  178. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  179. data2 = data2.map(operations=decode_op, input_columns="image")
  180. for item1, item1_1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  181. data1_1.create_dict_iterator(num_epochs=1, output_numpy=True),
  182. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  183. np.testing.assert_array_equal(item1['image'], item1_1['image'])
  184. _ = item2["image"]
  185. # Restore configuration num_parallel_workers
  186. ds.config.set_seed(original_seed)
  187. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  188. def validate_jsonfile(filepath):
  189. try:
  190. file_exist = os.path.exists(filepath)
  191. with open(filepath, 'r') as jfile:
  192. loaded_json = json.load(jfile)
  193. except IOError:
  194. return False
  195. return file_exist and isinstance(loaded_json, dict)
  196. def delete_json_files():
  197. file_list = glob.glob('*.json')
  198. for f in file_list:
  199. try:
  200. os.remove(f)
  201. except IOError:
  202. logger.info("Error while deleting: {}".format(f))
  203. # Test save load minddataset
  204. def skip_test_minddataset(add_and_remove_cv_file):
  205. """tutorial for cv minderdataset."""
  206. columns_list = ["data", "file_name", "label"]
  207. num_readers = 4
  208. indices = [1, 2, 3, 5, 7]
  209. sampler = ds.SubsetRandomSampler(indices)
  210. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  211. sampler=sampler)
  212. # Serializing into python dictionary
  213. ds1_dict = ds.serialize(data_set)
  214. # Serializing into json object
  215. ds1_json = json.dumps(ds1_dict, sort_keys=True)
  216. # Reconstruct dataset pipeline from its serialized form
  217. data_set = ds.deserialize(input_dict=ds1_dict)
  218. ds2_dict = ds.serialize(data_set)
  219. # Serializing into json object
  220. ds2_json = json.dumps(ds2_dict, sort_keys=True)
  221. assert ds1_json == ds2_json
  222. _ = get_data(CV_DIR_NAME)
  223. assert data_set.get_dataset_size() == 5
  224. num_iter = 0
  225. for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  226. num_iter += 1
  227. assert num_iter == 5
  228. if __name__ == '__main__':
  229. test_imagefolder()
  230. test_zip_dataset()
  231. test_mnist_dataset()
  232. test_random_crop()