You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serdes_dataset.py 22 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing dataset serialize and deserialize in DE
  17. """
  18. import filecmp
  19. import glob
  20. import json
  21. import os
  22. import pytest
  23. import numpy as np
  24. from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME
  25. from util import config_get_set_num_parallel_workers, config_get_set_seed
  26. import mindspore.common.dtype as mstype
  27. import mindspore.dataset as ds
  28. import mindspore.dataset.transforms.c_transforms as c
  29. import mindspore.dataset.transforms.py_transforms as py
  30. import mindspore.dataset.vision.c_transforms as vision
  31. import mindspore.dataset.vision.py_transforms as py_vision
  32. from mindspore import log as logger
  33. from mindspore.dataset.vision import Inter
  34. def test_serdes_imagefolder_dataset(remove_json_files=True):
  35. """
  36. Test simulating resnet50 dataset pipeline.
  37. """
  38. data_dir = "../data/dataset/testPK/data"
  39. ds.config.set_seed(1)
  40. # define data augmentation parameters
  41. rescale = 1.0 / 255.0
  42. shift = 0.0
  43. resize_height, resize_width = 224, 224
  44. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  45. # Constructing DE pipeline
  46. sampler = ds.WeightedRandomSampler(weights, 11)
  47. child_sampler = ds.SequentialSampler()
  48. sampler.add_child(child_sampler)
  49. data1 = ds.ImageFolderDataset(data_dir, sampler=sampler)
  50. data1 = data1.repeat(1)
  51. data1 = data1.map(operations=[vision.Decode(True)], input_columns=["image"])
  52. rescale_op = vision.Rescale(rescale, shift)
  53. resize_op = vision.Resize((resize_height, resize_width), Inter.LINEAR)
  54. data1 = data1.map(operations=[rescale_op, resize_op], input_columns=["image"])
  55. data1 = data1.batch(2)
  56. # Serialize the dataset pre-processing pipeline.
  57. # data1 should still work after saving.
  58. ds.serialize(data1, "imagenet_dataset_pipeline.json")
  59. ds1_dict = ds.serialize(data1)
  60. assert validate_jsonfile("imagenet_dataset_pipeline.json") is True
  61. # Print the serialized pipeline to stdout
  62. ds.show(data1)
  63. # Deserialize the serialized json file
  64. data2 = ds.deserialize(json_filepath="imagenet_dataset_pipeline.json")
  65. # Serialize the pipeline we just deserialized.
  66. # The content of the json file should be the same to the previous serialize.
  67. ds.serialize(data2, "imagenet_dataset_pipeline_1.json")
  68. assert validate_jsonfile("imagenet_dataset_pipeline_1.json") is True
  69. assert filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json')
  70. # Deserialize the latest json file again
  71. data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json")
  72. data4 = ds.deserialize(input_dict=ds1_dict)
  73. num_samples = 0
  74. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  75. for item1, item2, item3, item4 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  76. data2.create_dict_iterator(num_epochs=1, output_numpy=True),
  77. data3.create_dict_iterator(num_epochs=1, output_numpy=True),
  78. data4.create_dict_iterator(num_epochs=1, output_numpy=True)):
  79. np.testing.assert_array_equal(item1['image'], item2['image'])
  80. np.testing.assert_array_equal(item1['image'], item3['image'])
  81. np.testing.assert_array_equal(item1['label'], item2['label'])
  82. np.testing.assert_array_equal(item1['label'], item3['label'])
  83. np.testing.assert_array_equal(item3['image'], item4['image'])
  84. np.testing.assert_array_equal(item3['label'], item4['label'])
  85. num_samples += 1
  86. logger.info("Number of data in data1: {}".format(num_samples))
  87. assert num_samples == 6
  88. # Remove the generated json file
  89. if remove_json_files:
  90. delete_json_files()
  91. def test_serdes_mnist_dataset(remove_json_files=True):
  92. """
  93. Test serdes on mnist dataset pipeline.
  94. """
  95. data_dir = "../data/dataset/testMnistData"
  96. ds.config.set_seed(1)
  97. data1 = ds.MnistDataset(data_dir, num_samples=100)
  98. one_hot_encode = c.OneHot(10) # num_classes is input argument
  99. data1 = data1.map(operations=one_hot_encode, input_columns="label")
  100. # batch_size is input argument
  101. data1 = data1.batch(batch_size=10, drop_remainder=True)
  102. ds.serialize(data1, "mnist_dataset_pipeline.json")
  103. assert validate_jsonfile("mnist_dataset_pipeline.json") is True
  104. data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json")
  105. ds.serialize(data2, "mnist_dataset_pipeline_1.json")
  106. assert validate_jsonfile("mnist_dataset_pipeline_1.json") is True
  107. assert filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json')
  108. data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json")
  109. num = 0
  110. for data1, data2, data3 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  111. data2.create_dict_iterator(num_epochs=1, output_numpy=True),
  112. data3.create_dict_iterator(num_epochs=1, output_numpy=True)):
  113. np.testing.assert_array_equal(data1['image'], data2['image'])
  114. np.testing.assert_array_equal(data1['image'], data3['image'])
  115. np.testing.assert_array_equal(data1['label'], data2['label'])
  116. np.testing.assert_array_equal(data1['label'], data3['label'])
  117. num += 1
  118. logger.info("mnist total num samples is {}".format(str(num)))
  119. assert num == 10
  120. if remove_json_files:
  121. delete_json_files()
  122. def test_serdes_cifar10_dataset(remove_json_files=True):
  123. """
  124. Test serdes on Cifar10 dataset pipeline
  125. """
  126. data_dir = "../data/dataset/testCifar10Data"
  127. original_seed = config_get_set_seed(1)
  128. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  129. data1 = ds.Cifar10Dataset(data_dir, num_samples=10, shuffle=False)
  130. data1 = data1.take(6)
  131. trans = [
  132. vision.RandomCrop((32, 32), (4, 4, 4, 4)),
  133. vision.Resize((224, 224)),
  134. vision.Rescale(1.0 / 255.0, 0.0),
  135. vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
  136. vision.HWC2CHW()
  137. ]
  138. type_cast_op = c.TypeCast(mstype.int32)
  139. data1 = data1.map(operations=type_cast_op, input_columns="label")
  140. data1 = data1.map(operations=trans, input_columns="image")
  141. data1 = data1.batch(3, drop_remainder=True)
  142. data1 = data1.repeat(1)
  143. data2 = util_check_serialize_deserialize_file(data1, "cifar10_dataset_pipeline", remove_json_files)
  144. num_samples = 0
  145. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  146. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  147. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  148. np.testing.assert_array_equal(item1['image'], item2['image'])
  149. num_samples += 1
  150. assert num_samples == 2
  151. # Restore configuration num_parallel_workers
  152. ds.config.set_seed(original_seed)
  153. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  154. def test_serdes_celeba_dataset(remove_json_files=True):
  155. """
  156. Test serdes on Celeba dataset pipeline.
  157. """
  158. DATA_DIR = "../data/dataset/testCelebAData/"
  159. data1 = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0)
  160. # define map operations
  161. data1 = data1.repeat(2)
  162. center_crop = vision.CenterCrop((80, 80))
  163. pad_op = vision.Pad(20, fill_value=(20, 20, 20))
  164. data1 = data1.map(operations=[center_crop, pad_op], input_columns=["image"], num_parallel_workers=8)
  165. data2 = util_check_serialize_deserialize_file(data1, "celeba_dataset_pipeline", remove_json_files)
  166. num_samples = 0
  167. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  168. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  169. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  170. np.testing.assert_array_equal(item1['image'], item2['image'])
  171. num_samples += 1
  172. assert num_samples == 8
  173. def test_serdes_csv_dataset(remove_json_files=True):
  174. """
  175. Test serdes on Csvdataset pipeline.
  176. """
  177. DATA_DIR = "../data/dataset/testCSV/1.csv"
  178. data1 = ds.CSVDataset(
  179. DATA_DIR,
  180. column_defaults=["1", "2", "3", "4"],
  181. column_names=['col1', 'col2', 'col3', 'col4'],
  182. shuffle=False)
  183. columns = ["col1", "col4", "col2"]
  184. data1 = data1.project(columns=columns)
  185. data2 = util_check_serialize_deserialize_file(data1, "csv_dataset_pipeline", remove_json_files)
  186. num_samples = 0
  187. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  188. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  189. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  190. np.testing.assert_array_equal(item1['col1'], item2['col1'])
  191. np.testing.assert_array_equal(item1['col2'], item2['col2'])
  192. np.testing.assert_array_equal(item1['col4'], item2['col4'])
  193. num_samples += 1
  194. assert num_samples == 3
  195. def test_serdes_voc_dataset(remove_json_files=True):
  196. """
  197. Test serdes on VOC dataset pipeline.
  198. """
  199. data_dir = "../data/dataset/testVOC2012"
  200. original_seed = config_get_set_seed(1)
  201. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  202. # define map operations
  203. random_color_adjust_op = vision.RandomColorAdjust(brightness=(0.5, 0.5))
  204. random_rotation_op = vision.RandomRotation((0, 90), expand=True, resample=Inter.BILINEAR, center=(50, 50),
  205. fill_value=150)
  206. data1 = ds.VOCDataset(data_dir, task="Detection", usage="train", decode=True)
  207. data1 = data1.map(operations=random_color_adjust_op, input_columns=["image"])
  208. data1 = data1.map(operations=random_rotation_op, input_columns=["image"])
  209. data1 = data1.skip(2)
  210. data2 = util_check_serialize_deserialize_file(data1, "voc_dataset_pipeline", remove_json_files)
  211. num_samples = 0
  212. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  213. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  214. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  215. np.testing.assert_array_equal(item1['image'], item2['image'])
  216. num_samples += 1
  217. assert num_samples == 7
  218. # Restore configuration num_parallel_workers
  219. ds.config.set_seed(original_seed)
  220. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  221. def test_serdes_zip_dataset(remove_json_files=True):
  222. """
  223. Test serdes on zip dataset pipeline.
  224. """
  225. files = ["../data/dataset/testTFTestAllTypes/test.data"]
  226. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  227. ds.config.set_seed(1)
  228. ds0 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  229. data1 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  230. data2 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.FILES)
  231. data2 = data2.shuffle(10000)
  232. data2 = data2.rename(input_columns=["col_sint16", "col_sint32", "col_sint64", "col_float",
  233. "col_1d", "col_2d", "col_3d", "col_binary"],
  234. output_columns=["column_sint16", "column_sint32", "column_sint64", "column_float",
  235. "column_1d", "column_2d", "column_3d", "column_binary"])
  236. data3 = ds.zip((data1, data2))
  237. ds.serialize(data3, "zip_dataset_pipeline.json")
  238. assert validate_jsonfile("zip_dataset_pipeline.json") is True
  239. assert validate_jsonfile("zip_dataset_pipeline_typo.json") is False
  240. data4 = ds.deserialize(json_filepath="zip_dataset_pipeline.json")
  241. ds.serialize(data4, "zip_dataset_pipeline_1.json")
  242. assert validate_jsonfile("zip_dataset_pipeline_1.json") is True
  243. assert filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json')
  244. rows = 0
  245. for d0, d3, d4 in zip(ds0.create_tuple_iterator(output_numpy=True), data3.create_tuple_iterator(output_numpy=True),
  246. data4.create_tuple_iterator(output_numpy=True)):
  247. num_cols = len(d0)
  248. offset = 0
  249. for t1 in d0:
  250. np.testing.assert_array_equal(t1, d3[offset])
  251. np.testing.assert_array_equal(t1, d3[offset + num_cols])
  252. np.testing.assert_array_equal(t1, d4[offset])
  253. np.testing.assert_array_equal(t1, d4[offset + num_cols])
  254. offset += 1
  255. rows += 1
  256. assert rows == 12
  257. if remove_json_files:
  258. delete_json_files()
  259. def test_serdes_random_crop():
  260. """
  261. Test serdes on RandomCrop pipeline.
  262. """
  263. logger.info("test_random_crop")
  264. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  265. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  266. original_seed = config_get_set_seed(1)
  267. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  268. # First dataset
  269. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  270. decode_op = vision.Decode()
  271. random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
  272. data1 = data1.map(operations=decode_op, input_columns="image")
  273. data1 = data1.map(operations=random_crop_op, input_columns="image")
  274. # Serializing into python dictionary
  275. ds1_dict = ds.serialize(data1)
  276. # Serializing into json object
  277. _ = json.dumps(ds1_dict, indent=2)
  278. # Reconstruct dataset pipeline from its serialized form
  279. data1_1 = ds.deserialize(input_dict=ds1_dict)
  280. # Second dataset
  281. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  282. data2 = data2.map(operations=decode_op, input_columns="image")
  283. for item1, item1_1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  284. data1_1.create_dict_iterator(num_epochs=1, output_numpy=True),
  285. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  286. np.testing.assert_array_equal(item1['image'], item1_1['image'])
  287. _ = item2["image"]
  288. # Restore configuration num_parallel_workers
  289. ds.config.set_seed(original_seed)
  290. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  291. def test_serdes_to_device(remove_json_files=True):
  292. """
  293. Test serdes on transfer dataset pipeline.
  294. """
  295. data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  296. schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  297. data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
  298. data1 = data1.to_device()
  299. util_check_serialize_deserialize_file(data1, "transfer_dataset_pipeline", remove_json_files)
  300. def test_serdes_pyvision(remove_json_files=True):
  301. """
  302. Test serdes on py_transform pipeline.
  303. """
  304. data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  305. schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  306. data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
  307. transforms = [
  308. py_vision.Decode(),
  309. py_vision.CenterCrop([32, 32]),
  310. py_vision.ToTensor()
  311. ]
  312. data1 = data1.map(operations=py.Compose(transforms), input_columns=["image"])
  313. # Current python function derialization will be failed for pickle, so we disable this testcase
  314. # as an exception testcase.
  315. try:
  316. util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
  317. assert False
  318. except NotImplementedError as e:
  319. assert "python function is not yet supported" in str(e)
  320. def test_serdes_uniform_augment(remove_json_files=True):
  321. """
  322. Test serdes on uniform augment.
  323. """
  324. data_dir = "../data/dataset/testPK/data"
  325. data = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
  326. ds.config.set_seed(1)
  327. transforms_ua = [vision.RandomHorizontalFlip(),
  328. vision.RandomVerticalFlip(),
  329. vision.RandomColor(),
  330. vision.RandomSharpness(),
  331. vision.Invert(),
  332. vision.AutoContrast(),
  333. vision.Equalize()]
  334. transforms_all = [vision.Decode(), vision.Resize(size=[224, 224]),
  335. vision.UniformAugment(transforms=transforms_ua, num_ops=5)]
  336. data = data.map(operations=transforms_all, input_columns="image", num_parallel_workers=1)
  337. util_check_serialize_deserialize_file(data, "uniform_augment_pipeline", remove_json_files)
  338. def skip_test_serdes_fill(remove_json_files=True):
  339. """
  340. Test serdes on Fill data transform.
  341. """
  342. def gen():
  343. yield (np.array([4, 5, 6, 7], dtype=np.int32),)
  344. data = ds.GeneratorDataset(gen, column_names=["col"])
  345. fill_op = c.Fill(3)
  346. data = data.map(operations=fill_op, input_columns=["col"])
  347. expected = np.array([3, 3, 3, 3], dtype=np.int32)
  348. for data_row in data:
  349. np.testing.assert_array_equal(data_row[0].asnumpy(), expected)
  350. # FIXME - need proper serdes support for Fill's fill_value parameter
  351. util_check_serialize_deserialize_file(data, "fill_pipeline", remove_json_files)
  352. def test_serdes_exception():
  353. """
  354. Test exception case in serdes
  355. """
  356. data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  357. schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  358. data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
  359. data1 = data1.filter(input_columns=["image", "label"], predicate=lambda data: data < 11, num_parallel_workers=4)
  360. data1_json = ds.serialize(data1)
  361. with pytest.raises(RuntimeError) as msg:
  362. ds.deserialize(input_dict=data1_json)
  363. assert "Filter is not yet supported by ds.engine.deserialize" in str(msg)
  364. def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files):
  365. """
  366. Utility function for testing serdes files. It is to check if a json file is indeed created with correct name
  367. after serializing and if it remains the same after repeatedly saving and loading.
  368. :param data_orig: original data pipeline to be serialized
  369. :param filename: filename to be saved as json format
  370. :param remove_json_files: whether to remove the json file after testing
  371. :return: The data pipeline after serializing and deserializing using the original pipeline
  372. """
  373. file1 = filename + ".json"
  374. file2 = filename + "_1.json"
  375. ds.serialize(data_orig, file1)
  376. assert validate_jsonfile(file1) is True
  377. assert validate_jsonfile("wrong_name.json") is False
  378. data_changed = ds.deserialize(json_filepath=file1)
  379. ds.serialize(data_changed, file2)
  380. assert validate_jsonfile(file2) is True
  381. assert filecmp.cmp(file1, file2)
  382. # Remove the generated json file
  383. if remove_json_files:
  384. delete_json_files()
  385. return data_changed
  386. def validate_jsonfile(filepath):
  387. try:
  388. file_exist = os.path.exists(filepath)
  389. with open(filepath, 'r') as jfile:
  390. loaded_json = json.load(jfile)
  391. except IOError:
  392. return False
  393. return file_exist and isinstance(loaded_json, dict)
  394. def delete_json_files():
  395. file_list = glob.glob('*.json')
  396. for f in file_list:
  397. try:
  398. os.remove(f)
  399. except IOError:
  400. logger.info("Error while deleting: {}".format(f))
  401. # Test save load minddataset
  402. def skip_test_minddataset(add_and_remove_cv_file=True):
  403. """tutorial for cv minderdataset."""
  404. columns_list = ["data", "file_name", "label"]
  405. num_readers = 4
  406. indices = [1, 2, 3, 5, 7]
  407. sampler = ds.SubsetRandomSampler(indices)
  408. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  409. sampler=sampler)
  410. # Serializing into python dictionary
  411. ds1_dict = ds.serialize(data_set)
  412. # Serializing into json object
  413. ds1_json = json.dumps(ds1_dict, sort_keys=True)
  414. # Reconstruct dataset pipeline from its serialized form
  415. data_set = ds.deserialize(input_dict=ds1_dict)
  416. ds2_dict = ds.serialize(data_set)
  417. # Serializing into json object
  418. ds2_json = json.dumps(ds2_dict, sort_keys=True)
  419. assert ds1_json == ds2_json
  420. _ = get_data(CV_DIR_NAME)
  421. assert data_set.get_dataset_size() == 5
  422. num_iter = 0
  423. for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  424. num_iter += 1
  425. assert num_iter == 5
  426. if __name__ == '__main__':
  427. test_serdes_imagefolder_dataset()
  428. test_serdes_mnist_dataset()
  429. test_serdes_cifar10_dataset()
  430. test_serdes_celeba_dataset()
  431. test_serdes_csv_dataset()
  432. test_serdes_voc_dataset()
  433. test_serdes_zip_dataset()
  434. test_serdes_random_crop()
  435. test_serdes_to_device()
  436. test_serdes_pyvision()
  437. test_serdes_uniform_augment()
  438. skip_test_serdes_fill()
  439. test_serdes_exception()
  440. skip_test_minddataset()