You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_save_op.py 23 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This is the test module for saveOp.
  17. """
  18. import os
  19. from string import punctuation
  20. import numpy as np
  21. import pytest
  22. import mindspore.dataset as ds
  23. from mindspore import log as logger
  24. from mindspore.mindrecord import FileWriter
  25. TFRECORD_FILES = "../data/mindrecord/testTFRecordData/dummy.tfrecord"
  26. FILES_NUM = 1
  27. num_readers = 1
  28. def remove_file(file_name):
  29. """add/remove cv file"""
  30. if os.path.exists("{}".format(file_name)):
  31. os.remove("{}".format(file_name))
  32. if os.path.exists("{}.db".format(file_name)):
  33. os.remove("{}.db".format(file_name))
  34. def test_case_00():
  35. """
  36. Feature: save op
  37. Description: all bin data
  38. Expectation: generated mindrecord file
  39. """
  40. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  41. data = [{"image1": bytes("image1 bytes abc", encoding='UTF-8'),
  42. "image2": bytes("image1 bytes def", encoding='UTF-8'),
  43. "image3": bytes("image1 bytes ghi", encoding='UTF-8'),
  44. "image4": bytes("image1 bytes jkl", encoding='UTF-8'),
  45. "image5": bytes("image1 bytes mno", encoding='UTF-8')},
  46. {"image1": bytes("image2 bytes abc", encoding='UTF-8'),
  47. "image2": bytes("image2 bytes def", encoding='UTF-8'),
  48. "image3": bytes("image2 bytes ghi", encoding='UTF-8'),
  49. "image4": bytes("image2 bytes jkl", encoding='UTF-8'),
  50. "image5": bytes("image2 bytes mno", encoding='UTF-8')},
  51. {"image1": bytes("image3 bytes abc", encoding='UTF-8'),
  52. "image2": bytes("image3 bytes def", encoding='UTF-8'),
  53. "image3": bytes("image3 bytes ghi", encoding='UTF-8'),
  54. "image4": bytes("image3 bytes jkl", encoding='UTF-8'),
  55. "image5": bytes("image3 bytes mno", encoding='UTF-8')},
  56. {"image1": bytes("image5 bytes abc", encoding='UTF-8'),
  57. "image2": bytes("image5 bytes def", encoding='UTF-8'),
  58. "image3": bytes("image5 bytes ghi", encoding='UTF-8'),
  59. "image4": bytes("image5 bytes jkl", encoding='UTF-8'),
  60. "image5": bytes("image5 bytes mno", encoding='UTF-8')},
  61. {"image1": bytes("image6 bytes abc", encoding='UTF-8'),
  62. "image2": bytes("image6 bytes def", encoding='UTF-8'),
  63. "image3": bytes("image6 bytes ghi", encoding='UTF-8'),
  64. "image4": bytes("image6 bytes jkl", encoding='UTF-8'),
  65. "image5": bytes("image6 bytes mno", encoding='UTF-8')}]
  66. schema = {
  67. "image1": {"type": "bytes"},
  68. "image2": {"type": "bytes"},
  69. "image3": {"type": "bytes"},
  70. "image4": {"type": "bytes"},
  71. "image5": {"type": "bytes"}}
  72. writer = FileWriter(file_name, FILES_NUM)
  73. writer.add_schema(schema, "schema")
  74. writer.write_raw_data(data)
  75. writer.commit()
  76. file_name_auto = './'
  77. file_name_auto += os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  78. file_name_auto += '_auto'
  79. d1 = ds.MindDataset(file_name, None, num_readers, shuffle=False)
  80. d1.save(file_name_auto, FILES_NUM)
  81. data_value_to_list = []
  82. for item in data:
  83. new_data = {}
  84. new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
  85. new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
  86. new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
  87. new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
  88. new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
  89. data_value_to_list.append(new_data)
  90. d2 = ds.MindDataset(dataset_file=file_name_auto,
  91. num_parallel_workers=num_readers,
  92. shuffle=False)
  93. assert d2.get_dataset_size() == 5
  94. num_iter = 0
  95. for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
  96. assert len(item) == 5
  97. for field in item:
  98. if isinstance(item[field], np.ndarray):
  99. assert (item[field] ==
  100. data_value_to_list[num_iter][field]).all()
  101. else:
  102. assert item[field] == data_value_to_list[num_iter][field]
  103. num_iter += 1
  104. assert num_iter == 5
  105. remove_file(file_name)
  106. remove_file(file_name_auto)
  107. file_name_auto = './'
  108. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  109. data = [{"file_name": "001.jpg", "label": 43},
  110. {"file_name": "002.jpg", "label": 91},
  111. {"file_name": "003.jpg", "label": 61},
  112. {"file_name": "004.jpg", "label": 29},
  113. {"file_name": "005.jpg", "label": 78},
  114. {"file_name": "006.jpg", "label": 37}]
  115. schema = {"file_name": {"type": "string"},
  116. "label": {"type": "int32"}
  117. }
  118. writer = FileWriter(file_name, FILES_NUM)
  119. writer.add_schema(schema, "schema")
  120. writer.write_raw_data(data)
  121. writer.commit()
  122. file_name_auto = './'
  123. file_name_auto += os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  124. file_name_auto += '_auto'
  125. d1 = ds.MindDataset(file_name, None, num_readers, shuffle=False)
  126. d1.save(file_name_auto, FILES_NUM)
  127. data_value_to_list = []
  128. for item in data:
  129. new_data = {}
  130. new_data['file_name'] = np.asarray(item["file_name"], dtype='S')
  131. new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
  132. data_value_to_list.append(new_data)
  133. d2 = ds.MindDataset(dataset_file=file_name_auto,
  134. num_parallel_workers=num_readers,
  135. shuffle=False)
  136. assert d2.get_dataset_size() == 6
  137. num_iter = 0
  138. for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
  139. logger.info(item)
  140. assert len(item) == 2
  141. for field in item:
  142. if isinstance(item[field], np.ndarray):
  143. assert (item[field] ==
  144. data_value_to_list[num_iter][field]).all()
  145. else:
  146. assert item[field] == data_value_to_list[num_iter][field]
  147. num_iter += 1
  148. assert num_iter == 6
  149. remove_file(file_name)
  150. remove_file(file_name_auto)
  151. def test_case_02(): # muti-bytes
  152. """
  153. Feature: save op
  154. Description: multiple byte fields
  155. Expectation: generated mindrecord file
  156. """
  157. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  158. data = [{"file_name": "001.jpg", "label": 43,
  159. "float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32),
  160. "float64_array": np.array([48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471,
  161. 123414314.2141243, 87.1212122], dtype=np.float64),
  162. "float32": 3456.12345,
  163. "float64": 1987654321.123456785,
  164. "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int32),
  165. "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  166. "image1": bytes("image1 bytes abc", encoding='UTF-8'),
  167. "image2": bytes("image1 bytes def", encoding='UTF-8'),
  168. "image3": bytes("image1 bytes ghi", encoding='UTF-8'),
  169. "image4": bytes("image1 bytes jkl", encoding='UTF-8'),
  170. "image5": bytes("image1 bytes mno", encoding='UTF-8')},
  171. {"file_name": "002.jpg", "label": 91,
  172. "float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32),
  173. "float64_array": np.array([48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471,
  174. 123414314.2141243, 87.1212122], dtype=np.float64),
  175. "float32": 3456.12445,
  176. "float64": 1987654321.123456786,
  177. "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int32),
  178. "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  179. "image1": bytes("image2 bytes abc", encoding='UTF-8'),
  180. "image2": bytes("image2 bytes def", encoding='UTF-8'),
  181. "image3": bytes("image2 bytes ghi", encoding='UTF-8'),
  182. "image4": bytes("image2 bytes jkl", encoding='UTF-8'),
  183. "image5": bytes("image2 bytes mno", encoding='UTF-8')},
  184. {"file_name": "003.jpg", "label": 61,
  185. "float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32),
  186. "float64_array": np.array([48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471,
  187. 123414314.2141243, 87.1212122], dtype=np.float64),
  188. "float32": 3456.12545,
  189. "float64": 1987654321.123456787,
  190. "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int32),
  191. "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  192. "image1": bytes("image3 bytes abc", encoding='UTF-8'),
  193. "image2": bytes("image3 bytes def", encoding='UTF-8'),
  194. "image3": bytes("image3 bytes ghi", encoding='UTF-8'),
  195. "image4": bytes("image3 bytes jkl", encoding='UTF-8'),
  196. "image5": bytes("image3 bytes mno", encoding='UTF-8')},
  197. {"file_name": "004.jpg", "label": 29,
  198. "float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32),
  199. "float64_array": np.array([48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471,
  200. 123414314.2141243, 87.1212122], dtype=np.float64),
  201. "float32": 3456.12645,
  202. "float64": 1987654321.123456788,
  203. "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int32),
  204. "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  205. "image1": bytes("image4 bytes abc", encoding='UTF-8'),
  206. "image2": bytes("image4 bytes def", encoding='UTF-8'),
  207. "image3": bytes("image4 bytes ghi", encoding='UTF-8'),
  208. "image4": bytes("image4 bytes jkl", encoding='UTF-8'),
  209. "image5": bytes("image4 bytes mno", encoding='UTF-8')},
  210. {"file_name": "005.jpg", "label": 78,
  211. "float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32),
  212. "float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471,
  213. 123414314.2141243, 87.1212122], dtype=np.float64),
  214. "float32": 3456.12745,
  215. "float64": 1987654321.123456789,
  216. "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int32),
  217. "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  218. "image1": bytes("image5 bytes abc", encoding='UTF-8'),
  219. "image2": bytes("image5 bytes def", encoding='UTF-8'),
  220. "image3": bytes("image5 bytes ghi", encoding='UTF-8'),
  221. "image4": bytes("image5 bytes jkl", encoding='UTF-8'),
  222. "image5": bytes("image5 bytes mno", encoding='UTF-8')},
  223. {"file_name": "006.jpg", "label": 37,
  224. "float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32),
  225. "float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471,
  226. 123414314.2141243, 87.1212122], dtype=np.float64),
  227. "float32": 3456.12745,
  228. "float64": 1987654321.123456789,
  229. "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int32),
  230. "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  231. "image1": bytes("image6 bytes abc", encoding='UTF-8'),
  232. "image2": bytes("image6 bytes def", encoding='UTF-8'),
  233. "image3": bytes("image6 bytes ghi", encoding='UTF-8'),
  234. "image4": bytes("image6 bytes jkl", encoding='UTF-8'),
  235. "image5": bytes("image6 bytes mno", encoding='UTF-8')}
  236. ]
  237. schema = {"file_name": {"type": "string"},
  238. "float32_array": {"type": "float32", "shape": [-1]},
  239. "float64_array": {"type": "float64", "shape": [-1]},
  240. "float32": {"type": "float32"},
  241. "float64": {"type": "float64"},
  242. "source_sos_ids": {"type": "int32", "shape": [-1]},
  243. "source_sos_mask": {"type": "int64", "shape": [-1]},
  244. "image1": {"type": "bytes"},
  245. "image2": {"type": "bytes"},
  246. "image3": {"type": "bytes"},
  247. "label": {"type": "int32"},
  248. "image4": {"type": "bytes"},
  249. "image5": {"type": "bytes"}}
  250. writer = FileWriter(file_name, FILES_NUM)
  251. writer.add_schema(schema, "schema")
  252. writer.write_raw_data(data)
  253. writer.commit()
  254. file_name_auto = './'
  255. file_name_auto += os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  256. file_name_auto += '_auto'
  257. d1 = ds.MindDataset(file_name, None, num_readers, shuffle=False)
  258. d1.save(file_name_auto, FILES_NUM)
  259. data_value_to_list = []
  260. for item in data:
  261. new_data = {}
  262. new_data['file_name'] = np.asarray(item["file_name"], dtype='S')
  263. new_data['float32_array'] = item["float32_array"]
  264. new_data['float64_array'] = item["float64_array"]
  265. new_data['float32'] = item["float32"]
  266. new_data['float64'] = item["float64"]
  267. new_data['source_sos_ids'] = item["source_sos_ids"]
  268. new_data['source_sos_mask'] = item["source_sos_mask"]
  269. new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
  270. new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
  271. new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
  272. new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
  273. new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
  274. new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
  275. data_value_to_list.append(new_data)
  276. d2 = ds.MindDataset(dataset_file=file_name_auto,
  277. num_parallel_workers=num_readers,
  278. shuffle=False)
  279. assert d2.get_dataset_size() == 6
  280. num_iter = 0
  281. for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
  282. assert len(item) == 13
  283. for field in item:
  284. if isinstance(item[field], np.ndarray):
  285. if item[field].dtype == np.float32:
  286. assert (item[field] ==
  287. np.array(data_value_to_list[num_iter][field], np.float32)).all()
  288. else:
  289. assert (item[field] ==
  290. data_value_to_list[num_iter][field]).all()
  291. else:
  292. assert item[field] == data_value_to_list[num_iter][field]
  293. num_iter += 1
  294. assert num_iter == 6
  295. remove_file(file_name)
  296. remove_file(file_name_auto)
  297. def generator_1d():
  298. for i in range(10):
  299. yield (np.array([i]),)
  300. def test_case_03():
  301. """
  302. Feature: save op
  303. Description: 1D numpy array
  304. Expectation: generated mindrecord file
  305. """
  306. file_name_auto = './'
  307. file_name_auto += os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  308. file_name_auto += '_auto'
  309. # apply dataset operations
  310. d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
  311. d1.save(file_name_auto)
  312. d2 = ds.MindDataset(dataset_file=file_name_auto,
  313. num_parallel_workers=num_readers,
  314. shuffle=False)
  315. i = 0
  316. # each data is a dictionary
  317. for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
  318. golden = np.array([i])
  319. np.testing.assert_array_equal(item["data"], golden)
  320. i = i + 1
  321. remove_file(file_name_auto)
  322. def generator_with_type(t):
  323. for i in range(64):
  324. yield (np.array([i], dtype=t),)
  325. def type_tester(t):
  326. file_name_auto = './'
  327. file_name_auto += os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  328. file_name_auto += '_auto'
  329. logger.info("Test with Type {}".format(t.__name__))
  330. # apply dataset operations
  331. data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"], shuffle=False)
  332. data1 = data1.batch(4)
  333. data1 = data1.repeat(3)
  334. data1.save(file_name_auto)
  335. d2 = ds.MindDataset(dataset_file=file_name_auto,
  336. num_parallel_workers=num_readers,
  337. shuffle=False)
  338. i = 0
  339. num_repeat = 0
  340. # each data is a dictionary
  341. for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
  342. golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
  343. logger.info(item)
  344. np.testing.assert_array_equal(item["data"], golden)
  345. i = i + 4
  346. if i == 64:
  347. i = 0
  348. num_repeat += 1
  349. assert num_repeat == 3
  350. remove_file(file_name_auto)
  351. def test_case_04():
  352. # uint8 will drop shape as mindrecord store uint8 as bytes
  353. types = [np.int8, np.int16, np.int32, np.int64,
  354. np.uint16, np.uint32, np.float32, np.float64]
  355. for t in types:
  356. type_tester(t)
  357. def test_case_05():
  358. """
  359. Feature: save op
  360. Description: Exception Test
  361. Expectation: exception
  362. """
  363. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  364. d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
  365. with pytest.raises(Exception, match="num_files should between 0 and 1000."):
  366. d1.save(file_name, 0)
  367. def test_case_06():
  368. """
  369. Feature: save op
  370. Description: Exception Test
  371. Expectation: exception
  372. """
  373. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  374. d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
  375. with pytest.raises(Exception, match="tfrecord dataset format is not supported."):
  376. d1.save(file_name, 1, "tfrecord")
  377. def cast_name(key):
  378. """
  379. Cast schema names which containing special characters to valid names.
  380. """
  381. special_symbols = set('{}{}'.format(punctuation, ' '))
  382. special_symbols.remove('_')
  383. new_key = ['_' if x in special_symbols else x for x in key]
  384. casted_key = ''.join(new_key)
  385. return casted_key
  386. def test_case_07():
  387. """
  388. Feature: save op
  389. Description: save tfrecord files
  390. Expectation: generated mindrecord file
  391. """
  392. file_name_auto = './'
  393. file_name_auto += os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  394. file_name_auto += '_auto'
  395. d1 = ds.TFRecordDataset(TFRECORD_FILES, shuffle=False)
  396. tf_data = []
  397. for x in d1.create_dict_iterator(num_epochs=1, output_numpy=True):
  398. tf_data.append(x)
  399. d1.save(file_name_auto, FILES_NUM)
  400. d2 = ds.MindDataset(dataset_file=file_name_auto,
  401. num_parallel_workers=num_readers,
  402. shuffle=False)
  403. mr_data = []
  404. for x in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
  405. mr_data.append(x)
  406. count = 0
  407. for x in tf_data:
  408. for k, v in x.items():
  409. if isinstance(v, np.ndarray):
  410. assert (v == mr_data[count][cast_name(k)]).all()
  411. else:
  412. assert v == mr_data[count][cast_name(k)]
  413. count += 1
  414. assert count == 10
  415. remove_file(file_name_auto)
  416. def generator_dynamic_1d():
  417. arr = []
  418. for i in range(10):
  419. if i % 5 == 0:
  420. arr = []
  421. arr += [i]
  422. yield (np.array(arr),)
  423. def generator_dynamic_2d_0():
  424. for i in range(10):
  425. if i < 5:
  426. yield (np.arange(5).reshape([1, 5]),)
  427. else:
  428. yield (np.arange(10).reshape([2, 5]),)
  429. def generator_dynamic_2d_1():
  430. for i in range(10):
  431. if i < 5:
  432. yield (np.arange(5).reshape([5, 1]),)
  433. else:
  434. yield (np.arange(10).reshape([5, 2]),)
  435. def test_case_08():
  436. """
  437. Feature: save op
  438. Description: save dynamic 1D numpy array
  439. Expectation: generated mindrecord file
  440. """
  441. file_name_auto = './'
  442. file_name_auto += os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  443. file_name_auto += '_auto'
  444. # apply dataset operations
  445. d1 = ds.GeneratorDataset(generator_dynamic_1d, ["data"], shuffle=False)
  446. d1.save(file_name_auto)
  447. d2 = ds.MindDataset(dataset_file=file_name_auto,
  448. num_parallel_workers=num_readers,
  449. shuffle=False)
  450. i = 0
  451. arr = []
  452. for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
  453. if i % 5 == 0:
  454. arr = []
  455. arr += [i]
  456. golden = np.array(arr)
  457. np.testing.assert_array_equal(item["data"], golden)
  458. i = i + 1
  459. remove_file(file_name_auto)
  460. def test_case_09():
  461. """
  462. Feature: save op
  463. Description: save dynamic 2D numpy array
  464. Expectation: generated mindrecord file
  465. """
  466. file_name_auto = './'
  467. file_name_auto += os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  468. file_name_auto += '_auto'
  469. # apply dataset operations
  470. d1 = ds.GeneratorDataset(generator_dynamic_2d_0, ["data"], shuffle=False)
  471. d1.save(file_name_auto)
  472. d2 = ds.MindDataset(dataset_file=file_name_auto,
  473. num_parallel_workers=num_readers,
  474. shuffle=False)
  475. i = 0
  476. for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
  477. if i < 5:
  478. golden = np.arange(5).reshape([1, 5])
  479. else:
  480. golden = np.arange(10).reshape([2, 5])
  481. np.testing.assert_array_equal(item["data"], golden)
  482. i = i + 1
  483. remove_file(file_name_auto)
  484. def test_case_10():
  485. """
  486. Feature: save op
  487. Description: save 2D Tensor of different shape
  488. Expectation: Exception
  489. """
  490. file_name_auto = './'
  491. file_name_auto += os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  492. file_name_auto += '_auto'
  493. # apply dataset operations
  494. d1 = ds.GeneratorDataset(generator_dynamic_2d_1, ["data"], shuffle=False)
  495. with pytest.raises(Exception, match=
  496. "Error: besides dimension 0, other dimension shape is different from the previous's"):
  497. d1.save(file_name_auto)
  498. remove_file(file_name_auto)