You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serdes_dataset.py 23 kB

5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing dataset serialize and deserialize in DE
  17. """
  18. import filecmp
  19. import glob
  20. import json
  21. import os
  22. import pytest
  23. import numpy as np
  24. from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME
  25. from util import config_get_set_num_parallel_workers, config_get_set_seed
  26. import mindspore.common.dtype as mstype
  27. import mindspore.dataset as ds
  28. import mindspore.dataset.transforms.c_transforms as c
  29. import mindspore.dataset.transforms.py_transforms as py
  30. import mindspore.dataset.vision.c_transforms as vision
  31. import mindspore.dataset.vision.py_transforms as py_vision
  32. from mindspore import log as logger
  33. from mindspore.dataset.vision import Inter
  34. def test_serdes_imagefolder_dataset(remove_json_files=True):
  35. """
  36. Test simulating resnet50 dataset pipeline.
  37. """
  38. data_dir = "../data/dataset/testPK/data"
  39. ds.config.set_seed(1)
  40. # define data augmentation parameters
  41. rescale = 1.0 / 255.0
  42. shift = 0.0
  43. resize_height, resize_width = 224, 224
  44. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  45. # Constructing DE pipeline
  46. sampler = ds.WeightedRandomSampler(weights, 11)
  47. child_sampler = ds.SequentialSampler()
  48. sampler.add_child(child_sampler)
  49. data1 = ds.ImageFolderDataset(data_dir, sampler=sampler)
  50. data1 = data1.repeat(1)
  51. data1 = data1.map(operations=[vision.Decode(True)], input_columns=["image"])
  52. rescale_op = vision.Rescale(rescale, shift)
  53. resize_op = vision.Resize((resize_height, resize_width), Inter.LINEAR)
  54. data1 = data1.map(operations=[rescale_op, resize_op], input_columns=["image"])
  55. data1_1 = ds.TFRecordDataset(["../data/dataset/testTFTestAllTypes/test.data"], num_samples=6).batch(2).repeat(10)
  56. data1 = data1.zip(data1_1)
  57. # Serialize the dataset pre-processing pipeline.
  58. # data1 should still work after saving.
  59. ds.serialize(data1, "imagenet_dataset_pipeline.json")
  60. ds1_dict = ds.serialize(data1)
  61. assert validate_jsonfile("imagenet_dataset_pipeline.json") is True
  62. # Print the serialized pipeline to stdout
  63. ds.show(data1)
  64. # Deserialize the serialized json file
  65. data2 = ds.deserialize(json_filepath="imagenet_dataset_pipeline.json")
  66. # Serialize the pipeline we just deserialized.
  67. # The content of the json file should be the same to the previous serialize.
  68. ds.serialize(data2, "imagenet_dataset_pipeline_1.json")
  69. assert validate_jsonfile("imagenet_dataset_pipeline_1.json") is True
  70. assert filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json')
  71. assert data1.get_dataset_size() == data2.get_dataset_size()
  72. # Deserialize the latest json file again
  73. data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json")
  74. data4 = ds.deserialize(input_dict=ds1_dict)
  75. num_samples = 0
  76. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  77. for item1, item2, item3, item4 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  78. data2.create_dict_iterator(num_epochs=1, output_numpy=True),
  79. data3.create_dict_iterator(num_epochs=1, output_numpy=True),
  80. data4.create_dict_iterator(num_epochs=1, output_numpy=True)):
  81. np.testing.assert_array_equal(item1['image'], item2['image'])
  82. np.testing.assert_array_equal(item1['image'], item3['image'])
  83. np.testing.assert_array_equal(item1['label'], item2['label'])
  84. np.testing.assert_array_equal(item1['label'], item3['label'])
  85. np.testing.assert_array_equal(item3['image'], item4['image'])
  86. np.testing.assert_array_equal(item3['label'], item4['label'])
  87. num_samples += 1
  88. logger.info("Number of data in data1: {}".format(num_samples))
  89. assert num_samples == 11
  90. # Remove the generated json file
  91. if remove_json_files:
  92. delete_json_files()
  93. def test_serdes_mnist_dataset(remove_json_files=True):
  94. """
  95. Test serdes on mnist dataset pipeline.
  96. """
  97. data_dir = "../data/dataset/testMnistData"
  98. ds.config.set_seed(1)
  99. data1 = ds.MnistDataset(data_dir, num_samples=100)
  100. one_hot_encode = c.OneHot(10) # num_classes is input argument
  101. data1 = data1.map(operations=one_hot_encode, input_columns="label")
  102. # batch_size is input argument
  103. data1 = data1.batch(batch_size=10, drop_remainder=True)
  104. ds.serialize(data1, "mnist_dataset_pipeline.json")
  105. assert validate_jsonfile("mnist_dataset_pipeline.json") is True
  106. data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json")
  107. ds.serialize(data2, "mnist_dataset_pipeline_1.json")
  108. assert validate_jsonfile("mnist_dataset_pipeline_1.json") is True
  109. assert filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json')
  110. data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json")
  111. num = 0
  112. for data1, data2, data3 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  113. data2.create_dict_iterator(num_epochs=1, output_numpy=True),
  114. data3.create_dict_iterator(num_epochs=1, output_numpy=True)):
  115. np.testing.assert_array_equal(data1['image'], data2['image'])
  116. np.testing.assert_array_equal(data1['image'], data3['image'])
  117. np.testing.assert_array_equal(data1['label'], data2['label'])
  118. np.testing.assert_array_equal(data1['label'], data3['label'])
  119. num += 1
  120. logger.info("mnist total num samples is {}".format(str(num)))
  121. assert num == 10
  122. if remove_json_files:
  123. delete_json_files()
  124. def test_serdes_cifar10_dataset(remove_json_files=True):
  125. """
  126. Test serdes on Cifar10 dataset pipeline
  127. """
  128. data_dir = "../data/dataset/testCifar10Data"
  129. original_seed = config_get_set_seed(1)
  130. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  131. data1 = ds.Cifar10Dataset(data_dir, num_samples=10, shuffle=False)
  132. data1 = data1.take(6)
  133. trans = [
  134. vision.RandomCrop((32, 32), (4, 4, 4, 4)),
  135. vision.Resize((224, 224)),
  136. vision.Rescale(1.0 / 255.0, 0.0),
  137. vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
  138. vision.HWC2CHW()
  139. ]
  140. type_cast_op = c.TypeCast(mstype.int32)
  141. data1 = data1.map(operations=type_cast_op, input_columns="label")
  142. data1 = data1.map(operations=trans, input_columns="image")
  143. data1 = data1.batch(3, drop_remainder=True)
  144. data1 = data1.repeat(1)
  145. # json files are needed for create iterator, remove_json_files = False
  146. data2 = util_check_serialize_deserialize_file(data1, "cifar10_dataset_pipeline", False)
  147. num_samples = 0
  148. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  149. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  150. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  151. np.testing.assert_array_equal(item1['image'], item2['image'])
  152. num_samples += 1
  153. assert num_samples == 2
  154. # Restore configuration num_parallel_workers
  155. ds.config.set_seed(original_seed)
  156. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  157. if remove_json_files:
  158. delete_json_files()
  159. def test_serdes_celeba_dataset(remove_json_files=True):
  160. """
  161. Test serdes on Celeba dataset pipeline.
  162. """
  163. DATA_DIR = "../data/dataset/testCelebAData/"
  164. data1 = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0)
  165. # define map operations
  166. data1 = data1.repeat(2)
  167. center_crop = vision.CenterCrop((80, 80))
  168. pad_op = vision.Pad(20, fill_value=(20, 20, 20))
  169. data1 = data1.map(operations=[center_crop, pad_op], input_columns=["image"], num_parallel_workers=8)
  170. # json files are needed for create iterator, remove_json_files = False
  171. data2 = util_check_serialize_deserialize_file(data1, "celeba_dataset_pipeline", False)
  172. num_samples = 0
  173. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  174. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  175. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  176. np.testing.assert_array_equal(item1['image'], item2['image'])
  177. num_samples += 1
  178. assert num_samples == 8
  179. if remove_json_files:
  180. delete_json_files()
  181. def test_serdes_csv_dataset(remove_json_files=True):
  182. """
  183. Test serdes on Csvdataset pipeline.
  184. """
  185. DATA_DIR = "../data/dataset/testCSV/1.csv"
  186. data1 = ds.CSVDataset(
  187. DATA_DIR,
  188. column_defaults=["1", "2", "3", "4"],
  189. column_names=['col1', 'col2', 'col3', 'col4'],
  190. shuffle=False)
  191. columns = ["col1", "col4", "col2"]
  192. data1 = data1.project(columns=columns)
  193. # json files are needed for create iterator, remove_json_files = False
  194. data2 = util_check_serialize_deserialize_file(data1, "csv_dataset_pipeline", False)
  195. num_samples = 0
  196. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  197. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  198. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  199. np.testing.assert_array_equal(item1['col1'], item2['col1'])
  200. np.testing.assert_array_equal(item1['col2'], item2['col2'])
  201. np.testing.assert_array_equal(item1['col4'], item2['col4'])
  202. num_samples += 1
  203. assert num_samples == 3
  204. if remove_json_files:
  205. delete_json_files()
  206. def test_serdes_voc_dataset(remove_json_files=True):
  207. """
  208. Test serdes on VOC dataset pipeline.
  209. """
  210. data_dir = "../data/dataset/testVOC2012"
  211. original_seed = config_get_set_seed(1)
  212. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  213. # define map operations
  214. random_color_adjust_op = vision.RandomColorAdjust(brightness=(0.5, 0.5))
  215. random_rotation_op = vision.RandomRotation((0, 90), expand=True, resample=Inter.BILINEAR, center=(50, 50),
  216. fill_value=150)
  217. data1 = ds.VOCDataset(data_dir, task="Detection", usage="train", decode=True)
  218. data1 = data1.map(operations=random_color_adjust_op, input_columns=["image"])
  219. data1 = data1.map(operations=random_rotation_op, input_columns=["image"])
  220. data1 = data1.skip(2)
  221. # json files are needed for create iterator, remove_json_files = False
  222. data2 = util_check_serialize_deserialize_file(data1, "voc_dataset_pipeline", False)
  223. num_samples = 0
  224. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  225. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  226. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  227. np.testing.assert_array_equal(item1['image'], item2['image'])
  228. num_samples += 1
  229. assert num_samples == 7
  230. # Restore configuration num_parallel_workers
  231. ds.config.set_seed(original_seed)
  232. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  233. if remove_json_files:
  234. delete_json_files()
  235. def test_serdes_zip_dataset(remove_json_files=True):
  236. """
  237. Test serdes on zip dataset pipeline.
  238. """
  239. files = ["../data/dataset/testTFTestAllTypes/test.data"]
  240. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  241. ds.config.set_seed(1)
  242. ds0 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  243. data1 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  244. data2 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.FILES)
  245. data2 = data2.shuffle(10000)
  246. data2 = data2.rename(input_columns=["col_sint16", "col_sint32", "col_sint64", "col_float",
  247. "col_1d", "col_2d", "col_3d", "col_binary"],
  248. output_columns=["column_sint16", "column_sint32", "column_sint64", "column_float",
  249. "column_1d", "column_2d", "column_3d", "column_binary"])
  250. data3 = ds.zip((data1, data2))
  251. ds.serialize(data3, "zip_dataset_pipeline.json")
  252. assert validate_jsonfile("zip_dataset_pipeline.json") is True
  253. assert validate_jsonfile("zip_dataset_pipeline_typo.json") is False
  254. data4 = ds.deserialize(json_filepath="zip_dataset_pipeline.json")
  255. ds.serialize(data4, "zip_dataset_pipeline_1.json")
  256. assert validate_jsonfile("zip_dataset_pipeline_1.json") is True
  257. assert filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json')
  258. rows = 0
  259. for d0, d3, d4 in zip(ds0.create_tuple_iterator(output_numpy=True), data3.create_tuple_iterator(output_numpy=True),
  260. data4.create_tuple_iterator(output_numpy=True)):
  261. num_cols = len(d0)
  262. offset = 0
  263. for t1 in d0:
  264. np.testing.assert_array_equal(t1, d3[offset])
  265. np.testing.assert_array_equal(t1, d3[offset + num_cols])
  266. np.testing.assert_array_equal(t1, d4[offset])
  267. np.testing.assert_array_equal(t1, d4[offset + num_cols])
  268. offset += 1
  269. rows += 1
  270. assert rows == 12
  271. if remove_json_files:
  272. delete_json_files()
  273. def test_serdes_random_crop():
  274. """
  275. Test serdes on RandomCrop pipeline.
  276. """
  277. logger.info("test_random_crop")
  278. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  279. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  280. original_seed = config_get_set_seed(1)
  281. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  282. # First dataset
  283. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  284. decode_op = vision.Decode()
  285. random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
  286. data1 = data1.map(operations=decode_op, input_columns="image")
  287. data1 = data1.map(operations=random_crop_op, input_columns="image")
  288. # Serializing into python dictionary
  289. ds1_dict = ds.serialize(data1)
  290. # Serializing into json object
  291. _ = json.dumps(ds1_dict, indent=2)
  292. # Reconstruct dataset pipeline from its serialized form
  293. data1_1 = ds.deserialize(input_dict=ds1_dict)
  294. # Second dataset
  295. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  296. data2 = data2.map(operations=decode_op, input_columns="image")
  297. for item1, item1_1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
  298. data1_1.create_dict_iterator(num_epochs=1, output_numpy=True),
  299. data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  300. np.testing.assert_array_equal(item1['image'], item1_1['image'])
  301. _ = item2["image"]
  302. # Restore configuration num_parallel_workers
  303. ds.config.set_seed(original_seed)
  304. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  305. def test_serdes_to_device(remove_json_files=True):
  306. """
  307. Test serdes on transfer dataset pipeline.
  308. """
  309. data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  310. schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  311. data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
  312. data1 = data1.to_device()
  313. util_check_serialize_deserialize_file(data1, "transfer_dataset_pipeline", remove_json_files)
  314. def test_serdes_pyvision(remove_json_files=True):
  315. """
  316. Test serdes on py_transform pipeline.
  317. """
  318. data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  319. schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  320. data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
  321. transforms1 = [
  322. py_vision.Decode(),
  323. py_vision.CenterCrop([32, 32]),
  324. py_vision.ToTensor()
  325. ]
  326. transforms2 = [
  327. py_vision.RandomColorAdjust(),
  328. py_vision.FiveCrop(1),
  329. py_vision.Grayscale(),
  330. py.OneHotOp(1)
  331. ]
  332. data1 = data1.map(operations=py.Compose(transforms1), input_columns=["image"])
  333. data1 = data1.map(operations=py.RandomApply(transforms2), input_columns=["image"])
  334. util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
  335. def test_serdes_uniform_augment(remove_json_files=True):
  336. """
  337. Test serdes on uniform augment.
  338. """
  339. data_dir = "../data/dataset/testPK/data"
  340. data = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
  341. ds.config.set_seed(1)
  342. transforms_ua = [vision.RandomHorizontalFlip(),
  343. vision.RandomVerticalFlip(),
  344. vision.RandomColor(),
  345. vision.RandomSharpness(),
  346. vision.Invert(),
  347. vision.AutoContrast(),
  348. vision.Equalize()]
  349. transforms_all = [vision.Decode(), vision.Resize(size=[224, 224]),
  350. vision.UniformAugment(transforms=transforms_ua, num_ops=5)]
  351. data = data.map(operations=transforms_all, input_columns="image", num_parallel_workers=1)
  352. util_check_serialize_deserialize_file(data, "uniform_augment_pipeline", remove_json_files)
  353. def skip_test_serdes_fill(remove_json_files=True):
  354. """
  355. Test serdes on Fill data transform.
  356. """
  357. def gen():
  358. yield (np.array([4, 5, 6, 7], dtype=np.int32),)
  359. data = ds.GeneratorDataset(gen, column_names=["col"])
  360. fill_op = c.Fill(3)
  361. data = data.map(operations=fill_op, input_columns=["col"])
  362. expected = np.array([3, 3, 3, 3], dtype=np.int32)
  363. for data_row in data:
  364. np.testing.assert_array_equal(data_row[0].asnumpy(), expected)
  365. util_check_serialize_deserialize_file(data, "fill_pipeline", remove_json_files)
  366. def test_serdes_exception():
  367. """
  368. Test exception case in serdes
  369. """
  370. data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  371. schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  372. data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
  373. data1 = data1.filter(input_columns=["image", "label"], predicate=lambda data: data < 11, num_parallel_workers=4)
  374. data1_json = ds.serialize(data1)
  375. with pytest.raises(RuntimeError) as msg:
  376. data2 = ds.deserialize(input_dict=data1_json)
  377. ds.serialize(data2, "filter_dataset_fail.json")
  378. assert "Filter operation is not supported" in str(msg)
  379. delete_json_files()
  380. def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files):
  381. """
  382. Utility function for testing serdes files. It is to check if a json file is indeed created with correct name
  383. after serializing and if it remains the same after repeatedly saving and loading.
  384. :param data_orig: original data pipeline to be serialized
  385. :param filename: filename to be saved as json format
  386. :param remove_json_files: whether to remove the json file after testing
  387. :return: The data pipeline after serializing and deserializing using the original pipeline
  388. """
  389. file1 = filename + ".json"
  390. file2 = filename + "_1.json"
  391. ds.serialize(data_orig, file1)
  392. assert validate_jsonfile(file1) is True
  393. assert validate_jsonfile("wrong_name.json") is False
  394. data_changed = ds.deserialize(json_filepath=file1)
  395. ds.serialize(data_changed, file2)
  396. assert validate_jsonfile(file2) is True
  397. assert filecmp.cmp(file1, file2, shallow=False)
  398. # Remove the generated json file
  399. if remove_json_files:
  400. delete_json_files()
  401. return data_changed
  402. def validate_jsonfile(filepath):
  403. try:
  404. file_exist = os.path.exists(filepath)
  405. with open(filepath, 'r') as jfile:
  406. loaded_json = json.load(jfile)
  407. except IOError:
  408. return False
  409. return file_exist and isinstance(loaded_json, dict)
  410. def delete_json_files():
  411. file_list = glob.glob('*.json')
  412. for f in file_list:
  413. try:
  414. os.remove(f)
  415. except IOError:
  416. logger.info("Error while deleting: {}".format(f))
  417. # Test save load minddataset
  418. def skip_test_minddataset(add_and_remove_cv_file=True):
  419. """tutorial for cv minderdataset."""
  420. columns_list = ["data", "file_name", "label"]
  421. num_readers = 4
  422. indices = [1, 2, 3, 5, 7]
  423. sampler = ds.SubsetRandomSampler(indices)
  424. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  425. sampler=sampler)
  426. # Serializing into python dictionary
  427. ds1_dict = ds.serialize(data_set)
  428. # Serializing into json object
  429. ds1_json = json.dumps(ds1_dict, sort_keys=True)
  430. # Reconstruct dataset pipeline from its serialized form
  431. data_set = ds.deserialize(input_dict=ds1_dict)
  432. ds2_dict = ds.serialize(data_set)
  433. # Serializing into json object
  434. ds2_json = json.dumps(ds2_dict, sort_keys=True)
  435. assert ds1_json == ds2_json
  436. _ = get_data(CV_DIR_NAME)
  437. assert data_set.get_dataset_size() == 5
  438. num_iter = 0
  439. for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  440. num_iter += 1
  441. assert num_iter == 5
  442. if __name__ == '__main__':
  443. test_serdes_imagefolder_dataset()
  444. test_serdes_mnist_dataset()
  445. test_serdes_cifar10_dataset()
  446. test_serdes_celeba_dataset()
  447. test_serdes_csv_dataset()
  448. test_serdes_voc_dataset()
  449. test_serdes_zip_dataset()
  450. test_serdes_random_crop()
  451. test_serdes_to_device()
  452. test_serdes_pyvision()
  453. test_serdes_uniform_augment()
  454. skip_test_serdes_fill()
  455. test_serdes_exception()
  456. skip_test_minddataset()