You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serdes_dataset.py 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing dataset serialize and deserialize in DE
  17. """
  18. import filecmp
  19. import glob
  20. import json
  21. import os
  22. import numpy as np
  23. from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME
  24. from util import config_get_set_num_parallel_workers
  25. import mindspore.dataset as ds
  26. import mindspore.dataset.transforms.c_transforms as c
  27. import mindspore.dataset.vision.c_transforms as vision
  28. from mindspore import log as logger
  29. from mindspore.dataset.vision import Inter
  30. def test_imagefolder(remove_json_files=True):
  31. """
  32. Test simulating resnet50 dataset pipeline.
  33. """
  34. data_dir = "../data/dataset/testPK/data"
  35. ds.config.set_seed(1)
  36. # define data augmentation parameters
  37. rescale = 1.0 / 255.0
  38. shift = 0.0
  39. resize_height, resize_width = 224, 224
  40. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  41. # Constructing DE pipeline
  42. sampler = ds.WeightedRandomSampler(weights, 11)
  43. data1 = ds.ImageFolderDataset(data_dir, sampler=sampler)
  44. data1 = data1.repeat(1)
  45. data1 = data1.map(operations=[vision.Decode(True)], input_columns=["image"])
  46. rescale_op = vision.Rescale(rescale, shift)
  47. resize_op = vision.Resize((resize_height, resize_width), Inter.LINEAR)
  48. data1 = data1.map(operations=[rescale_op, resize_op], input_columns=["image"])
  49. data1 = data1.batch(2)
  50. # Serialize the dataset pre-processing pipeline.
  51. # data1 should still work after saving.
  52. ds.serialize(data1, "imagenet_dataset_pipeline.json")
  53. ds1_dict = ds.serialize(data1)
  54. assert validate_jsonfile("imagenet_dataset_pipeline.json") is True
  55. # Print the serialized pipeline to stdout
  56. ds.show(data1)
  57. # Deserialize the serialized json file
  58. data2 = ds.deserialize(json_filepath="imagenet_dataset_pipeline.json")
  59. # Serialize the pipeline we just deserialized.
  60. # The content of the json file should be the same to the previous serialize.
  61. ds.serialize(data2, "imagenet_dataset_pipeline_1.json")
  62. assert validate_jsonfile("imagenet_dataset_pipeline_1.json") is True
  63. assert filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json')
  64. # Deserialize the latest json file again
  65. data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json")
  66. data4 = ds.deserialize(input_dict=ds1_dict)
  67. num_samples = 0
  68. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  69. for item1, item2, item3, item4 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  70. data2.create_dict_iterator(num_epochs=1, output_numpy=True),
  71. data3.create_dict_iterator(num_epochs=1, output_numpy=True),
  72. data4.create_dict_iterator(num_epochs=1, output_numpy=True)):
  73. np.testing.assert_array_equal(item1['image'], item2['image'])
  74. np.testing.assert_array_equal(item1['image'], item3['image'])
  75. np.testing.assert_array_equal(item1['label'], item2['label'])
  76. np.testing.assert_array_equal(item1['label'], item3['label'])
  77. np.testing.assert_array_equal(item3['image'], item4['image'])
  78. np.testing.assert_array_equal(item3['label'], item4['label'])
  79. num_samples += 1
  80. logger.info("Number of data in data1: {}".format(num_samples))
  81. assert num_samples == 6
  82. # Remove the generated json file
  83. if remove_json_files:
  84. delete_json_files()
  85. def test_mnist_dataset(remove_json_files=True):
  86. data_dir = "../data/dataset/testMnistData"
  87. ds.config.set_seed(1)
  88. data1 = ds.MnistDataset(data_dir, num_samples=100)
  89. one_hot_encode = c.OneHot(10) # num_classes is input argument
  90. data1 = data1.map(operations=one_hot_encode, input_columns="label")
  91. # batch_size is input argument
  92. data1 = data1.batch(batch_size=10, drop_remainder=True)
  93. ds.serialize(data1, "mnist_dataset_pipeline.json")
  94. assert validate_jsonfile("mnist_dataset_pipeline.json") is True
  95. data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json")
  96. ds.serialize(data2, "mnist_dataset_pipeline_1.json")
  97. assert validate_jsonfile("mnist_dataset_pipeline_1.json") is True
  98. assert filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json')
  99. data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json")
  100. num = 0
  101. for data1, data2, data3 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  102. data2.create_dict_iterator(num_epochs=1, output_numpy=True),
  103. data3.create_dict_iterator(num_epochs=1, output_numpy=True)):
  104. np.testing.assert_array_equal(data1['image'], data2['image'])
  105. np.testing.assert_array_equal(data1['image'], data3['image'])
  106. np.testing.assert_array_equal(data1['label'], data2['label'])
  107. np.testing.assert_array_equal(data1['label'], data3['label'])
  108. num += 1
  109. logger.info("mnist total num samples is {}".format(str(num)))
  110. assert num == 10
  111. if remove_json_files:
  112. delete_json_files()
  113. def test_zip_dataset(remove_json_files=True):
  114. files = ["../data/dataset/testTFTestAllTypes/test.data"]
  115. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  116. ds.config.set_seed(1)
  117. ds0 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  118. data1 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  119. data2 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.FILES)
  120. data2 = data2.shuffle(10000)
  121. data2 = data2.rename(input_columns=["col_sint16", "col_sint32", "col_sint64", "col_float",
  122. "col_1d", "col_2d", "col_3d", "col_binary"],
  123. output_columns=["column_sint16", "column_sint32", "column_sint64", "column_float",
  124. "column_1d", "column_2d", "column_3d", "column_binary"])
  125. data3 = ds.zip((data1, data2))
  126. ds.serialize(data3, "zip_dataset_pipeline.json")
  127. assert validate_jsonfile("zip_dataset_pipeline.json") is True
  128. assert validate_jsonfile("zip_dataset_pipeline_typo.json") is False
  129. data4 = ds.deserialize(json_filepath="zip_dataset_pipeline.json")
  130. ds.serialize(data4, "zip_dataset_pipeline_1.json")
  131. assert validate_jsonfile("zip_dataset_pipeline_1.json") is True
  132. assert filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json')
  133. rows = 0
  134. for d0, d3, d4 in zip(ds0.create_tuple_iterator(output_numpy=True), data3.create_tuple_iterator(output_numpy=True),
  135. data4.create_tuple_iterator(output_numpy=True)):
  136. num_cols = len(d0)
  137. offset = 0
  138. for t1 in d0:
  139. np.testing.assert_array_equal(t1, d3[offset])
  140. np.testing.assert_array_equal(t1, d3[offset + num_cols])
  141. np.testing.assert_array_equal(t1, d4[offset])
  142. np.testing.assert_array_equal(t1, d4[offset + num_cols])
  143. offset += 1
  144. rows += 1
  145. assert rows == 12
  146. if remove_json_files:
  147. delete_json_files()
  148. def test_random_crop():
  149. logger.info("test_random_crop")
  150. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  151. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  152. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  153. # First dataset
  154. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  155. decode_op = vision.Decode()
  156. random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
  157. data1 = data1.map(operations=decode_op, input_columns="image")
  158. data1 = data1.map(operations=random_crop_op, input_columns="image")
  159. # Serializing into python dictionary
  160. ds1_dict = ds.serialize(data1)
  161. # Serializing into json object
  162. _ = json.dumps(ds1_dict, indent=2)
  163. # Reconstruct dataset pipeline from its serialized form
  164. data1_1 = ds.deserialize(input_dict=ds1_dict)
  165. # Second dataset
  166. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  167. data2 = data2.map(operations=decode_op, input_columns="image")
  168. for item1, item1_1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  169. data1_1.create_dict_iterator(num_epochs=1, output_numpy=True),
  170. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  171. np.testing.assert_array_equal(item1['image'], item1_1['image'])
  172. _ = item2["image"]
  173. # Restore configuration num_parallel_workers
  174. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  175. def validate_jsonfile(filepath):
  176. try:
  177. file_exist = os.path.exists(filepath)
  178. with open(filepath, 'r') as jfile:
  179. loaded_json = json.load(jfile)
  180. except IOError:
  181. return False
  182. return file_exist and isinstance(loaded_json, dict)
  183. def delete_json_files():
  184. file_list = glob.glob('*.json')
  185. for f in file_list:
  186. try:
  187. os.remove(f)
  188. except IOError:
  189. logger.info("Error while deleting: {}".format(f))
  190. # Test save load minddataset
  191. def test_minddataset(add_and_remove_cv_file):
  192. """tutorial for cv minderdataset."""
  193. columns_list = ["data", "file_name", "label"]
  194. num_readers = 4
  195. indices = [1, 2, 3, 5, 7]
  196. sampler = ds.SubsetRandomSampler(indices)
  197. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  198. sampler=sampler)
  199. # Serializing into python dictionary
  200. ds1_dict = ds.serialize(data_set)
  201. # Serializing into json object
  202. ds1_json = json.dumps(ds1_dict, sort_keys=True)
  203. # Reconstruct dataset pipeline from its serialized form
  204. data_set = ds.deserialize(input_dict=ds1_dict)
  205. ds2_dict = ds.serialize(data_set)
  206. # Serializing into json object
  207. ds2_json = json.dumps(ds2_dict, sort_keys=True)
  208. assert ds1_json == ds2_json
  209. _ = get_data(CV_DIR_NAME)
  210. assert data_set.get_dataset_size() == 5
  211. num_iter = 0
  212. for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  213. num_iter += 1
  214. assert num_iter == 5
  215. if __name__ == '__main__':
  216. test_imagefolder()
  217. test_zip_dataset()
  218. test_mnist_dataset()
  219. test_random_crop()