You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serdes_dataset.py 19 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing dataset serialize and deserialize in DE
  17. """
  18. import filecmp
  19. import glob
  20. import json
  21. import os
  22. import pytest
  23. import numpy as np
  24. from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME
  25. from util import config_get_set_num_parallel_workers, config_get_set_seed
  26. import mindspore.common.dtype as mstype
  27. import mindspore.dataset as ds
  28. import mindspore.dataset.transforms.c_transforms as c
  29. import mindspore.dataset.vision.c_transforms as vision
  30. from mindspore import log as logger
  31. from mindspore.dataset.vision import Inter
  32. def test_serdes_imagefolder_dataset(remove_json_files=True):
  33. """
  34. Test simulating resnet50 dataset pipeline.
  35. """
  36. data_dir = "../data/dataset/testPK/data"
  37. ds.config.set_seed(1)
  38. # define data augmentation parameters
  39. rescale = 1.0 / 255.0
  40. shift = 0.0
  41. resize_height, resize_width = 224, 224
  42. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  43. # Constructing DE pipeline
  44. sampler = ds.WeightedRandomSampler(weights, 11)
  45. child_sampler = ds.SequentialSampler()
  46. sampler.add_child(child_sampler)
  47. data1 = ds.ImageFolderDataset(data_dir, sampler=sampler)
  48. data1 = data1.repeat(1)
  49. data1 = data1.map(operations=[vision.Decode(True)], input_columns=["image"])
  50. rescale_op = vision.Rescale(rescale, shift)
  51. resize_op = vision.Resize((resize_height, resize_width), Inter.LINEAR)
  52. data1 = data1.map(operations=[rescale_op, resize_op], input_columns=["image"])
  53. data1 = data1.batch(2)
  54. # Serialize the dataset pre-processing pipeline.
  55. # data1 should still work after saving.
  56. ds.serialize(data1, "imagenet_dataset_pipeline.json")
  57. ds1_dict = ds.serialize(data1)
  58. assert validate_jsonfile("imagenet_dataset_pipeline.json") is True
  59. # Print the serialized pipeline to stdout
  60. ds.show(data1)
  61. # Deserialize the serialized json file
  62. data2 = ds.deserialize(json_filepath="imagenet_dataset_pipeline.json")
  63. # Serialize the pipeline we just deserialized.
  64. # The content of the json file should be the same to the previous serialize.
  65. ds.serialize(data2, "imagenet_dataset_pipeline_1.json")
  66. assert validate_jsonfile("imagenet_dataset_pipeline_1.json") is True
  67. assert filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json')
  68. # Deserialize the latest json file again
  69. data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json")
  70. data4 = ds.deserialize(input_dict=ds1_dict)
  71. num_samples = 0
  72. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  73. for item1, item2, item3, item4 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  74. data2.create_dict_iterator(num_epochs=1, output_numpy=True),
  75. data3.create_dict_iterator(num_epochs=1, output_numpy=True),
  76. data4.create_dict_iterator(num_epochs=1, output_numpy=True)):
  77. np.testing.assert_array_equal(item1['image'], item2['image'])
  78. np.testing.assert_array_equal(item1['image'], item3['image'])
  79. np.testing.assert_array_equal(item1['label'], item2['label'])
  80. np.testing.assert_array_equal(item1['label'], item3['label'])
  81. np.testing.assert_array_equal(item3['image'], item4['image'])
  82. np.testing.assert_array_equal(item3['label'], item4['label'])
  83. num_samples += 1
  84. logger.info("Number of data in data1: {}".format(num_samples))
  85. assert num_samples == 6
  86. # Remove the generated json file
  87. if remove_json_files:
  88. delete_json_files()
  89. def test_serdes_mnist_dataset(remove_json_files=True):
  90. """
  91. Test serdes on mnist dataset pipeline.
  92. """
  93. data_dir = "../data/dataset/testMnistData"
  94. ds.config.set_seed(1)
  95. data1 = ds.MnistDataset(data_dir, num_samples=100)
  96. one_hot_encode = c.OneHot(10) # num_classes is input argument
  97. data1 = data1.map(operations=one_hot_encode, input_columns="label")
  98. # batch_size is input argument
  99. data1 = data1.batch(batch_size=10, drop_remainder=True)
  100. ds.serialize(data1, "mnist_dataset_pipeline.json")
  101. assert validate_jsonfile("mnist_dataset_pipeline.json") is True
  102. data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json")
  103. ds.serialize(data2, "mnist_dataset_pipeline_1.json")
  104. assert validate_jsonfile("mnist_dataset_pipeline_1.json") is True
  105. assert filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json')
  106. data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json")
  107. num = 0
  108. for data1, data2, data3 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  109. data2.create_dict_iterator(num_epochs=1, output_numpy=True),
  110. data3.create_dict_iterator(num_epochs=1, output_numpy=True)):
  111. np.testing.assert_array_equal(data1['image'], data2['image'])
  112. np.testing.assert_array_equal(data1['image'], data3['image'])
  113. np.testing.assert_array_equal(data1['label'], data2['label'])
  114. np.testing.assert_array_equal(data1['label'], data3['label'])
  115. num += 1
  116. logger.info("mnist total num samples is {}".format(str(num)))
  117. assert num == 10
  118. if remove_json_files:
  119. delete_json_files()
  120. def test_serdes_zip_dataset(remove_json_files=True):
  121. """
  122. Test serdes on zip dataset pipeline.
  123. """
  124. files = ["../data/dataset/testTFTestAllTypes/test.data"]
  125. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  126. ds.config.set_seed(1)
  127. ds0 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  128. data1 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  129. data2 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.FILES)
  130. data2 = data2.shuffle(10000)
  131. data2 = data2.rename(input_columns=["col_sint16", "col_sint32", "col_sint64", "col_float",
  132. "col_1d", "col_2d", "col_3d", "col_binary"],
  133. output_columns=["column_sint16", "column_sint32", "column_sint64", "column_float",
  134. "column_1d", "column_2d", "column_3d", "column_binary"])
  135. data3 = ds.zip((data1, data2))
  136. ds.serialize(data3, "zip_dataset_pipeline.json")
  137. assert validate_jsonfile("zip_dataset_pipeline.json") is True
  138. assert validate_jsonfile("zip_dataset_pipeline_typo.json") is False
  139. data4 = ds.deserialize(json_filepath="zip_dataset_pipeline.json")
  140. ds.serialize(data4, "zip_dataset_pipeline_1.json")
  141. assert validate_jsonfile("zip_dataset_pipeline_1.json") is True
  142. assert filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json')
  143. rows = 0
  144. for d0, d3, d4 in zip(ds0.create_tuple_iterator(output_numpy=True), data3.create_tuple_iterator(output_numpy=True),
  145. data4.create_tuple_iterator(output_numpy=True)):
  146. num_cols = len(d0)
  147. offset = 0
  148. for t1 in d0:
  149. np.testing.assert_array_equal(t1, d3[offset])
  150. np.testing.assert_array_equal(t1, d3[offset + num_cols])
  151. np.testing.assert_array_equal(t1, d4[offset])
  152. np.testing.assert_array_equal(t1, d4[offset + num_cols])
  153. offset += 1
  154. rows += 1
  155. assert rows == 12
  156. if remove_json_files:
  157. delete_json_files()
  158. def test_serdes_random_crop():
  159. """
  160. Test serdes on RandomCrop pipeline.
  161. """
  162. logger.info("test_random_crop")
  163. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  164. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  165. original_seed = config_get_set_seed(1)
  166. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  167. # First dataset
  168. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  169. decode_op = vision.Decode()
  170. random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
  171. data1 = data1.map(operations=decode_op, input_columns="image")
  172. data1 = data1.map(operations=random_crop_op, input_columns="image")
  173. # Serializing into python dictionary
  174. ds1_dict = ds.serialize(data1)
  175. # Serializing into json object
  176. _ = json.dumps(ds1_dict, indent=2)
  177. # Reconstruct dataset pipeline from its serialized form
  178. data1_1 = ds.deserialize(input_dict=ds1_dict)
  179. # Second dataset
  180. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  181. data2 = data2.map(operations=decode_op, input_columns="image")
  182. for item1, item1_1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  183. data1_1.create_dict_iterator(num_epochs=1, output_numpy=True),
  184. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  185. np.testing.assert_array_equal(item1['image'], item1_1['image'])
  186. _ = item2["image"]
  187. # Restore configuration num_parallel_workers
  188. ds.config.set_seed(original_seed)
  189. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  190. def test_serdes_cifar10_dataset(remove_json_files=True):
  191. """
  192. Test serdes on Cifar10 dataset pipeline
  193. """
  194. data_dir = "../data/dataset/testCifar10Data"
  195. original_seed = config_get_set_seed(1)
  196. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  197. data1 = ds.Cifar10Dataset(data_dir, num_samples=10, shuffle=False)
  198. data1 = data1.take(6)
  199. trans = [
  200. vision.RandomCrop((32, 32), (4, 4, 4, 4)),
  201. vision.Resize((224, 224)),
  202. vision.Rescale(1.0 / 255.0, 0.0),
  203. vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
  204. vision.HWC2CHW()
  205. ]
  206. type_cast_op = c.TypeCast(mstype.int32)
  207. data1 = data1.map(operations=type_cast_op, input_columns="label")
  208. data1 = data1.map(operations=trans, input_columns="image")
  209. data1 = data1.batch(3, drop_remainder=True)
  210. data1 = data1.repeat(1)
  211. data2 = util_check_serialize_deserialize_file(data1, "cifar10_dataset_pipeline", remove_json_files)
  212. num_samples = 0
  213. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  214. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  215. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  216. np.testing.assert_array_equal(item1['image'], item2['image'])
  217. num_samples += 1
  218. assert num_samples == 2
  219. # Restore configuration num_parallel_workers
  220. ds.config.set_seed(original_seed)
  221. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  222. def test_serdes_celeba_dataset(remove_json_files=True):
  223. """
  224. Test serdes on Celeba dataset pipeline.
  225. """
  226. DATA_DIR = "../data/dataset/testCelebAData/"
  227. data1 = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0)
  228. # define map operations
  229. data1 = data1.repeat(2)
  230. center_crop = vision.CenterCrop((80, 80))
  231. pad_op = vision.Pad(20, fill_value=(20, 20, 20))
  232. data1 = data1.map(operations=[center_crop, pad_op], input_columns=["image"], num_parallel_workers=8)
  233. data2 = util_check_serialize_deserialize_file(data1, "celeba_dataset_pipeline", remove_json_files)
  234. num_samples = 0
  235. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  236. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  237. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  238. np.testing.assert_array_equal(item1['image'], item2['image'])
  239. num_samples += 1
  240. assert num_samples == 8
  241. def test_serdes_csv_dataset(remove_json_files=True):
  242. """
  243. Test serdes on Csvdataset pipeline.
  244. """
  245. DATA_DIR = "../data/dataset/testCSV/1.csv"
  246. data1 = ds.CSVDataset(
  247. DATA_DIR,
  248. column_defaults=["1", "2", "3", "4"],
  249. column_names=['col1', 'col2', 'col3', 'col4'],
  250. shuffle=False)
  251. columns = ["col1", "col4", "col2"]
  252. data1 = data1.project(columns=columns)
  253. data2 = util_check_serialize_deserialize_file(data1, "csv_dataset_pipeline", remove_json_files)
  254. num_samples = 0
  255. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  256. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  257. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  258. np.testing.assert_array_equal(item1['col1'], item2['col1'])
  259. np.testing.assert_array_equal(item1['col2'], item2['col2'])
  260. np.testing.assert_array_equal(item1['col4'], item2['col4'])
  261. num_samples += 1
  262. assert num_samples == 3
  263. def test_serdes_voc_dataset(remove_json_files=True):
  264. """
  265. Test serdes on VOC dataset pipeline.
  266. """
  267. data_dir = "../data/dataset/testVOC2012"
  268. original_seed = config_get_set_seed(1)
  269. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  270. # define map operations
  271. random_color_adjust_op = vision.RandomColorAdjust(brightness=(0.5, 0.5))
  272. random_rotation_op = vision.RandomRotation((0, 90), expand=True, resample=Inter.BILINEAR, center=(50, 50),
  273. fill_value=150)
  274. data1 = ds.VOCDataset(data_dir, task="Detection", usage="train", decode=True)
  275. data1 = data1.map(operations=random_color_adjust_op, input_columns=["image"])
  276. data1 = data1.map(operations=random_rotation_op, input_columns=["image"])
  277. data1 = data1.skip(2)
  278. data2 = util_check_serialize_deserialize_file(data1, "voc_dataset_pipeline", remove_json_files)
  279. num_samples = 0
  280. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  281. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  282. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  283. np.testing.assert_array_equal(item1['image'], item2['image'])
  284. num_samples += 1
  285. assert num_samples == 7
  286. # Restore configuration num_parallel_workers
  287. ds.config.set_seed(original_seed)
  288. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  289. def test_serdes_to_device(remove_json_files=True):
  290. """
  291. Test serdes on VOC dataset pipeline.
  292. """
  293. data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  294. schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  295. data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
  296. data1 = data1.to_device()
  297. util_check_serialize_deserialize_file(data1, "transfer_dataset_pipeline", remove_json_files)
  298. def test_serdes_exception():
  299. """
  300. Test exception case in serdes
  301. """
  302. data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  303. schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  304. data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
  305. data1 = data1.filter(input_columns=["image", "label"], predicate=lambda data: data < 11, num_parallel_workers=4)
  306. data1_json = ds.serialize(data1)
  307. with pytest.raises(RuntimeError) as msg:
  308. ds.deserialize(input_dict=data1_json)
  309. assert "Filter is not yet supported by ds.engine.deserialize" in str(msg)
  310. def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files):
  311. """
  312. Utility function for testing serdes files. It is to check if a json file is indeed created with correct name
  313. after serializing and if it remains the same after repeatedly saving and loading.
  314. :param data_orig: original data pipeline to be serialized
  315. :param filename: filename to be saved as json format
  316. :param remove_json_files: whether to remove the json file after testing
  317. :return: The data pipeline after serializing and deserializing using the original pipeline
  318. """
  319. file1 = filename + ".json"
  320. file2 = filename + "_1.json"
  321. ds.serialize(data_orig, file1)
  322. assert validate_jsonfile(file1) is True
  323. assert validate_jsonfile("wrong_name.json") is False
  324. data_changed = ds.deserialize(json_filepath=file1)
  325. ds.serialize(data_changed, file2)
  326. assert validate_jsonfile(file2) is True
  327. assert filecmp.cmp(file1, file2)
  328. # Remove the generated json file
  329. if remove_json_files:
  330. delete_json_files()
  331. return data_changed
  332. def validate_jsonfile(filepath):
  333. try:
  334. file_exist = os.path.exists(filepath)
  335. with open(filepath, 'r') as jfile:
  336. loaded_json = json.load(jfile)
  337. except IOError:
  338. return False
  339. return file_exist and isinstance(loaded_json, dict)
  340. def delete_json_files():
  341. file_list = glob.glob('*.json')
  342. for f in file_list:
  343. try:
  344. os.remove(f)
  345. except IOError:
  346. logger.info("Error while deleting: {}".format(f))
  347. # Test save load minddataset
  348. def skip_test_minddataset(add_and_remove_cv_file):
  349. """tutorial for cv minderdataset."""
  350. columns_list = ["data", "file_name", "label"]
  351. num_readers = 4
  352. indices = [1, 2, 3, 5, 7]
  353. sampler = ds.SubsetRandomSampler(indices)
  354. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  355. sampler=sampler)
  356. # Serializing into python dictionary
  357. ds1_dict = ds.serialize(data_set)
  358. # Serializing into json object
  359. ds1_json = json.dumps(ds1_dict, sort_keys=True)
  360. # Reconstruct dataset pipeline from its serialized form
  361. data_set = ds.deserialize(input_dict=ds1_dict)
  362. ds2_dict = ds.serialize(data_set)
  363. # Serializing into json object
  364. ds2_json = json.dumps(ds2_dict, sort_keys=True)
  365. assert ds1_json == ds2_json
  366. _ = get_data(CV_DIR_NAME)
  367. assert data_set.get_dataset_size() == 5
  368. num_iter = 0
  369. for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  370. num_iter += 1
  371. assert num_iter == 5
  372. if __name__ == '__main__':
  373. test_serdes_imagefolder_dataset()
  374. test_serdes_mnist_dataset()
  375. test_serdes_cifar10_dataset()
  376. test_serdes_celeba_dataset()
  377. test_serdes_csv_dataset()
  378. test_serdes_voc_dataset()
  379. test_serdes_zip_dataset()
  380. test_serdes_random_crop()
  381. test_serdes_exception()