You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serdes_dataset.py 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing dataset serialize and deserialize in DE
  17. """
  18. import filecmp
  19. import glob
  20. import json
  21. import os
  22. import numpy as np
  23. from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME
  24. from util import config_get_set_num_parallel_workers
  25. import mindspore.dataset as ds
  26. import mindspore.dataset.transforms.c_transforms as c
  27. import mindspore.dataset.transforms.vision.c_transforms as vision
  28. from mindspore import log as logger
  29. from mindspore.dataset.transforms.vision import Inter
  30. def test_imagefolder(remove_json_files=True):
  31. """
  32. Test simulating resnet50 dataset pipeline.
  33. """
  34. data_dir = "../data/dataset/testPK/data"
  35. ds.config.set_seed(1)
  36. # define data augmentation parameters
  37. rescale = 1.0 / 255.0
  38. shift = 0.0
  39. resize_height, resize_width = 224, 224
  40. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  41. # Constructing DE pipeline
  42. sampler = ds.WeightedRandomSampler(weights, 11)
  43. data1 = ds.ImageFolderDatasetV2(data_dir, sampler=sampler)
  44. data1 = data1.repeat(1)
  45. data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
  46. rescale_op = vision.Rescale(rescale, shift)
  47. resize_op = vision.Resize((resize_height, resize_width), Inter.LINEAR)
  48. data1 = data1.map(input_columns=["image"], operations=[rescale_op, resize_op])
  49. data1 = data1.batch(2)
  50. # Serialize the dataset pre-processing pipeline.
  51. # data1 should still work after saving.
  52. ds.serialize(data1, "imagenet_dataset_pipeline.json")
  53. ds1_dict = ds.serialize(data1)
  54. assert validate_jsonfile("imagenet_dataset_pipeline.json") is True
  55. # Print the serialized pipeline to stdout
  56. ds.show(data1)
  57. # Deserialize the serialized json file
  58. data2 = ds.deserialize(json_filepath="imagenet_dataset_pipeline.json")
  59. # Serialize the pipeline we just deserialized.
  60. # The content of the json file should be the same to the previous serialize.
  61. ds.serialize(data2, "imagenet_dataset_pipeline_1.json")
  62. assert validate_jsonfile("imagenet_dataset_pipeline_1.json") is True
  63. assert filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json')
  64. # Deserialize the latest json file again
  65. data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json")
  66. data4 = ds.deserialize(input_dict=ds1_dict)
  67. num_samples = 0
  68. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  69. for item1, item2, item3, item4 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(),
  70. data3.create_dict_iterator(), data4.create_dict_iterator()):
  71. np.testing.assert_array_equal(item1['image'], item2['image'])
  72. np.testing.assert_array_equal(item1['image'], item3['image'])
  73. np.testing.assert_array_equal(item1['label'], item2['label'])
  74. np.testing.assert_array_equal(item1['label'], item3['label'])
  75. np.testing.assert_array_equal(item3['image'], item4['image'])
  76. np.testing.assert_array_equal(item3['label'], item4['label'])
  77. num_samples += 1
  78. logger.info("Number of data in data1: {}".format(num_samples))
  79. assert num_samples == 6
  80. # Remove the generated json file
  81. if remove_json_files:
  82. delete_json_files()
  83. def test_mnist_dataset(remove_json_files=True):
  84. data_dir = "../data/dataset/testMnistData"
  85. ds.config.set_seed(1)
  86. data1 = ds.MnistDataset(data_dir, 100)
  87. one_hot_encode = c.OneHot(10) # num_classes is input argument
  88. data1 = data1.map(input_columns="label", operations=one_hot_encode)
  89. # batch_size is input argument
  90. data1 = data1.batch(batch_size=10, drop_remainder=True)
  91. ds.serialize(data1, "mnist_dataset_pipeline.json")
  92. assert validate_jsonfile("mnist_dataset_pipeline.json") is True
  93. data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json")
  94. ds.serialize(data2, "mnist_dataset_pipeline_1.json")
  95. assert validate_jsonfile("mnist_dataset_pipeline_1.json") is True
  96. assert filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json')
  97. data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json")
  98. num = 0
  99. for data1, data2, data3 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(),
  100. data3.create_dict_iterator()):
  101. np.testing.assert_array_equal(data1['image'], data2['image'])
  102. np.testing.assert_array_equal(data1['image'], data3['image'])
  103. np.testing.assert_array_equal(data1['label'], data2['label'])
  104. np.testing.assert_array_equal(data1['label'], data3['label'])
  105. num += 1
  106. logger.info("mnist total num samples is {}".format(str(num)))
  107. assert num == 10
  108. if remove_json_files:
  109. delete_json_files()
  110. def test_zip_dataset(remove_json_files=True):
  111. files = ["../data/dataset/testTFTestAllTypes/test.data"]
  112. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  113. ds.config.set_seed(1)
  114. ds0 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  115. data1 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  116. data2 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.FILES)
  117. data2 = data2.shuffle(10000)
  118. data2 = data2.rename(input_columns=["col_sint16", "col_sint32", "col_sint64", "col_float",
  119. "col_1d", "col_2d", "col_3d", "col_binary"],
  120. output_columns=["column_sint16", "column_sint32", "column_sint64", "column_float",
  121. "column_1d", "column_2d", "column_3d", "column_binary"])
  122. data3 = ds.zip((data1, data2))
  123. ds.serialize(data3, "zip_dataset_pipeline.json")
  124. assert validate_jsonfile("zip_dataset_pipeline.json") is True
  125. assert validate_jsonfile("zip_dataset_pipeline_typo.json") is False
  126. data4 = ds.deserialize(json_filepath="zip_dataset_pipeline.json")
  127. ds.serialize(data4, "zip_dataset_pipeline_1.json")
  128. assert validate_jsonfile("zip_dataset_pipeline_1.json") is True
  129. assert filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json')
  130. rows = 0
  131. for d0, d3, d4 in zip(ds0, data3, data4):
  132. num_cols = len(d0)
  133. offset = 0
  134. for t1 in d0:
  135. np.testing.assert_array_equal(t1, d3[offset])
  136. np.testing.assert_array_equal(t1, d3[offset + num_cols])
  137. np.testing.assert_array_equal(t1, d4[offset])
  138. np.testing.assert_array_equal(t1, d4[offset + num_cols])
  139. offset += 1
  140. rows += 1
  141. assert rows == 12
  142. if remove_json_files:
  143. delete_json_files()
  144. def test_random_crop():
  145. logger.info("test_random_crop")
  146. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  147. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  148. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  149. # First dataset
  150. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  151. decode_op = vision.Decode()
  152. random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
  153. data1 = data1.map(input_columns="image", operations=decode_op)
  154. data1 = data1.map(input_columns="image", operations=random_crop_op)
  155. # Serializing into python dictionary
  156. ds1_dict = ds.serialize(data1)
  157. # Serializing into json object
  158. _ = json.dumps(ds1_dict, indent=2)
  159. # Reconstruct dataset pipeline from its serialized form
  160. data1_1 = ds.deserialize(input_dict=ds1_dict)
  161. # Second dataset
  162. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  163. data2 = data2.map(input_columns="image", operations=decode_op)
  164. for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(),
  165. data2.create_dict_iterator()):
  166. np.testing.assert_array_equal(item1['image'], item1_1['image'])
  167. _ = item2["image"]
  168. # Restore configuration num_parallel_workers
  169. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  170. def validate_jsonfile(filepath):
  171. try:
  172. file_exist = os.path.exists(filepath)
  173. with open(filepath, 'r') as jfile:
  174. loaded_json = json.load(jfile)
  175. except IOError:
  176. return False
  177. return file_exist and isinstance(loaded_json, dict)
  178. def delete_json_files():
  179. file_list = glob.glob('*.json')
  180. for f in file_list:
  181. try:
  182. os.remove(f)
  183. except IOError:
  184. logger.info("Error while deleting: {}".format(f))
  185. # Test save load minddataset
  186. def test_minddataset(add_and_remove_cv_file):
  187. """tutorial for cv minderdataset."""
  188. columns_list = ["data", "file_name", "label"]
  189. num_readers = 4
  190. indices = [1, 2, 3, 5, 7]
  191. sampler = ds.SubsetRandomSampler(indices)
  192. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  193. sampler=sampler)
  194. # Serializing into python dictionary
  195. ds1_dict = ds.serialize(data_set)
  196. # Serializing into json object
  197. ds1_json = json.dumps(ds1_dict, sort_keys=True)
  198. # Reconstruct dataset pipeline from its serialized form
  199. data_set = ds.deserialize(input_dict=ds1_dict)
  200. ds2_dict = ds.serialize(data_set)
  201. # Serializing into json object
  202. ds2_json = json.dumps(ds2_dict, sort_keys=True)
  203. assert ds1_json == ds2_json
  204. _ = get_data(CV_DIR_NAME)
  205. assert data_set.get_dataset_size() == 5
  206. num_iter = 0
  207. for _ in data_set.create_dict_iterator():
  208. num_iter += 1
  209. assert num_iter == 5
  210. if __name__ == '__main__':
  211. test_imagefolder()
  212. test_zip_dataset()
  213. test_mnist_dataset()
  214. test_random_crop()