You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serdes_dataset.py 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing dataset serialize and deserialize in DE
  17. """
  18. import filecmp
  19. import glob
  20. import json
  21. import os
  22. import numpy as np
  23. from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME
  24. from util import config_get_set_num_parallel_workers
  25. import mindspore.dataset as ds
  26. import mindspore.dataset.transforms.c_transforms as c
  27. import mindspore.dataset.vision.c_transforms as vision
  28. from mindspore import log as logger
  29. from mindspore.dataset.vision import Inter
  30. def test_imagefolder(remove_json_files=True):
  31. """
  32. Test simulating resnet50 dataset pipeline.
  33. """
  34. data_dir = "../data/dataset/testPK/data"
  35. ds.config.set_seed(1)
  36. # define data augmentation parameters
  37. rescale = 1.0 / 255.0
  38. shift = 0.0
  39. resize_height, resize_width = 224, 224
  40. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  41. # Constructing DE pipeline
  42. sampler = ds.WeightedRandomSampler(weights, 11)
  43. data1 = ds.ImageFolderDataset(data_dir, sampler=sampler)
  44. data1 = data1.repeat(1)
  45. data1 = data1.map(operations=[vision.Decode(True)], input_columns=["image"])
  46. rescale_op = vision.Rescale(rescale, shift)
  47. resize_op = vision.Resize((resize_height, resize_width), Inter.LINEAR)
  48. data1 = data1.map(operations=[rescale_op, resize_op], input_columns=["image"])
  49. data1 = data1.batch(2)
  50. # Serialize the dataset pre-processing pipeline.
  51. # data1 should still work after saving.
  52. ds.serialize(data1, "imagenet_dataset_pipeline.json")
  53. ds1_dict = ds.serialize(data1)
  54. assert validate_jsonfile("imagenet_dataset_pipeline.json") is True
  55. # Print the serialized pipeline to stdout
  56. ds.show(data1)
  57. # Deserialize the serialized json file
  58. data2 = ds.deserialize(json_filepath="imagenet_dataset_pipeline.json")
  59. # Serialize the pipeline we just deserialized.
  60. # The content of the json file should be the same to the previous serialize.
  61. ds.serialize(data2, "imagenet_dataset_pipeline_1.json")
  62. assert validate_jsonfile("imagenet_dataset_pipeline_1.json") is True
  63. assert filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json')
  64. # Deserialize the latest json file again
  65. data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json")
  66. data4 = ds.deserialize(input_dict=ds1_dict)
  67. num_samples = 0
  68. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  69. for item1, item2, item3, item4 in zip(data1.create_dict_iterator(num_epochs=1),
  70. data2.create_dict_iterator(num_epochs=1),
  71. data3.create_dict_iterator(num_epochs=1),
  72. data4.create_dict_iterator(num_epochs=1)):
  73. np.testing.assert_array_equal(item1['image'], item2['image'])
  74. np.testing.assert_array_equal(item1['image'], item3['image'])
  75. np.testing.assert_array_equal(item1['label'], item2['label'])
  76. np.testing.assert_array_equal(item1['label'], item3['label'])
  77. np.testing.assert_array_equal(item3['image'], item4['image'])
  78. np.testing.assert_array_equal(item3['label'], item4['label'])
  79. num_samples += 1
  80. logger.info("Number of data in data1: {}".format(num_samples))
  81. assert num_samples == 6
  82. # Remove the generated json file
  83. if remove_json_files:
  84. delete_json_files()
  85. def test_mnist_dataset(remove_json_files=True):
  86. data_dir = "../data/dataset/testMnistData"
  87. ds.config.set_seed(1)
  88. data1 = ds.MnistDataset(data_dir, num_samples=100)
  89. one_hot_encode = c.OneHot(10) # num_classes is input argument
  90. data1 = data1.map(operations=one_hot_encode, input_columns="label")
  91. # batch_size is input argument
  92. data1 = data1.batch(batch_size=10, drop_remainder=True)
  93. ds.serialize(data1, "mnist_dataset_pipeline.json")
  94. assert validate_jsonfile("mnist_dataset_pipeline.json") is True
  95. data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json")
  96. ds.serialize(data2, "mnist_dataset_pipeline_1.json")
  97. assert validate_jsonfile("mnist_dataset_pipeline_1.json") is True
  98. assert filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json')
  99. data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json")
  100. num = 0
  101. for data1, data2, data3 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1),
  102. data3.create_dict_iterator(num_epochs=1)):
  103. np.testing.assert_array_equal(data1['image'], data2['image'])
  104. np.testing.assert_array_equal(data1['image'], data3['image'])
  105. np.testing.assert_array_equal(data1['label'], data2['label'])
  106. np.testing.assert_array_equal(data1['label'], data3['label'])
  107. num += 1
  108. logger.info("mnist total num samples is {}".format(str(num)))
  109. assert num == 10
  110. if remove_json_files:
  111. delete_json_files()
  112. def test_zip_dataset(remove_json_files=True):
  113. files = ["../data/dataset/testTFTestAllTypes/test.data"]
  114. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  115. ds.config.set_seed(1)
  116. ds0 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  117. data1 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  118. data2 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.FILES)
  119. data2 = data2.shuffle(10000)
  120. data2 = data2.rename(input_columns=["col_sint16", "col_sint32", "col_sint64", "col_float",
  121. "col_1d", "col_2d", "col_3d", "col_binary"],
  122. output_columns=["column_sint16", "column_sint32", "column_sint64", "column_float",
  123. "column_1d", "column_2d", "column_3d", "column_binary"])
  124. data3 = ds.zip((data1, data2))
  125. ds.serialize(data3, "zip_dataset_pipeline.json")
  126. assert validate_jsonfile("zip_dataset_pipeline.json") is True
  127. assert validate_jsonfile("zip_dataset_pipeline_typo.json") is False
  128. data4 = ds.deserialize(json_filepath="zip_dataset_pipeline.json")
  129. ds.serialize(data4, "zip_dataset_pipeline_1.json")
  130. assert validate_jsonfile("zip_dataset_pipeline_1.json") is True
  131. assert filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json')
  132. rows = 0
  133. for d0, d3, d4 in zip(ds0, data3, data4):
  134. num_cols = len(d0)
  135. offset = 0
  136. for t1 in d0:
  137. np.testing.assert_array_equal(t1, d3[offset])
  138. np.testing.assert_array_equal(t1, d3[offset + num_cols])
  139. np.testing.assert_array_equal(t1, d4[offset])
  140. np.testing.assert_array_equal(t1, d4[offset + num_cols])
  141. offset += 1
  142. rows += 1
  143. assert rows == 12
  144. if remove_json_files:
  145. delete_json_files()
  146. def test_random_crop():
  147. logger.info("test_random_crop")
  148. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  149. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  150. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  151. # First dataset
  152. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  153. decode_op = vision.Decode()
  154. random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
  155. data1 = data1.map(operations=decode_op, input_columns="image")
  156. data1 = data1.map(operations=random_crop_op, input_columns="image")
  157. # Serializing into python dictionary
  158. ds1_dict = ds.serialize(data1)
  159. # Serializing into json object
  160. _ = json.dumps(ds1_dict, indent=2)
  161. # Reconstruct dataset pipeline from its serialized form
  162. data1_1 = ds.deserialize(input_dict=ds1_dict)
  163. # Second dataset
  164. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  165. data2 = data2.map(operations=decode_op, input_columns="image")
  166. for item1, item1_1, item2 in zip(data1.create_dict_iterator(num_epochs=1),
  167. data1_1.create_dict_iterator(num_epochs=1),
  168. data2.create_dict_iterator(num_epochs=1)):
  169. np.testing.assert_array_equal(item1['image'], item1_1['image'])
  170. _ = item2["image"]
  171. # Restore configuration num_parallel_workers
  172. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  173. def validate_jsonfile(filepath):
  174. try:
  175. file_exist = os.path.exists(filepath)
  176. with open(filepath, 'r') as jfile:
  177. loaded_json = json.load(jfile)
  178. except IOError:
  179. return False
  180. return file_exist and isinstance(loaded_json, dict)
  181. def delete_json_files():
  182. file_list = glob.glob('*.json')
  183. for f in file_list:
  184. try:
  185. os.remove(f)
  186. except IOError:
  187. logger.info("Error while deleting: {}".format(f))
  188. # Test save load minddataset
  189. def test_minddataset(add_and_remove_cv_file):
  190. """tutorial for cv minderdataset."""
  191. columns_list = ["data", "file_name", "label"]
  192. num_readers = 4
  193. indices = [1, 2, 3, 5, 7]
  194. sampler = ds.SubsetRandomSampler(indices)
  195. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  196. sampler=sampler)
  197. # Serializing into python dictionary
  198. ds1_dict = ds.serialize(data_set)
  199. # Serializing into json object
  200. ds1_json = json.dumps(ds1_dict, sort_keys=True)
  201. # Reconstruct dataset pipeline from its serialized form
  202. data_set = ds.deserialize(input_dict=ds1_dict)
  203. ds2_dict = ds.serialize(data_set)
  204. # Serializing into json object
  205. ds2_json = json.dumps(ds2_dict, sort_keys=True)
  206. assert ds1_json == ds2_json
  207. _ = get_data(CV_DIR_NAME)
  208. assert data_set.get_dataset_size() == 5
  209. num_iter = 0
  210. for _ in data_set.create_dict_iterator(num_epochs=1):
  211. num_iter += 1
  212. assert num_iter == 5
  213. if __name__ == '__main__':
  214. test_imagefolder()
  215. test_zip_dataset()
  216. test_mnist_dataset()
  217. test_random_crop()