You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serdes_dataset.py 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing dataset serialize and deserialize in DE
  17. """
  18. import filecmp
  19. import glob
  20. import json
  21. import os
  22. import numpy as np
  23. import mindspore.dataset as ds
  24. import mindspore.dataset.transforms.c_transforms as c
  25. import mindspore.dataset.transforms.vision.c_transforms as vision
  26. from mindspore import log as logger
  27. from mindspore.dataset.transforms.vision import Inter
  28. from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME
  29. def test_imagefolder(remove_json_files=True):
  30. """
  31. Test simulating resnet50 dataset pipeline.
  32. """
  33. data_dir = "../data/dataset/testPK/data"
  34. ds.config.set_seed(1)
  35. # define data augmentation parameters
  36. rescale = 1.0 / 255.0
  37. shift = 0.0
  38. resize_height, resize_width = 224, 224
  39. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  40. # Constructing DE pipeline
  41. sampler = ds.WeightedRandomSampler(weights, 11)
  42. data1 = ds.ImageFolderDatasetV2(data_dir, sampler=sampler)
  43. data1 = data1.repeat(1)
  44. data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
  45. rescale_op = vision.Rescale(rescale, shift)
  46. resize_op = vision.Resize((resize_height, resize_width), Inter.LINEAR)
  47. data1 = data1.map(input_columns=["image"], operations=[rescale_op, resize_op])
  48. data1 = data1.batch(2)
  49. # Serialize the dataset pre-processing pipeline.
  50. # data1 should still work after saving.
  51. ds.serialize(data1, "imagenet_dataset_pipeline.json")
  52. ds1_dict = ds.serialize(data1)
  53. assert validate_jsonfile("imagenet_dataset_pipeline.json") is True
  54. # Print the serialized pipeline to stdout
  55. ds.show(data1)
  56. # Deserialize the serialized json file
  57. data2 = ds.deserialize(json_filepath="imagenet_dataset_pipeline.json")
  58. # Serialize the pipeline we just deserialized.
  59. # The content of the json file should be the same to the previous serialize.
  60. ds.serialize(data2, "imagenet_dataset_pipeline_1.json")
  61. assert validate_jsonfile("imagenet_dataset_pipeline_1.json") is True
  62. assert filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json')
  63. # Deserialize the latest json file again
  64. data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json")
  65. data4 = ds.deserialize(input_dict=ds1_dict)
  66. num_samples = 0
  67. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  68. for item1, item2, item3, item4 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(),
  69. data3.create_dict_iterator(), data4.create_dict_iterator()):
  70. assert np.array_equal(item1['image'], item2['image'])
  71. assert np.array_equal(item1['image'], item3['image'])
  72. assert np.array_equal(item1['label'], item2['label'])
  73. assert np.array_equal(item1['label'], item3['label'])
  74. assert np.array_equal(item3['image'], item4['image'])
  75. assert np.array_equal(item3['label'], item4['label'])
  76. num_samples += 1
  77. logger.info("Number of data in data1: {}".format(num_samples))
  78. assert num_samples == 6
  79. # Remove the generated json file
  80. if remove_json_files:
  81. delete_json_files()
  82. def test_mnist_dataset(remove_json_files=True):
  83. data_dir = "../data/dataset/testMnistData"
  84. ds.config.set_seed(1)
  85. data1 = ds.MnistDataset(data_dir, 100)
  86. one_hot_encode = c.OneHot(10) # num_classes is input argument
  87. data1 = data1.map(input_columns="label", operations=one_hot_encode)
  88. # batch_size is input argument
  89. data1 = data1.batch(batch_size=10, drop_remainder=True)
  90. ds.serialize(data1, "mnist_dataset_pipeline.json")
  91. assert validate_jsonfile("mnist_dataset_pipeline.json") is True
  92. data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json")
  93. ds.serialize(data2, "mnist_dataset_pipeline_1.json")
  94. assert validate_jsonfile("mnist_dataset_pipeline_1.json") is True
  95. assert filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json')
  96. data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json")
  97. num = 0
  98. for data1, data2, data3 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(),
  99. data3.create_dict_iterator()):
  100. assert np.array_equal(data1['image'], data2['image'])
  101. assert np.array_equal(data1['image'], data3['image'])
  102. assert np.array_equal(data1['label'], data2['label'])
  103. assert np.array_equal(data1['label'], data3['label'])
  104. num += 1
  105. logger.info("mnist total num samples is {}".format(str(num)))
  106. assert num == 10
  107. if remove_json_files:
  108. delete_json_files()
  109. def test_zip_dataset(remove_json_files=True):
  110. files = ["../data/dataset/testTFTestAllTypes/test.data"]
  111. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  112. ds.config.set_seed(1)
  113. ds0 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  114. data1 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  115. data2 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.FILES)
  116. data2 = data2.shuffle(10000)
  117. data2 = data2.rename(input_columns=["col_sint16", "col_sint32", "col_sint64", "col_float",
  118. "col_1d", "col_2d", "col_3d", "col_binary"],
  119. output_columns=["column_sint16", "column_sint32", "column_sint64", "column_float",
  120. "column_1d", "column_2d", "column_3d", "column_binary"])
  121. data3 = ds.zip((data1, data2))
  122. ds.serialize(data3, "zip_dataset_pipeline.json")
  123. assert validate_jsonfile("zip_dataset_pipeline.json") is True
  124. assert validate_jsonfile("zip_dataset_pipeline_typo.json") is False
  125. data4 = ds.deserialize(json_filepath="zip_dataset_pipeline.json")
  126. ds.serialize(data4, "zip_dataset_pipeline_1.json")
  127. assert validate_jsonfile("zip_dataset_pipeline_1.json") is True
  128. assert filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json')
  129. rows = 0
  130. for d0, d3, d4 in zip(ds0, data3, data4):
  131. num_cols = len(d0)
  132. offset = 0
  133. for t1 in d0:
  134. assert np.array_equal(t1, d3[offset])
  135. assert np.array_equal(t1, d3[offset + num_cols])
  136. assert np.array_equal(t1, d4[offset])
  137. assert np.array_equal(t1, d4[offset + num_cols])
  138. offset += 1
  139. rows += 1
  140. assert rows == 12
  141. if remove_json_files:
  142. delete_json_files()
  143. def test_random_crop():
  144. logger.info("test_random_crop")
  145. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  146. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  147. # First dataset
  148. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  149. decode_op = vision.Decode()
  150. random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
  151. data1 = data1.map(input_columns="image", operations=decode_op)
  152. data1 = data1.map(input_columns="image", operations=random_crop_op)
  153. # Serializing into python dictionary
  154. ds1_dict = ds.serialize(data1)
  155. # Serializing into json object
  156. _ = json.dumps(ds1_dict, indent=2)
  157. # Reconstruct dataset pipeline from its serialized form
  158. data1_1 = ds.deserialize(input_dict=ds1_dict)
  159. # Second dataset
  160. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  161. data2 = data2.map(input_columns="image", operations=decode_op)
  162. for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(),
  163. data2.create_dict_iterator()):
  164. assert np.array_equal(item1['image'], item1_1['image'])
  165. _ = item2["image"]
  166. def validate_jsonfile(filepath):
  167. try:
  168. file_exist = os.path.exists(filepath)
  169. with open(filepath, 'r') as jfile:
  170. loaded_json = json.load(jfile)
  171. except IOError:
  172. return False
  173. return file_exist and isinstance(loaded_json, dict)
  174. def delete_json_files():
  175. file_list = glob.glob('*.json')
  176. for f in file_list:
  177. try:
  178. os.remove(f)
  179. except IOError:
  180. logger.info("Error while deleting: {}".format(f))
  181. # Test save load minddataset
  182. def test_minddataset(add_and_remove_cv_file):
  183. """tutorial for cv minderdataset."""
  184. columns_list = ["data", "file_name", "label"]
  185. num_readers = 4
  186. indices = [1, 2, 3, 5, 7]
  187. sampler = ds.SubsetRandomSampler(indices)
  188. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  189. sampler=sampler)
  190. # Serializing into python dictionary
  191. ds1_dict = ds.serialize(data_set)
  192. # Serializing into json object
  193. ds1_json = json.dumps(ds1_dict, sort_keys=True)
  194. # Reconstruct dataset pipeline from its serialized form
  195. data_set = ds.deserialize(input_dict=ds1_dict)
  196. ds2_dict = ds.serialize(data_set)
  197. # Serializing into json object
  198. ds2_json = json.dumps(ds2_dict, sort_keys=True)
  199. assert ds1_json == ds2_json
  200. _ = get_data(CV_DIR_NAME)
  201. assert data_set.get_dataset_size() == 5
  202. num_iter = 0
  203. for _ in data_set.create_dict_iterator():
  204. num_iter += 1
  205. assert num_iter == 5
  206. if __name__ == '__main__':
  207. test_imagefolder()
  208. test_zip_dataset()
  209. test_mnist_dataset()
  210. test_random_crop()