You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serdes_dataset.py 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing dataset serialize and deserialize in DE
  17. """
  18. import filecmp
  19. import glob
  20. import json
  21. import os
  22. import numpy as np
  23. import mindspore.dataset as ds
  24. import mindspore.dataset.transforms.c_transforms as c
  25. import mindspore.dataset.transforms.vision.c_transforms as vision
  26. from mindspore.dataset.transforms.vision import Inter
  27. from mindspore import log as logger
  28. def test_imagefolder(remove_json_files=True):
  29. """
  30. Test simulating resnet50 dataset pipeline.
  31. """
  32. data_dir = "../data/dataset/testPK/data"
  33. ds.config.set_seed(1)
  34. # define data augmentation parameters
  35. rescale = 1.0 / 255.0
  36. shift = 0.0
  37. resize_height, resize_width = 224, 224
  38. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  39. # Constructing DE pipeline
  40. sampler = ds.WeightedRandomSampler(weights, 11)
  41. data1 = ds.ImageFolderDatasetV2(data_dir, sampler=sampler)
  42. data1 = data1.repeat(1)
  43. data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
  44. rescale_op = vision.Rescale(rescale, shift)
  45. resize_op = vision.Resize((resize_height, resize_width), Inter.LINEAR)
  46. data1 = data1.map(input_columns=["image"], operations=[rescale_op, resize_op])
  47. data1 = data1.batch(2)
  48. # Serialize the dataset pre-processing pipeline.
  49. # data1 should still work after saving.
  50. ds.serialize(data1, "imagenet_dataset_pipeline.json")
  51. ds1_dict = ds.serialize(data1)
  52. assert (validate_jsonfile("imagenet_dataset_pipeline.json") is True)
  53. # Print the serialized pipeline to stdout
  54. ds.show(data1)
  55. # Deserialize the serialized json file
  56. data2 = ds.deserialize(json_filepath="imagenet_dataset_pipeline.json")
  57. # Serialize the pipeline we just deserialized.
  58. # The content of the json file should be the same to the previous serialize.
  59. ds.serialize(data2, "imagenet_dataset_pipeline_1.json")
  60. assert (validate_jsonfile("imagenet_dataset_pipeline_1.json") is True)
  61. assert (filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json'))
  62. # Deserialize the latest json file again
  63. data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json")
  64. data4 = ds.deserialize(input_dict=ds1_dict)
  65. num_samples = 0
  66. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  67. for item1, item2, item3, item4 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(),
  68. data3.create_dict_iterator(), data4.create_dict_iterator()):
  69. assert (np.array_equal(item1['image'], item2['image']))
  70. assert (np.array_equal(item1['image'], item3['image']))
  71. assert (np.array_equal(item1['label'], item2['label']))
  72. assert (np.array_equal(item1['label'], item3['label']))
  73. assert (np.array_equal(item3['image'], item4['image']))
  74. assert (np.array_equal(item3['label'], item4['label']))
  75. num_samples += 1
  76. logger.info("Number of data in data1: {}".format(num_samples))
  77. assert (num_samples == 6)
  78. # Remove the generated json file
  79. if remove_json_files:
  80. delete_json_files()
  81. def test_mnist_dataset(remove_json_files=True):
  82. data_dir = "../data/dataset/testMnistData"
  83. ds.config.set_seed(1)
  84. data1 = ds.MnistDataset(data_dir, 100)
  85. one_hot_encode = c.OneHot(10) # num_classes is input argument
  86. data1 = data1.map(input_columns="label", operations=one_hot_encode)
  87. # batch_size is input argument
  88. data1 = data1.batch(batch_size=10, drop_remainder=True)
  89. ds.serialize(data1, "mnist_dataset_pipeline.json")
  90. assert (validate_jsonfile("mnist_dataset_pipeline.json") is True)
  91. data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json")
  92. ds.serialize(data2, "mnist_dataset_pipeline_1.json")
  93. assert (validate_jsonfile("mnist_dataset_pipeline_1.json") is True)
  94. assert (filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json'))
  95. data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json")
  96. num = 0
  97. for data1, data2, data3 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(),
  98. data3.create_dict_iterator()):
  99. assert (np.array_equal(data1['image'], data2['image']))
  100. assert (np.array_equal(data1['image'], data3['image']))
  101. assert (np.array_equal(data1['label'], data2['label']))
  102. assert (np.array_equal(data1['label'], data3['label']))
  103. num += 1
  104. logger.info("mnist total num samples is {}".format(str(num)))
  105. assert (num == 10)
  106. if remove_json_files:
  107. delete_json_files()
  108. def test_zip_dataset(remove_json_files=True):
  109. files = ["../data/dataset/testTFTestAllTypes/test.data"]
  110. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  111. ds.config.set_seed(1)
  112. ds0 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  113. data1 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  114. data2 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.FILES)
  115. data2 = data2.shuffle(10000)
  116. data2 = data2.rename(input_columns=["col_sint16", "col_sint32", "col_sint64", "col_float",
  117. "col_1d", "col_2d", "col_3d", "col_binary"],
  118. output_columns=["column_sint16", "column_sint32", "column_sint64", "column_float",
  119. "column_1d", "column_2d", "column_3d", "column_binary"])
  120. data3 = ds.zip((data1, data2))
  121. ds.serialize(data3, "zip_dataset_pipeline.json")
  122. assert (validate_jsonfile("zip_dataset_pipeline.json") is True)
  123. assert (validate_jsonfile("zip_dataset_pipeline_typo.json") is False)
  124. data4 = ds.deserialize(json_filepath="zip_dataset_pipeline.json")
  125. ds.serialize(data4, "zip_dataset_pipeline_1.json")
  126. assert (validate_jsonfile("zip_dataset_pipeline_1.json") is True)
  127. assert (filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json'))
  128. rows = 0
  129. for d0, d3, d4 in zip(ds0, data3, data4):
  130. num_cols = len(d0)
  131. offset = 0
  132. for t1 in d0:
  133. assert np.array_equal(t1, d3[offset])
  134. assert np.array_equal(t1, d3[offset + num_cols])
  135. assert np.array_equal(t1, d4[offset])
  136. assert np.array_equal(t1, d4[offset + num_cols])
  137. offset += 1
  138. rows += 1
  139. assert (rows == 12)
  140. if remove_json_files:
  141. delete_json_files()
  142. def test_random_crop():
  143. logger.info("test_random_crop")
  144. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  145. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  146. # First dataset
  147. data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  148. decode_op = vision.Decode()
  149. random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
  150. data1 = data1.map(input_columns="image", operations=decode_op)
  151. data1 = data1.map(input_columns="image", operations=random_crop_op)
  152. # Serializing into python dictionary
  153. ds1_dict = ds.serialize(data1)
  154. # Serializing into json object
  155. ds1_json = json.dumps(ds1_dict, indent=2)
  156. # Reconstruct dataset pipeline from its serialized form
  157. data1_1 = ds.deserialize(input_dict=ds1_dict)
  158. # Second dataset
  159. data2 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  160. data2 = data2.map(input_columns="image", operations=decode_op)
  161. for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(),
  162. data2.create_dict_iterator()):
  163. assert (np.array_equal(item1['image'], item1_1['image']))
  164. image2 = item2["image"]
  165. def validate_jsonfile(filepath):
  166. try:
  167. file_exist = os.path.exists(filepath)
  168. with open(filepath, 'r') as jfile:
  169. loaded_json = json.load(jfile)
  170. except IOError:
  171. return False
  172. return file_exist and isinstance(loaded_json, dict)
  173. def delete_json_files():
  174. file_list = glob.glob('*.json')
  175. for f in file_list:
  176. try:
  177. os.remove(f)
  178. except IOError:
  179. logger.info("Error while deleting: {}".format(f))
  180. if __name__ == '__main__':
  181. test_imagefolder()
  182. test_zip_dataset()
  183. test_mnist_dataset()
  184. test_random_crop()