You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serdes_dataset.py 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing dataset serialize and deserialize in DE
  17. """
  18. import filecmp
  19. import glob
  20. import json
  21. import os
  22. import numpy as np
  23. import mindspore.dataset as ds
  24. import mindspore.dataset.transforms.c_transforms as c
  25. import mindspore.dataset.transforms.vision.c_transforms as vision
  26. from mindspore import log as logger
  27. from mindspore.dataset.transforms.vision import Inter
  28. from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME
  29. from util import config_get_set_num_parallel_workers
  30. def test_imagefolder(remove_json_files=True):
  31. """
  32. Test simulating resnet50 dataset pipeline.
  33. """
  34. data_dir = "../data/dataset/testPK/data"
  35. ds.config.set_seed(1)
  36. # define data augmentation parameters
  37. rescale = 1.0 / 255.0
  38. shift = 0.0
  39. resize_height, resize_width = 224, 224
  40. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  41. # Constructing DE pipeline
  42. sampler = ds.WeightedRandomSampler(weights, 11)
  43. data1 = ds.ImageFolderDatasetV2(data_dir, sampler=sampler)
  44. data1 = data1.repeat(1)
  45. data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
  46. rescale_op = vision.Rescale(rescale, shift)
  47. resize_op = vision.Resize((resize_height, resize_width), Inter.LINEAR)
  48. data1 = data1.map(input_columns=["image"], operations=[rescale_op, resize_op])
  49. data1 = data1.batch(2)
  50. # Serialize the dataset pre-processing pipeline.
  51. # data1 should still work after saving.
  52. ds.serialize(data1, "imagenet_dataset_pipeline.json")
  53. ds1_dict = ds.serialize(data1)
  54. assert validate_jsonfile("imagenet_dataset_pipeline.json") is True
  55. # Print the serialized pipeline to stdout
  56. ds.show(data1)
  57. # Deserialize the serialized json file
  58. data2 = ds.deserialize(json_filepath="imagenet_dataset_pipeline.json")
  59. # Serialize the pipeline we just deserialized.
  60. # The content of the json file should be the same to the previous serialize.
  61. ds.serialize(data2, "imagenet_dataset_pipeline_1.json")
  62. assert validate_jsonfile("imagenet_dataset_pipeline_1.json") is True
  63. assert filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json')
  64. # Deserialize the latest json file again
  65. data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json")
  66. data4 = ds.deserialize(input_dict=ds1_dict)
  67. num_samples = 0
  68. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  69. for item1, item2, item3, item4 in zip(data1.create_dict_iterator(num_epochs=1),
  70. data2.create_dict_iterator(num_epochs=1),
  71. data3.create_dict_iterator(num_epochs=1),
  72. data4.create_dict_iterator(num_epochs=1)):
  73. np.testing.assert_array_equal(item1['image'], item2['image'])
  74. np.testing.assert_array_equal(item1['image'], item3['image'])
  75. np.testing.assert_array_equal(item1['label'], item2['label'])
  76. np.testing.assert_array_equal(item1['label'], item3['label'])
  77. np.testing.assert_array_equal(item3['image'], item4['image'])
  78. np.testing.assert_array_equal(item3['label'], item4['label'])
  79. num_samples += 1
  80. logger.info("Number of data in data1: {}".format(num_samples))
  81. assert num_samples == 6
  82. # Remove the generated json file
  83. if remove_json_files:
  84. delete_json_files()
  85. def test_mnist_dataset(remove_json_files=True):
  86. data_dir = "../data/dataset/testMnistData"
  87. ds.config.set_seed(1)
  88. data1 = ds.MnistDataset(data_dir, 100)
  89. one_hot_encode = c.OneHot(10) # num_classes is input argument
  90. data1 = data1.map(input_columns="label", operations=one_hot_encode)
  91. # batch_size is input argument
  92. data1 = data1.batch(batch_size=10, drop_remainder=True)
  93. ds.serialize(data1, "mnist_dataset_pipeline.json")
  94. assert validate_jsonfile("mnist_dataset_pipeline.json") is True
  95. data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json")
  96. ds.serialize(data2, "mnist_dataset_pipeline_1.json")
  97. assert validate_jsonfile("mnist_dataset_pipeline_1.json") is True
  98. assert filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json')
  99. data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json")
  100. num = 0
  101. for data1, data2, data3 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1),
  102. data3.create_dict_iterator(num_epochs=1)):
  103. np.testing.assert_array_equal(data1['image'], data2['image'])
  104. np.testing.assert_array_equal(data1['image'], data3['image'])
  105. np.testing.assert_array_equal(data1['label'], data2['label'])
  106. np.testing.assert_array_equal(data1['label'], data3['label'])
  107. num += 1
  108. logger.info("mnist total num samples is {}".format(str(num)))
  109. assert num == 10
  110. if remove_json_files:
  111. delete_json_files()
  112. def test_zip_dataset(remove_json_files=True):
  113. files = ["../data/dataset/testTFTestAllTypes/test.data"]
  114. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  115. ds.config.set_seed(1)
  116. ds0 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  117. data1 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  118. data2 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.FILES)
  119. data2 = data2.shuffle(10000)
  120. data2 = data2.rename(input_columns=["col_sint16", "col_sint32", "col_sint64", "col_float",
  121. "col_1d", "col_2d", "col_3d", "col_binary"],
  122. output_columns=["column_sint16", "column_sint32", "column_sint64", "column_float",
  123. "column_1d", "column_2d", "column_3d", "column_binary"])
  124. data3 = ds.zip((data1, data2))
  125. ds.serialize(data3, "zip_dataset_pipeline.json")
  126. assert validate_jsonfile("zip_dataset_pipeline.json") is True
  127. assert validate_jsonfile("zip_dataset_pipeline_typo.json") is False
  128. data4 = ds.deserialize(json_filepath="zip_dataset_pipeline.json")
  129. ds.serialize(data4, "zip_dataset_pipeline_1.json")
  130. assert validate_jsonfile("zip_dataset_pipeline_1.json") is True
  131. assert filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json')
  132. rows = 0
  133. for d0, d3, d4 in zip(ds0, data3, data4):
  134. num_cols = len(d0)
  135. offset = 0
  136. for t1 in d0:
  137. np.testing.assert_array_equal(t1, d3[offset])
  138. np.testing.assert_array_equal(t1, d3[offset + num_cols])
  139. np.testing.assert_array_equal(t1, d4[offset])
  140. np.testing.assert_array_equal(t1, d4[offset + num_cols])
  141. offset += 1
  142. rows += 1
  143. assert rows == 12
  144. if remove_json_files:
  145. delete_json_files()
  146. def test_random_crop():
  147. logger.info("test_random_crop")
  148. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  149. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  150. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  151. # First dataset
  152. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  153. decode_op = vision.Decode()
  154. random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
  155. data1 = data1.map(input_columns="image", operations=decode_op)
  156. data1 = data1.map(input_columns="image", operations=random_crop_op)
  157. # Serializing into python dictionary
  158. ds1_dict = ds.serialize(data1)
  159. # Serializing into json object
  160. _ = json.dumps(ds1_dict, indent=2)
  161. # Reconstruct dataset pipeline from its serialized form
  162. data1_1 = ds.deserialize(input_dict=ds1_dict)
  163. # Second dataset
  164. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  165. data2 = data2.map(input_columns="image", operations=decode_op)
  166. for item1, item1_1, item2 in zip(data1.create_dict_iterator(num_epochs=1),
  167. data1_1.create_dict_iterator(num_epochs=1),
  168. data2.create_dict_iterator(num_epochs=1)):
  169. np.testing.assert_array_equal(item1['image'], item1_1['image'])
  170. _ = item2["image"]
  171. # Restore configuration num_parallel_workers
  172. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  173. def validate_jsonfile(filepath):
  174. try:
  175. file_exist = os.path.exists(filepath)
  176. with open(filepath, 'r') as jfile:
  177. loaded_json = json.load(jfile)
  178. except IOError:
  179. return False
  180. return file_exist and isinstance(loaded_json, dict)
  181. def delete_json_files():
  182. file_list = glob.glob('*.json')
  183. for f in file_list:
  184. try:
  185. os.remove(f)
  186. except IOError:
  187. logger.info("Error while deleting: {}".format(f))
  188. # Test save load minddataset
  189. def test_minddataset(add_and_remove_cv_file):
  190. """tutorial for cv minderdataset."""
  191. columns_list = ["data", "file_name", "label"]
  192. num_readers = 4
  193. indices = [1, 2, 3, 5, 7]
  194. sampler = ds.SubsetRandomSampler(indices)
  195. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  196. sampler=sampler)
  197. # Serializing into python dictionary
  198. ds1_dict = ds.serialize(data_set)
  199. # Serializing into json object
  200. ds1_json = json.dumps(ds1_dict, sort_keys=True)
  201. # Reconstruct dataset pipeline from its serialized form
  202. data_set = ds.deserialize(input_dict=ds1_dict)
  203. ds2_dict = ds.serialize(data_set)
  204. # Serializing into json object
  205. ds2_json = json.dumps(ds2_dict, sort_keys=True)
  206. assert ds1_json == ds2_json
  207. _ = get_data(CV_DIR_NAME)
  208. assert data_set.get_dataset_size() == 5
  209. num_iter = 0
  210. for _ in data_set.create_dict_iterator(num_epochs=1):
  211. num_iter += 1
  212. assert num_iter == 5
  213. if __name__ == '__main__':
  214. test_imagefolder()
  215. test_zip_dataset()
  216. test_mnist_dataset()
  217. test_random_crop()