You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_minddataset.py 72 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This is the test module for mindrecord
  17. """
  18. import collections
  19. import json
  20. import os
  21. import re
  22. import string
  23. import pytest
  24. import numpy as np
  25. import mindspore.dataset as ds
  26. import mindspore.dataset.transforms.vision.c_transforms as vision
  27. from mindspore import log as logger
  28. from mindspore.dataset.transforms.vision import Inter
  29. from mindspore.mindrecord import FileWriter
  30. FILES_NUM = 4
  31. CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
  32. CV1_FILE_NAME = "../data/mindrecord/imagenet1.mindrecord"
  33. CV2_FILE_NAME = "../data/mindrecord/imagenet2.mindrecord"
  34. CV_DIR_NAME = "../data/mindrecord/testImageNetData"
  35. NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord"
  36. OLD_NLP_FILE_NAME = "../data/mindrecord/testOldVersion/aclImdb.mindrecord"
  37. NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos"
  38. NLP_FILE_VOCAB = "../data/mindrecord/testAclImdbData/vocab.txt"
  39. @pytest.fixture
  40. def add_and_remove_cv_file():
  41. """add/remove cv file"""
  42. paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
  43. for x in range(FILES_NUM)]
  44. for x in paths:
  45. if os.path.exists("{}".format(x)):
  46. os.remove("{}".format(x))
  47. if os.path.exists("{}.db".format(x)):
  48. os.remove("{}.db".format(x))
  49. writer = FileWriter(CV_FILE_NAME, FILES_NUM)
  50. data = get_data(CV_DIR_NAME)
  51. cv_schema_json = {"id": {"type": "int32"},
  52. "file_name": {"type": "string"},
  53. "label": {"type": "int32"},
  54. "data": {"type": "bytes"}}
  55. writer.add_schema(cv_schema_json, "img_schema")
  56. writer.add_index(["file_name", "label"])
  57. writer.write_raw_data(data)
  58. writer.commit()
  59. yield "yield_cv_data"
  60. for x in paths:
  61. os.remove("{}".format(x))
  62. os.remove("{}.db".format(x))
  63. @pytest.fixture
  64. def add_and_remove_nlp_file():
  65. """add/remove nlp file"""
  66. paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0'))
  67. for x in range(FILES_NUM)]
  68. for x in paths:
  69. if os.path.exists("{}".format(x)):
  70. os.remove("{}".format(x))
  71. if os.path.exists("{}.db".format(x)):
  72. os.remove("{}.db".format(x))
  73. writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
  74. data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
  75. nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"},
  76. "rating": {"type": "float32"},
  77. "input_ids": {"type": "int64",
  78. "shape": [-1]},
  79. "input_mask": {"type": "int64",
  80. "shape": [1, -1]},
  81. "segment_ids": {"type": "int64",
  82. "shape": [2, -1]}
  83. }
  84. writer.set_header_size(1 << 14)
  85. writer.set_page_size(1 << 15)
  86. writer.add_schema(nlp_schema_json, "nlp_schema")
  87. writer.add_index(["id", "rating"])
  88. writer.write_raw_data(data)
  89. writer.commit()
  90. yield "yield_nlp_data"
  91. for x in paths:
  92. os.remove("{}".format(x))
  93. os.remove("{}.db".format(x))
  94. @pytest.fixture
  95. def add_and_remove_nlp_compress_file():
  96. """add/remove nlp file"""
  97. paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0'))
  98. for x in range(FILES_NUM)]
  99. for x in paths:
  100. if os.path.exists("{}".format(x)):
  101. os.remove("{}".format(x))
  102. if os.path.exists("{}.db".format(x)):
  103. os.remove("{}.db".format(x))
  104. writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
  105. data = []
  106. for row_id in range(16):
  107. data.append({
  108. "label": row_id,
  109. "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129,
  110. 255, 256, -32768, 32767, -32769, 32768, -2147483648,
  111. 2147483647], dtype=np.int32), [-1]),
  112. "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
  113. 256, -32768, 32767, -32769, 32768,
  114. -2147483648, 2147483647, -2147483649, 2147483649,
  115. -922337036854775808, 9223372036854775807]), [1, -1]),
  116. "array_c": str.encode("nlp data"),
  117. "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1])
  118. })
  119. nlp_schema_json = {"label": {"type": "int32"},
  120. "array_a": {"type": "int32",
  121. "shape": [-1]},
  122. "array_b": {"type": "int64",
  123. "shape": [1, -1]},
  124. "array_c": {"type": "bytes"},
  125. "array_d": {"type": "int64",
  126. "shape": [2, -1]}
  127. }
  128. writer.set_header_size(1 << 14)
  129. writer.set_page_size(1 << 15)
  130. writer.add_schema(nlp_schema_json, "nlp_schema")
  131. writer.write_raw_data(data)
  132. writer.commit()
  133. yield "yield_nlp_data"
  134. for x in paths:
  135. os.remove("{}".format(x))
  136. os.remove("{}.db".format(x))
  137. def test_nlp_compress_data(add_and_remove_nlp_compress_file):
  138. """tutorial for nlp minderdataset."""
  139. data = []
  140. for row_id in range(16):
  141. data.append({
  142. "label": row_id,
  143. "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129,
  144. 255, 256, -32768, 32767, -32769, 32768, -2147483648,
  145. 2147483647], dtype=np.int32), [-1]),
  146. "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
  147. 256, -32768, 32767, -32769, 32768,
  148. -2147483648, 2147483647, -2147483649, 2147483649,
  149. -922337036854775808, 9223372036854775807]), [1, -1]),
  150. "array_c": str.encode("nlp data"),
  151. "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1])
  152. })
  153. num_readers = 1
  154. data_set = ds.MindDataset(
  155. NLP_FILE_NAME + "0", None, num_readers, shuffle=False)
  156. assert data_set.get_dataset_size() == 16
  157. num_iter = 0
  158. for x, item in zip(data, data_set.create_dict_iterator()):
  159. assert (item["array_a"] == x["array_a"]).all()
  160. assert (item["array_b"] == x["array_b"]).all()
  161. assert item["array_c"].tobytes() == x["array_c"]
  162. assert (item["array_d"] == x["array_d"]).all()
  163. assert item["label"] == x["label"]
  164. num_iter += 1
  165. assert num_iter == 16
  166. def test_nlp_compress_data_old_version(add_and_remove_nlp_compress_file):
  167. """tutorial for nlp minderdataset."""
  168. num_readers = 1
  169. data_set = ds.MindDataset(
  170. NLP_FILE_NAME + "0", None, num_readers, shuffle=False)
  171. old_data_set = ds.MindDataset(
  172. OLD_NLP_FILE_NAME + "0", None, num_readers, shuffle=False)
  173. assert old_data_set.get_dataset_size() == 16
  174. num_iter = 0
  175. for x, item in zip(old_data_set.create_dict_iterator(), data_set.create_dict_iterator()):
  176. assert (item["array_a"] == x["array_a"]).all()
  177. assert (item["array_b"] == x["array_b"]).all()
  178. assert (item["array_c"] == x["array_c"]).all()
  179. assert (item["array_d"] == x["array_d"]).all()
  180. assert item["label"] == x["label"]
  181. num_iter += 1
  182. assert num_iter == 16
  183. def test_cv_minddataset_writer_tutorial():
  184. """tutorial for cv dataset writer."""
  185. paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
  186. for x in range(FILES_NUM)]
  187. for x in paths:
  188. if os.path.exists("{}".format(x)):
  189. os.remove("{}".format(x))
  190. if os.path.exists("{}.db".format(x)):
  191. os.remove("{}.db".format(x))
  192. writer = FileWriter(CV_FILE_NAME, FILES_NUM)
  193. data = get_data(CV_DIR_NAME)
  194. cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"},
  195. "data": {"type": "bytes"}}
  196. writer.add_schema(cv_schema_json, "img_schema")
  197. writer.add_index(["file_name", "label"])
  198. writer.write_raw_data(data)
  199. writer.commit()
  200. for x in paths:
  201. os.remove("{}".format(x))
  202. os.remove("{}.db".format(x))
  203. def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
  204. """tutorial for cv minddataset."""
  205. columns_list = ["data", "file_name", "label"]
  206. num_readers = 4
  207. def partitions(num_shards):
  208. for partition_id in range(num_shards):
  209. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  210. num_shards=num_shards, shard_id=partition_id)
  211. num_iter = 0
  212. for item in data_set.create_dict_iterator():
  213. logger.info("-------------- partition : {} ------------------------".format(partition_id))
  214. logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
  215. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  216. num_iter += 1
  217. return num_iter
  218. assert partitions(4) == 3
  219. assert partitions(5) == 2
  220. assert partitions(9) == 2
  221. def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file):
  222. """tutorial for cv minddataset."""
  223. columns_list = ["data", "file_name", "label"]
  224. num_readers = 4
  225. num_shards = 3
  226. epoch1 = []
  227. epoch2 = []
  228. epoch3 = []
  229. for partition_id in range(num_shards):
  230. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  231. num_shards=num_shards, shard_id=partition_id)
  232. data_set = data_set.repeat(3)
  233. num_iter = 0
  234. for item in data_set.create_dict_iterator():
  235. logger.info("-------------- partition : {} ------------------------".format(partition_id))
  236. logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
  237. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  238. num_iter += 1
  239. if num_iter <= 4:
  240. epoch1.append(item["file_name"]) # save epoch 1 list
  241. elif num_iter <= 8:
  242. epoch2.append(item["file_name"]) # save epoch 2 list
  243. else:
  244. epoch3.append(item["file_name"]) # save epoch 3 list
  245. assert num_iter == 12
  246. assert len(epoch1) == 4
  247. assert len(epoch2) == 4
  248. assert len(epoch3) == 4
  249. assert epoch1 not in (epoch2, epoch3)
  250. assert epoch2 not in (epoch1, epoch3)
  251. assert epoch3 not in (epoch1, epoch2)
  252. epoch1 = []
  253. epoch2 = []
  254. epoch3 = []
  255. def test_cv_minddataset_check_shuffle_result(add_and_remove_cv_file):
  256. """tutorial for cv minddataset."""
  257. columns_list = ["data", "file_name", "label"]
  258. num_readers = 4
  259. ds.config.set_seed(54321)
  260. epoch1 = []
  261. epoch2 = []
  262. epoch3 = []
  263. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  264. data_set = data_set.repeat(3)
  265. num_iter = 0
  266. for item in data_set.create_dict_iterator():
  267. logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
  268. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  269. num_iter += 1
  270. if num_iter <= 10:
  271. epoch1.append(item["file_name"]) # save epoch 1 list
  272. elif num_iter <= 20:
  273. epoch2.append(item["file_name"]) # save epoch 2 list
  274. else:
  275. epoch3.append(item["file_name"]) # save epoch 3 list
  276. assert num_iter == 30
  277. assert len(epoch1) == 10
  278. assert len(epoch2) == 10
  279. assert len(epoch3) == 10
  280. assert epoch1 not in (epoch2, epoch3)
  281. assert epoch2 not in (epoch1, epoch3)
  282. assert epoch3 not in (epoch1, epoch2)
  283. epoch1_new_dataset = []
  284. epoch2_new_dataset = []
  285. epoch3_new_dataset = []
  286. data_set2 = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  287. data_set2 = data_set2.repeat(3)
  288. num_iter = 0
  289. for item in data_set2.create_dict_iterator():
  290. logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
  291. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  292. num_iter += 1
  293. if num_iter <= 10:
  294. epoch1_new_dataset.append(item["file_name"]) # save epoch 1 list
  295. elif num_iter <= 20:
  296. epoch2_new_dataset.append(item["file_name"]) # save epoch 2 list
  297. else:
  298. epoch3_new_dataset.append(item["file_name"]) # save epoch 3 list
  299. assert num_iter == 30
  300. assert len(epoch1_new_dataset) == 10
  301. assert len(epoch2_new_dataset) == 10
  302. assert len(epoch3_new_dataset) == 10
  303. assert epoch1_new_dataset not in (epoch2_new_dataset, epoch3_new_dataset)
  304. assert epoch2_new_dataset not in (epoch1_new_dataset, epoch3_new_dataset)
  305. assert epoch3_new_dataset not in (epoch1_new_dataset, epoch2_new_dataset)
  306. assert epoch1 == epoch1_new_dataset
  307. assert epoch2 == epoch2_new_dataset
  308. assert epoch3 == epoch3_new_dataset
  309. ds.config.set_seed(12345)
  310. epoch1_new_dataset2 = []
  311. epoch2_new_dataset2 = []
  312. epoch3_new_dataset2 = []
  313. data_set3 = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  314. data_set3 = data_set3.repeat(3)
  315. num_iter = 0
  316. for item in data_set3.create_dict_iterator():
  317. logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
  318. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  319. num_iter += 1
  320. if num_iter <= 10:
  321. epoch1_new_dataset2.append(item["file_name"]) # save epoch 1 list
  322. elif num_iter <= 20:
  323. epoch2_new_dataset2.append(item["file_name"]) # save epoch 2 list
  324. else:
  325. epoch3_new_dataset2.append(item["file_name"]) # save epoch 3 list
  326. assert num_iter == 30
  327. assert len(epoch1_new_dataset2) == 10
  328. assert len(epoch2_new_dataset2) == 10
  329. assert len(epoch3_new_dataset2) == 10
  330. assert epoch1_new_dataset2 not in (epoch2_new_dataset2, epoch3_new_dataset2)
  331. assert epoch2_new_dataset2 not in (epoch1_new_dataset2, epoch3_new_dataset2)
  332. assert epoch3_new_dataset2 not in (epoch1_new_dataset2, epoch2_new_dataset2)
  333. assert epoch1 != epoch1_new_dataset2
  334. assert epoch2 != epoch2_new_dataset2
  335. assert epoch3 != epoch3_new_dataset2
  336. def test_cv_minddataset_dataset_size(add_and_remove_cv_file):
  337. """tutorial for cv minddataset."""
  338. columns_list = ["data", "file_name", "label"]
  339. num_readers = 4
  340. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  341. assert data_set.get_dataset_size() == 10
  342. repeat_num = 2
  343. data_set = data_set.repeat(repeat_num)
  344. num_iter = 0
  345. for item in data_set.create_dict_iterator():
  346. logger.info(
  347. "-------------- get dataset size {} -----------------".format(num_iter))
  348. logger.info(
  349. "-------------- item[label]: {} ---------------------".format(item["label"]))
  350. logger.info(
  351. "-------------- item[data]: {} ----------------------".format(item["data"]))
  352. num_iter += 1
  353. assert num_iter == 20
  354. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  355. num_shards=4, shard_id=3)
  356. assert data_set.get_dataset_size() == 3
  357. def test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file):
  358. """tutorial for cv minddataset."""
  359. columns_list = ["data", "label"]
  360. num_readers = 4
  361. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  362. decode_op = vision.Decode()
  363. data_set = data_set.map(
  364. input_columns=["data"], operations=decode_op, num_parallel_workers=2)
  365. resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
  366. data_set = data_set.map(input_columns="data",
  367. operations=resize_op, num_parallel_workers=2)
  368. data_set = data_set.batch(2)
  369. data_set = data_set.repeat(2)
  370. num_iter = 0
  371. labels = []
  372. for item in data_set.create_dict_iterator():
  373. logger.info(
  374. "-------------- get dataset size {} -----------------".format(num_iter))
  375. logger.info(
  376. "-------------- item[label]: {} ---------------------".format(item["label"]))
  377. logger.info(
  378. "-------------- item[data]: {} ----------------------".format(item["data"]))
  379. num_iter += 1
  380. labels.append(item["label"])
  381. assert num_iter == 10
  382. logger.info("repeat shuffle: {}".format(labels))
  383. assert len(labels) == 10
  384. assert labels[0:5] == labels[0:5]
  385. assert labels[0:5] != labels[5:5]
  386. def test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file):
  387. """tutorial for cv minddataset."""
  388. columns_list = ["data", "label"]
  389. num_readers = 4
  390. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  391. decode_op = vision.Decode()
  392. data_set = data_set.map(
  393. input_columns=["data"], operations=decode_op, num_parallel_workers=2)
  394. resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
  395. data_set = data_set.map(input_columns="data",
  396. operations=resize_op, num_parallel_workers=2)
  397. data_set = data_set.batch(32, drop_remainder=True)
  398. num_iter = 0
  399. for item in data_set.create_dict_iterator():
  400. logger.info(
  401. "-------------- get dataset size {} -----------------".format(num_iter))
  402. logger.info(
  403. "-------------- item[label]: {} ---------------------".format(item["label"]))
  404. logger.info(
  405. "-------------- item[data]: {} ----------------------".format(item["data"]))
  406. num_iter += 1
  407. assert num_iter == 0
  408. def test_cv_minddataset_issue_888(add_and_remove_cv_file):
  409. """issue 888 test."""
  410. columns_list = ["data", "label"]
  411. num_readers = 2
  412. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False, num_shards=5, shard_id=1)
  413. data_set = data_set.shuffle(2)
  414. data_set = data_set.repeat(9)
  415. num_iter = 0
  416. for _ in data_set.create_dict_iterator():
  417. num_iter += 1
  418. assert num_iter == 18
  419. def test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file):
  420. """tutorial for cv minddataset."""
  421. columns_list = ["data", "label"]
  422. num_readers = 4
  423. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, block_reader=True)
  424. assert data_set.get_dataset_size() == 10
  425. repeat_num = 2
  426. data_set = data_set.repeat(repeat_num)
  427. num_iter = 0
  428. for item in data_set.create_dict_iterator():
  429. logger.info(
  430. "-------------- block reader repeat tow {} -----------------".format(num_iter))
  431. logger.info(
  432. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  433. logger.info(
  434. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  435. num_iter += 1
  436. assert num_iter == 20
  437. def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_remove_cv_file):
  438. """tutorial for cv minddataset."""
  439. columns_list = ["id", "data", "label"]
  440. num_readers = 4
  441. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False,
  442. block_reader=True)
  443. assert data_set.get_dataset_size() == 10
  444. repeat_num = 2
  445. data_set = data_set.repeat(repeat_num)
  446. num_iter = 0
  447. for item in data_set.create_dict_iterator():
  448. logger.info(
  449. "-------------- block reader repeat tow {} -----------------".format(num_iter))
  450. logger.info(
  451. "-------------- item[id]: {} ----------------------------".format(item["id"]))
  452. logger.info(
  453. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  454. logger.info(
  455. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  456. num_iter += 1
  457. assert num_iter == 20
  458. def test_cv_minddataset_reader_file_list(add_and_remove_cv_file):
  459. """tutorial for cv minderdataset."""
  460. columns_list = ["data", "file_name", "label"]
  461. num_readers = 4
  462. data_set = ds.MindDataset([CV_FILE_NAME + str(x)
  463. for x in range(FILES_NUM)], columns_list, num_readers)
  464. assert data_set.get_dataset_size() == 10
  465. num_iter = 0
  466. for item in data_set.create_dict_iterator():
  467. logger.info(
  468. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  469. logger.info(
  470. "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  471. logger.info(
  472. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  473. logger.info(
  474. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  475. logger.info(
  476. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  477. num_iter += 1
  478. assert num_iter == 10
  479. def test_cv_minddataset_reader_one_partition(add_and_remove_cv_file):
  480. """tutorial for cv minderdataset."""
  481. columns_list = ["data", "file_name", "label"]
  482. num_readers = 4
  483. data_set = ds.MindDataset([CV_FILE_NAME + "0"], columns_list, num_readers)
  484. assert data_set.get_dataset_size() < 10
  485. num_iter = 0
  486. for item in data_set.create_dict_iterator():
  487. logger.info(
  488. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  489. logger.info(
  490. "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  491. logger.info(
  492. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  493. logger.info(
  494. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  495. logger.info(
  496. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  497. num_iter += 1
  498. assert num_iter < 10
  499. def test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file):
  500. """tutorial for cv minderdataset."""
  501. if os.path.exists(CV1_FILE_NAME):
  502. os.remove(CV1_FILE_NAME)
  503. if os.path.exists("{}.db".format(CV1_FILE_NAME)):
  504. os.remove("{}.db".format(CV1_FILE_NAME))
  505. if os.path.exists(CV2_FILE_NAME):
  506. os.remove(CV2_FILE_NAME)
  507. if os.path.exists("{}.db".format(CV2_FILE_NAME)):
  508. os.remove("{}.db".format(CV2_FILE_NAME))
  509. writer = FileWriter(CV1_FILE_NAME, 1)
  510. data = get_data(CV_DIR_NAME)
  511. cv_schema_json = {"id": {"type": "int32"},
  512. "file_name": {"type": "string"},
  513. "label": {"type": "int32"},
  514. "data": {"type": "bytes"}}
  515. writer.add_schema(cv_schema_json, "CV1_schema")
  516. writer.add_index(["file_name", "label"])
  517. writer.write_raw_data(data)
  518. writer.commit()
  519. writer = FileWriter(CV2_FILE_NAME, 1)
  520. data = get_data(CV_DIR_NAME)
  521. cv_schema_json = {"id": {"type": "int32"},
  522. "file_name": {"type": "string"},
  523. "label": {"type": "int32"},
  524. "data": {"type": "bytes"}}
  525. writer.add_schema(cv_schema_json, "CV2_schema")
  526. writer.add_index(["file_name", "label"])
  527. writer.write_raw_data(data)
  528. writer.commit()
  529. columns_list = ["data", "file_name", "label"]
  530. num_readers = 4
  531. data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)] + [CV1_FILE_NAME, CV2_FILE_NAME],
  532. columns_list, num_readers)
  533. assert data_set.get_dataset_size() == 30
  534. num_iter = 0
  535. for item in data_set.create_dict_iterator():
  536. logger.info(
  537. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  538. logger.info(
  539. "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  540. logger.info(
  541. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  542. logger.info(
  543. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  544. logger.info(
  545. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  546. num_iter += 1
  547. assert num_iter == 30
  548. if os.path.exists(CV1_FILE_NAME):
  549. os.remove(CV1_FILE_NAME)
  550. if os.path.exists("{}.db".format(CV1_FILE_NAME)):
  551. os.remove("{}.db".format(CV1_FILE_NAME))
  552. if os.path.exists(CV2_FILE_NAME):
  553. os.remove(CV2_FILE_NAME)
  554. if os.path.exists("{}.db".format(CV2_FILE_NAME)):
  555. os.remove("{}.db".format(CV2_FILE_NAME))
  556. def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file):
  557. paths = ["{}{}".format(CV1_FILE_NAME, str(x).rjust(1, '0'))
  558. for x in range(FILES_NUM)]
  559. for x in paths:
  560. if os.path.exists("{}".format(x)):
  561. os.remove("{}".format(x))
  562. if os.path.exists("{}.db".format(x)):
  563. os.remove("{}.db".format(x))
  564. writer = FileWriter(CV1_FILE_NAME, FILES_NUM)
  565. data = get_data(CV_DIR_NAME)
  566. cv_schema_json = {"id": {"type": "int32"},
  567. "file_name": {"type": "string"},
  568. "label": {"type": "int32"},
  569. "data": {"type": "bytes"}}
  570. writer.add_schema(cv_schema_json, "CV1_schema")
  571. writer.add_index(["file_name", "label"])
  572. writer.write_raw_data(data)
  573. writer.commit()
  574. columns_list = ["data", "file_name", "label"]
  575. num_readers = 4
  576. data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(2)] + [CV1_FILE_NAME + str(x) for x in range(2, 4)],
  577. columns_list, num_readers)
  578. assert data_set.get_dataset_size() < 20
  579. num_iter = 0
  580. for item in data_set.create_dict_iterator():
  581. logger.info(
  582. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  583. logger.info(
  584. "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  585. logger.info(
  586. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  587. logger.info(
  588. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  589. logger.info(
  590. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  591. num_iter += 1
  592. assert num_iter < 20
  593. for x in paths:
  594. os.remove("{}".format(x))
  595. os.remove("{}.db".format(x))
  596. def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file):
  597. """tutorial for cv minderdataset."""
  598. columns_list = ["data", "file_name", "label"]
  599. num_readers = 4
  600. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  601. assert data_set.get_dataset_size() == 10
  602. num_iter = 0
  603. for item in data_set.create_dict_iterator():
  604. logger.info(
  605. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  606. logger.info(
  607. "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  608. logger.info(
  609. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  610. logger.info(
  611. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  612. logger.info(
  613. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  614. num_iter += 1
  615. assert num_iter == 10
  616. def test_nlp_minddataset_reader_basic_tutorial(add_and_remove_nlp_file):
  617. """tutorial for nlp minderdataset."""
  618. num_readers = 4
  619. data_set = ds.MindDataset(NLP_FILE_NAME + "0", None, num_readers)
  620. assert data_set.get_dataset_size() == 10
  621. num_iter = 0
  622. for item in data_set.create_dict_iterator():
  623. logger.info(
  624. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  625. logger.info(
  626. "-------------- num_iter: {} ------------------------".format(num_iter))
  627. logger.info(
  628. "-------------- item[id]: {} ------------------------".format(item["id"]))
  629. logger.info(
  630. "-------------- item[rating]: {} --------------------".format(item["rating"]))
  631. logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format(
  632. item["input_ids"], item["input_ids"].shape))
  633. logger.info("-------------- item[input_mask]: {}, shape: {} -----------------".format(
  634. item["input_mask"], item["input_mask"].shape))
  635. logger.info("-------------- item[segment_ids]: {}, shape: {} -----------------".format(
  636. item["segment_ids"], item["segment_ids"].shape))
  637. assert item["input_ids"].shape == (50,)
  638. assert item["input_mask"].shape == (1, 50)
  639. assert item["segment_ids"].shape == (2, 25)
  640. num_iter += 1
  641. assert num_iter == 10
  642. def test_cv_minddataset_reader_basic_tutorial_5_epoch(add_and_remove_cv_file):
  643. """tutorial for cv minderdataset."""
  644. columns_list = ["data", "file_name", "label"]
  645. num_readers = 4
  646. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  647. assert data_set.get_dataset_size() == 10
  648. for _ in range(5):
  649. num_iter = 0
  650. for data in data_set:
  651. logger.info("data is {}".format(data))
  652. num_iter += 1
  653. assert num_iter == 10
  654. data_set.reset()
  655. def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_cv_file):
  656. """tutorial for cv minderdataset."""
  657. columns_list = ["data", "label"]
  658. num_readers = 4
  659. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  660. resize_height = 32
  661. resize_width = 32
  662. # define map operations
  663. decode_op = vision.Decode()
  664. resize_op = vision.Resize(
  665. (resize_height, resize_width), ds.transforms.vision.Inter.LINEAR)
  666. data_set = data_set.map(
  667. input_columns=["data"], operations=decode_op, num_parallel_workers=4)
  668. data_set = data_set.map(
  669. input_columns=["data"], operations=resize_op, num_parallel_workers=4)
  670. data_set = data_set.batch(2)
  671. assert data_set.get_dataset_size() == 5
  672. for _ in range(5):
  673. num_iter = 0
  674. for data in data_set:
  675. logger.info("data is {}".format(data))
  676. num_iter += 1
  677. assert num_iter == 5
  678. data_set.reset()
  679. def test_cv_minddataset_reader_no_columns(add_and_remove_cv_file):
  680. """tutorial for cv minderdataset."""
  681. data_set = ds.MindDataset(CV_FILE_NAME + "0")
  682. assert data_set.get_dataset_size() == 10
  683. num_iter = 0
  684. for item in data_set.create_dict_iterator():
  685. logger.info(
  686. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  687. logger.info(
  688. "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  689. logger.info(
  690. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  691. logger.info(
  692. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  693. logger.info(
  694. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  695. num_iter += 1
  696. assert num_iter == 10
  697. def test_cv_minddataset_reader_repeat_tutorial(add_and_remove_cv_file):
  698. """tutorial for cv minderdataset."""
  699. columns_list = ["data", "file_name", "label"]
  700. num_readers = 4
  701. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  702. repeat_num = 2
  703. data_set = data_set.repeat(repeat_num)
  704. num_iter = 0
  705. for item in data_set.create_dict_iterator():
  706. logger.info(
  707. "-------------- repeat two test {} ------------------------".format(num_iter))
  708. logger.info(
  709. "-------------- len(item[data]): {} -----------------------".format(len(item["data"])))
  710. logger.info(
  711. "-------------- item[data]: {} ----------------------------".format(item["data"]))
  712. logger.info(
  713. "-------------- item[file_name]: {} -----------------------".format(item["file_name"]))
  714. logger.info(
  715. "-------------- item[label]: {} ---------------------------".format(item["label"]))
  716. num_iter += 1
  717. assert num_iter == 20
  718. def get_data(dir_name):
  719. """
  720. usage: get data from imagenet dataset
  721. params:
  722. dir_name: directory containing folder images and annotation information
  723. """
  724. if not os.path.isdir(dir_name):
  725. raise IOError("Directory {} not exists".format(dir_name))
  726. img_dir = os.path.join(dir_name, "images")
  727. ann_file = os.path.join(dir_name, "annotation.txt")
  728. with open(ann_file, "r") as file_reader:
  729. lines = file_reader.readlines()
  730. data_list = []
  731. for i, line in enumerate(lines):
  732. try:
  733. filename, label = line.split(",")
  734. label = label.strip("\n")
  735. with open(os.path.join(img_dir, filename), "rb") as file_reader:
  736. img = file_reader.read()
  737. data_json = {"id": i,
  738. "file_name": filename,
  739. "data": img,
  740. "label": int(label)}
  741. data_list.append(data_json)
  742. except FileNotFoundError:
  743. continue
  744. return data_list
  745. def get_multi_bytes_data(file_name, bytes_num=3):
  746. """
  747. Return raw data of multi-bytes dataset.
  748. Args:
  749. file_name (str): String of multi-bytes dataset's path.
  750. bytes_num (int): Number of bytes fields.
  751. Returns:
  752. List
  753. """
  754. if not os.path.exists(file_name):
  755. raise IOError("map file {} not exists".format(file_name))
  756. dir_name = os.path.dirname(file_name)
  757. with open(file_name, "r") as file_reader:
  758. lines = file_reader.readlines()
  759. data_list = []
  760. row_num = 0
  761. for line in lines:
  762. try:
  763. img10_path = line.strip('\n').split(" ")
  764. img5 = []
  765. for path in img10_path[:bytes_num]:
  766. with open(os.path.join(dir_name, path), "rb") as file_reader:
  767. img5 += [file_reader.read()]
  768. data_json = {"image_{}".format(i): img5[i]
  769. for i in range(len(img5))}
  770. data_json.update({"id": row_num})
  771. row_num += 1
  772. data_list.append(data_json)
  773. except FileNotFoundError:
  774. continue
  775. return data_list
  776. def get_mkv_data(dir_name):
  777. """
  778. Return raw data of Vehicle_and_Person dataset.
  779. Args:
  780. dir_name (str): String of Vehicle_and_Person dataset's path.
  781. Returns:
  782. List
  783. """
  784. if not os.path.isdir(dir_name):
  785. raise IOError("Directory {} not exists".format(dir_name))
  786. img_dir = os.path.join(dir_name, "Image")
  787. label_dir = os.path.join(dir_name, "prelabel")
  788. data_list = []
  789. file_list = os.listdir(label_dir)
  790. index = 1
  791. for item in file_list:
  792. if os.path.splitext(item)[1] == '.json':
  793. file_path = os.path.join(label_dir, item)
  794. image_name = ''.join([os.path.splitext(item)[0], ".jpg"])
  795. image_path = os.path.join(img_dir, image_name)
  796. with open(file_path, "r") as load_f:
  797. load_dict = json.load(load_f)
  798. if os.path.exists(image_path):
  799. with open(image_path, "rb") as file_reader:
  800. img = file_reader.read()
  801. data_json = {"file_name": image_name,
  802. "prelabel": str(load_dict),
  803. "data": img,
  804. "id": index}
  805. data_list.append(data_json)
  806. index += 1
  807. logger.info('{} images are missing'.format(
  808. len(file_list) - len(data_list)))
  809. return data_list
  810. def get_nlp_data(dir_name, vocab_file, num):
  811. """
  812. Return raw data of aclImdb dataset.
  813. Args:
  814. dir_name (str): String of aclImdb dataset's path.
  815. vocab_file (str): String of dictionary's path.
  816. num (int): Number of sample.
  817. Returns:
  818. List
  819. """
  820. if not os.path.isdir(dir_name):
  821. raise IOError("Directory {} not exists".format(dir_name))
  822. for root, _, files in os.walk(dir_name):
  823. for index, file_name_extension in enumerate(files):
  824. if index < num:
  825. file_path = os.path.join(root, file_name_extension)
  826. file_name, _ = file_name_extension.split('.', 1)
  827. id_, rating = file_name.split('_', 1)
  828. with open(file_path, 'r') as f:
  829. raw_content = f.read()
  830. dictionary = load_vocab(vocab_file)
  831. vectors = [dictionary.get('[CLS]')]
  832. vectors += [dictionary.get(i) if i in dictionary
  833. else dictionary.get('[UNK]')
  834. for i in re.findall(r"[\w']+|[{}]"
  835. .format(string.punctuation),
  836. raw_content)]
  837. vectors += [dictionary.get('[SEP]')]
  838. input_, mask, segment = inputs(vectors)
  839. input_ids = np.reshape(np.array(input_), [-1])
  840. input_mask = np.reshape(np.array(mask), [1, -1])
  841. segment_ids = np.reshape(np.array(segment), [2, -1])
  842. data = {
  843. "label": 1,
  844. "id": id_,
  845. "rating": float(rating),
  846. "input_ids": input_ids,
  847. "input_mask": input_mask,
  848. "segment_ids": segment_ids
  849. }
  850. yield data
  851. def convert_to_uni(text):
  852. if isinstance(text, str):
  853. return text
  854. if isinstance(text, bytes):
  855. return text.decode('utf-8', 'ignore')
  856. raise Exception("The type %s does not convert!" % type(text))
  857. def load_vocab(vocab_file):
  858. """load vocabulary to translate statement."""
  859. vocab = collections.OrderedDict()
  860. vocab.setdefault('blank', 2)
  861. index = 0
  862. with open(vocab_file) as reader:
  863. while True:
  864. tmp = reader.readline()
  865. if not tmp:
  866. break
  867. token = convert_to_uni(tmp)
  868. token = token.strip()
  869. vocab[token] = index
  870. index += 1
  871. return vocab
  872. def inputs(vectors, maxlen=50):
  873. length = len(vectors)
  874. if length > maxlen:
  875. return vectors[0:maxlen], [1] * maxlen, [0] * maxlen
  876. input_ = vectors + [0] * (maxlen - length)
  877. mask = [1] * length + [0] * (maxlen - length)
  878. segment = [0] * maxlen
  879. return input_, mask, segment
  880. def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
  881. mindrecord_file_name = "test.mindrecord"
  882. if os.path.exists("{}".format(mindrecord_file_name)):
  883. os.remove("{}".format(mindrecord_file_name))
  884. if os.path.exists("{}.db".format(mindrecord_file_name)):
  885. os.remove("{}.db".format(mindrecord_file_name))
  886. data = [{"file_name": "001.jpg", "label": 4,
  887. "image1": bytes("image1 bytes abc", encoding='UTF-8'),
  888. "image2": bytes("image1 bytes def", encoding='UTF-8'),
  889. "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
  890. "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  891. "image3": bytes("image1 bytes ghi", encoding='UTF-8'),
  892. "image4": bytes("image1 bytes jkl", encoding='UTF-8'),
  893. "image5": bytes("image1 bytes mno", encoding='UTF-8'),
  894. "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
  895. "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
  896. "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  897. "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
  898. {"file_name": "002.jpg", "label": 5,
  899. "image1": bytes("image2 bytes abc", encoding='UTF-8'),
  900. "image2": bytes("image2 bytes def", encoding='UTF-8'),
  901. "image3": bytes("image2 bytes ghi", encoding='UTF-8'),
  902. "image4": bytes("image2 bytes jkl", encoding='UTF-8'),
  903. "image5": bytes("image2 bytes mno", encoding='UTF-8'),
  904. "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
  905. "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  906. "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
  907. "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
  908. "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  909. "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
  910. {"file_name": "003.jpg", "label": 6,
  911. "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
  912. "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  913. "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
  914. "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
  915. "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  916. "image1": bytes("image3 bytes abc", encoding='UTF-8'),
  917. "image2": bytes("image3 bytes def", encoding='UTF-8'),
  918. "image3": bytes("image3 bytes ghi", encoding='UTF-8'),
  919. "image4": bytes("image3 bytes jkl", encoding='UTF-8'),
  920. "image5": bytes("image3 bytes mno", encoding='UTF-8'),
  921. "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
  922. {"file_name": "004.jpg", "label": 7,
  923. "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
  924. "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  925. "image1": bytes("image4 bytes abc", encoding='UTF-8'),
  926. "image2": bytes("image4 bytes def", encoding='UTF-8'),
  927. "image3": bytes("image4 bytes ghi", encoding='UTF-8'),
  928. "image4": bytes("image4 bytes jkl", encoding='UTF-8'),
  929. "image5": bytes("image4 bytes mno", encoding='UTF-8'),
  930. "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
  931. "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
  932. "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  933. "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
  934. {"file_name": "005.jpg", "label": 8,
  935. "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
  936. "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  937. "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
  938. "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
  939. "image1": bytes("image5 bytes abc", encoding='UTF-8'),
  940. "image2": bytes("image5 bytes def", encoding='UTF-8'),
  941. "image3": bytes("image5 bytes ghi", encoding='UTF-8'),
  942. "image4": bytes("image5 bytes jkl", encoding='UTF-8'),
  943. "image5": bytes("image5 bytes mno", encoding='UTF-8'),
  944. "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  945. "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
  946. {"file_name": "006.jpg", "label": 9,
  947. "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
  948. "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  949. "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
  950. "image1": bytes("image6 bytes abc", encoding='UTF-8'),
  951. "image2": bytes("image6 bytes def", encoding='UTF-8'),
  952. "image3": bytes("image6 bytes ghi", encoding='UTF-8'),
  953. "image4": bytes("image6 bytes jkl", encoding='UTF-8'),
  954. "image5": bytes("image6 bytes mno", encoding='UTF-8'),
  955. "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
  956. "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  957. "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
  958. ]
  959. writer = FileWriter(mindrecord_file_name)
  960. schema = {"file_name": {"type": "string"},
  961. "image1": {"type": "bytes"},
  962. "image2": {"type": "bytes"},
  963. "source_sos_ids": {"type": "int64", "shape": [-1]},
  964. "source_sos_mask": {"type": "int64", "shape": [-1]},
  965. "image3": {"type": "bytes"},
  966. "image4": {"type": "bytes"},
  967. "image5": {"type": "bytes"},
  968. "target_sos_ids": {"type": "int64", "shape": [-1]},
  969. "target_sos_mask": {"type": "int64", "shape": [-1]},
  970. "target_eos_ids": {"type": "int64", "shape": [-1]},
  971. "target_eos_mask": {"type": "int64", "shape": [-1]},
  972. "label": {"type": "int32"}}
  973. writer.add_schema(schema, "data is so cool")
  974. writer.write_raw_data(data)
  975. writer.commit()
  976. # change data value to list
  977. data_value_to_list = []
  978. for item in data:
  979. new_data = {}
  980. new_data['file_name'] = np.asarray(item["file_name"], dtype='S')
  981. new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
  982. new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
  983. new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
  984. new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
  985. new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
  986. new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
  987. new_data['source_sos_ids'] = item["source_sos_ids"]
  988. new_data['source_sos_mask'] = item["source_sos_mask"]
  989. new_data['target_sos_ids'] = item["target_sos_ids"]
  990. new_data['target_sos_mask'] = item["target_sos_mask"]
  991. new_data['target_eos_ids'] = item["target_eos_ids"]
  992. new_data['target_eos_mask'] = item["target_eos_mask"]
  993. data_value_to_list.append(new_data)
  994. num_readers = 2
  995. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  996. num_parallel_workers=num_readers,
  997. shuffle=False)
  998. assert data_set.get_dataset_size() == 6
  999. num_iter = 0
  1000. for item in data_set.create_dict_iterator():
  1001. assert len(item) == 13
  1002. for field in item:
  1003. if isinstance(item[field], np.ndarray):
  1004. assert (item[field] ==
  1005. data_value_to_list[num_iter][field]).all()
  1006. else:
  1007. assert item[field] == data_value_to_list[num_iter][field]
  1008. num_iter += 1
  1009. assert num_iter == 6
  1010. num_readers = 2
  1011. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1012. columns_list=["source_sos_ids",
  1013. "source_sos_mask", "target_sos_ids"],
  1014. num_parallel_workers=num_readers,
  1015. shuffle=False)
  1016. assert data_set.get_dataset_size() == 6
  1017. num_iter = 0
  1018. for item in data_set.create_dict_iterator():
  1019. assert len(item) == 3
  1020. for field in item:
  1021. if isinstance(item[field], np.ndarray):
  1022. assert (item[field] == data[num_iter][field]).all()
  1023. else:
  1024. assert item[field] == data[num_iter][field]
  1025. num_iter += 1
  1026. assert num_iter == 6
  1027. num_readers = 1
  1028. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1029. columns_list=[
  1030. "image2", "source_sos_mask", "image3", "target_sos_ids"],
  1031. num_parallel_workers=num_readers,
  1032. shuffle=False)
  1033. assert data_set.get_dataset_size() == 6
  1034. num_iter = 0
  1035. for item in data_set.create_dict_iterator():
  1036. assert len(item) == 4
  1037. for field in item:
  1038. if isinstance(item[field], np.ndarray):
  1039. assert (item[field] ==
  1040. data_value_to_list[num_iter][field]).all()
  1041. else:
  1042. assert item[field] == data_value_to_list[num_iter][field]
  1043. num_iter += 1
  1044. assert num_iter == 6
  1045. num_readers = 3
  1046. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1047. columns_list=["target_sos_ids",
  1048. "image4", "source_sos_ids"],
  1049. num_parallel_workers=num_readers,
  1050. shuffle=False)
  1051. assert data_set.get_dataset_size() == 6
  1052. num_iter = 0
  1053. for item in data_set.create_dict_iterator():
  1054. assert len(item) == 3
  1055. for field in item:
  1056. if isinstance(item[field], np.ndarray):
  1057. assert (item[field] ==
  1058. data_value_to_list[num_iter][field]).all()
  1059. else:
  1060. assert item[field] == data_value_to_list[num_iter][field]
  1061. num_iter += 1
  1062. assert num_iter == 6
  1063. num_readers = 3
  1064. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1065. columns_list=["target_sos_ids", "image5",
  1066. "image4", "image3", "source_sos_ids"],
  1067. num_parallel_workers=num_readers,
  1068. shuffle=False)
  1069. assert data_set.get_dataset_size() == 6
  1070. num_iter = 0
  1071. for item in data_set.create_dict_iterator():
  1072. assert len(item) == 5
  1073. for field in item:
  1074. if isinstance(item[field], np.ndarray):
  1075. assert (item[field] ==
  1076. data_value_to_list[num_iter][field]).all()
  1077. else:
  1078. assert item[field] == data_value_to_list[num_iter][field]
  1079. num_iter += 1
  1080. assert num_iter == 6
  1081. num_readers = 1
  1082. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1083. columns_list=["target_eos_mask", "image5",
  1084. "image2", "source_sos_mask", "label"],
  1085. num_parallel_workers=num_readers,
  1086. shuffle=False)
  1087. assert data_set.get_dataset_size() == 6
  1088. num_iter = 0
  1089. for item in data_set.create_dict_iterator():
  1090. assert len(item) == 5
  1091. for field in item:
  1092. if isinstance(item[field], np.ndarray):
  1093. assert (item[field] ==
  1094. data_value_to_list[num_iter][field]).all()
  1095. else:
  1096. assert item[field] == data_value_to_list[num_iter][field]
  1097. num_iter += 1
  1098. assert num_iter == 6
  1099. num_readers = 2
  1100. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1101. columns_list=["label", "target_eos_mask", "image1", "target_eos_ids", "source_sos_mask",
  1102. "image2", "image4", "image3", "source_sos_ids", "image5", "file_name"],
  1103. num_parallel_workers=num_readers,
  1104. shuffle=False)
  1105. assert data_set.get_dataset_size() == 6
  1106. num_iter = 0
  1107. for item in data_set.create_dict_iterator():
  1108. assert len(item) == 11
  1109. for field in item:
  1110. if isinstance(item[field], np.ndarray):
  1111. assert (item[field] ==
  1112. data_value_to_list[num_iter][field]).all()
  1113. else:
  1114. assert item[field] == data_value_to_list[num_iter][field]
  1115. num_iter += 1
  1116. assert num_iter == 6
  1117. os.remove("{}".format(mindrecord_file_name))
  1118. os.remove("{}.db".format(mindrecord_file_name))
  1119. def test_write_with_multi_bytes_and_MindDataset():
  1120. mindrecord_file_name = "test.mindrecord"
  1121. data = [{"file_name": "001.jpg", "label": 43,
  1122. "image1": bytes("image1 bytes abc", encoding='UTF-8'),
  1123. "image2": bytes("image1 bytes def", encoding='UTF-8'),
  1124. "image3": bytes("image1 bytes ghi", encoding='UTF-8'),
  1125. "image4": bytes("image1 bytes jkl", encoding='UTF-8'),
  1126. "image5": bytes("image1 bytes mno", encoding='UTF-8')},
  1127. {"file_name": "002.jpg", "label": 91,
  1128. "image1": bytes("image2 bytes abc", encoding='UTF-8'),
  1129. "image2": bytes("image2 bytes def", encoding='UTF-8'),
  1130. "image3": bytes("image2 bytes ghi", encoding='UTF-8'),
  1131. "image4": bytes("image2 bytes jkl", encoding='UTF-8'),
  1132. "image5": bytes("image2 bytes mno", encoding='UTF-8')},
  1133. {"file_name": "003.jpg", "label": 61,
  1134. "image1": bytes("image3 bytes abc", encoding='UTF-8'),
  1135. "image2": bytes("image3 bytes def", encoding='UTF-8'),
  1136. "image3": bytes("image3 bytes ghi", encoding='UTF-8'),
  1137. "image4": bytes("image3 bytes jkl", encoding='UTF-8'),
  1138. "image5": bytes("image3 bytes mno", encoding='UTF-8')},
  1139. {"file_name": "004.jpg", "label": 29,
  1140. "image1": bytes("image4 bytes abc", encoding='UTF-8'),
  1141. "image2": bytes("image4 bytes def", encoding='UTF-8'),
  1142. "image3": bytes("image4 bytes ghi", encoding='UTF-8'),
  1143. "image4": bytes("image4 bytes jkl", encoding='UTF-8'),
  1144. "image5": bytes("image4 bytes mno", encoding='UTF-8')},
  1145. {"file_name": "005.jpg", "label": 78,
  1146. "image1": bytes("image5 bytes abc", encoding='UTF-8'),
  1147. "image2": bytes("image5 bytes def", encoding='UTF-8'),
  1148. "image3": bytes("image5 bytes ghi", encoding='UTF-8'),
  1149. "image4": bytes("image5 bytes jkl", encoding='UTF-8'),
  1150. "image5": bytes("image5 bytes mno", encoding='UTF-8')},
  1151. {"file_name": "006.jpg", "label": 37,
  1152. "image1": bytes("image6 bytes abc", encoding='UTF-8'),
  1153. "image2": bytes("image6 bytes def", encoding='UTF-8'),
  1154. "image3": bytes("image6 bytes ghi", encoding='UTF-8'),
  1155. "image4": bytes("image6 bytes jkl", encoding='UTF-8'),
  1156. "image5": bytes("image6 bytes mno", encoding='UTF-8')}
  1157. ]
  1158. writer = FileWriter(mindrecord_file_name)
  1159. schema = {"file_name": {"type": "string"},
  1160. "image1": {"type": "bytes"},
  1161. "image2": {"type": "bytes"},
  1162. "image3": {"type": "bytes"},
  1163. "label": {"type": "int32"},
  1164. "image4": {"type": "bytes"},
  1165. "image5": {"type": "bytes"}}
  1166. writer.add_schema(schema, "data is so cool")
  1167. writer.write_raw_data(data)
  1168. writer.commit()
  1169. # change data value to list
  1170. data_value_to_list = []
  1171. for item in data:
  1172. new_data = {}
  1173. new_data['file_name'] = np.asarray(item["file_name"], dtype='S')
  1174. new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
  1175. new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
  1176. new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
  1177. new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
  1178. new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
  1179. new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
  1180. data_value_to_list.append(new_data)
  1181. num_readers = 2
  1182. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1183. num_parallel_workers=num_readers,
  1184. shuffle=False)
  1185. assert data_set.get_dataset_size() == 6
  1186. num_iter = 0
  1187. for item in data_set.create_dict_iterator():
  1188. assert len(item) == 7
  1189. for field in item:
  1190. if isinstance(item[field], np.ndarray):
  1191. assert (item[field] ==
  1192. data_value_to_list[num_iter][field]).all()
  1193. else:
  1194. assert item[field] == data_value_to_list[num_iter][field]
  1195. num_iter += 1
  1196. assert num_iter == 6
  1197. num_readers = 2
  1198. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1199. columns_list=["image1", "image2", "image5"],
  1200. num_parallel_workers=num_readers,
  1201. shuffle=False)
  1202. assert data_set.get_dataset_size() == 6
  1203. num_iter = 0
  1204. for item in data_set.create_dict_iterator():
  1205. assert len(item) == 3
  1206. for field in item:
  1207. if isinstance(item[field], np.ndarray):
  1208. assert (item[field] ==
  1209. data_value_to_list[num_iter][field]).all()
  1210. else:
  1211. assert item[field] == data_value_to_list[num_iter][field]
  1212. num_iter += 1
  1213. assert num_iter == 6
  1214. num_readers = 2
  1215. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1216. columns_list=["image2", "image4"],
  1217. num_parallel_workers=num_readers,
  1218. shuffle=False)
  1219. assert data_set.get_dataset_size() == 6
  1220. num_iter = 0
  1221. for item in data_set.create_dict_iterator():
  1222. assert len(item) == 2
  1223. for field in item:
  1224. if isinstance(item[field], np.ndarray):
  1225. assert (item[field] ==
  1226. data_value_to_list[num_iter][field]).all()
  1227. else:
  1228. assert item[field] == data_value_to_list[num_iter][field]
  1229. num_iter += 1
  1230. assert num_iter == 6
  1231. num_readers = 2
  1232. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1233. columns_list=["image5", "image2"],
  1234. num_parallel_workers=num_readers,
  1235. shuffle=False)
  1236. assert data_set.get_dataset_size() == 6
  1237. num_iter = 0
  1238. for item in data_set.create_dict_iterator():
  1239. assert len(item) == 2
  1240. for field in item:
  1241. if isinstance(item[field], np.ndarray):
  1242. assert (item[field] ==
  1243. data_value_to_list[num_iter][field]).all()
  1244. else:
  1245. assert item[field] == data_value_to_list[num_iter][field]
  1246. num_iter += 1
  1247. assert num_iter == 6
  1248. num_readers = 2
  1249. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1250. columns_list=["image5", "image2", "label"],
  1251. num_parallel_workers=num_readers,
  1252. shuffle=False)
  1253. assert data_set.get_dataset_size() == 6
  1254. num_iter = 0
  1255. for item in data_set.create_dict_iterator():
  1256. assert len(item) == 3
  1257. for field in item:
  1258. if isinstance(item[field], np.ndarray):
  1259. assert (item[field] ==
  1260. data_value_to_list[num_iter][field]).all()
  1261. else:
  1262. assert item[field] == data_value_to_list[num_iter][field]
  1263. num_iter += 1
  1264. assert num_iter == 6
  1265. num_readers = 2
  1266. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1267. columns_list=["image4", "image5",
  1268. "image2", "image3", "file_name"],
  1269. num_parallel_workers=num_readers,
  1270. shuffle=False)
  1271. assert data_set.get_dataset_size() == 6
  1272. num_iter = 0
  1273. for item in data_set.create_dict_iterator():
  1274. assert len(item) == 5
  1275. for field in item:
  1276. if isinstance(item[field], np.ndarray):
  1277. assert (item[field] ==
  1278. data_value_to_list[num_iter][field]).all()
  1279. else:
  1280. assert item[field] == data_value_to_list[num_iter][field]
  1281. num_iter += 1
  1282. assert num_iter == 6
  1283. os.remove("{}".format(mindrecord_file_name))
  1284. os.remove("{}.db".format(mindrecord_file_name))
  1285. def test_write_with_multi_array_and_MindDataset():
  1286. mindrecord_file_name = "test.mindrecord"
  1287. data = [{"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
  1288. "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1289. "source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64),
  1290. "source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
  1291. "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
  1292. "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
  1293. "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1294. "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
  1295. {"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
  1296. "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1297. "source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64),
  1298. "source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
  1299. "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
  1300. "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
  1301. "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1302. "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
  1303. {"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
  1304. "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1305. "source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64),
  1306. "source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
  1307. "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
  1308. "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
  1309. "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1310. "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
  1311. {"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
  1312. "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1313. "source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64),
  1314. "source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
  1315. "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
  1316. "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
  1317. "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1318. "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
  1319. {"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
  1320. "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1321. "source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64),
  1322. "source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
  1323. "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
  1324. "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
  1325. "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1326. "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
  1327. {"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
  1328. "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1329. "source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64),
  1330. "source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
  1331. "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
  1332. "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
  1333. "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1334. "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
  1335. ]
  1336. writer = FileWriter(mindrecord_file_name)
  1337. schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
  1338. "source_sos_mask": {"type": "int64", "shape": [-1]},
  1339. "source_eos_ids": {"type": "int64", "shape": [-1]},
  1340. "source_eos_mask": {"type": "int64", "shape": [-1]},
  1341. "target_sos_ids": {"type": "int64", "shape": [-1]},
  1342. "target_sos_mask": {"type": "int64", "shape": [-1]},
  1343. "target_eos_ids": {"type": "int64", "shape": [-1]},
  1344. "target_eos_mask": {"type": "int64", "shape": [-1]}}
  1345. writer.add_schema(schema, "data is so cool")
  1346. writer.write_raw_data(data)
  1347. writer.commit()
  1348. # change data value to list - do none
  1349. data_value_to_list = []
  1350. for item in data:
  1351. new_data = {}
  1352. new_data['source_sos_ids'] = item["source_sos_ids"]
  1353. new_data['source_sos_mask'] = item["source_sos_mask"]
  1354. new_data['source_eos_ids'] = item["source_eos_ids"]
  1355. new_data['source_eos_mask'] = item["source_eos_mask"]
  1356. new_data['target_sos_ids'] = item["target_sos_ids"]
  1357. new_data['target_sos_mask'] = item["target_sos_mask"]
  1358. new_data['target_eos_ids'] = item["target_eos_ids"]
  1359. new_data['target_eos_mask'] = item["target_eos_mask"]
  1360. data_value_to_list.append(new_data)
  1361. num_readers = 2
  1362. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1363. num_parallel_workers=num_readers,
  1364. shuffle=False)
  1365. assert data_set.get_dataset_size() == 6
  1366. num_iter = 0
  1367. for item in data_set.create_dict_iterator():
  1368. assert len(item) == 8
  1369. for field in item:
  1370. if isinstance(item[field], np.ndarray):
  1371. assert (item[field] ==
  1372. data_value_to_list[num_iter][field]).all()
  1373. else:
  1374. assert item[field] == data_value_to_list[num_iter][field]
  1375. num_iter += 1
  1376. assert num_iter == 6
  1377. num_readers = 2
  1378. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1379. columns_list=["source_eos_ids", "source_eos_mask",
  1380. "target_sos_ids", "target_sos_mask",
  1381. "target_eos_ids", "target_eos_mask"],
  1382. num_parallel_workers=num_readers,
  1383. shuffle=False)
  1384. assert data_set.get_dataset_size() == 6
  1385. num_iter = 0
  1386. for item in data_set.create_dict_iterator():
  1387. assert len(item) == 6
  1388. for field in item:
  1389. if isinstance(item[field], np.ndarray):
  1390. assert (item[field] ==
  1391. data_value_to_list[num_iter][field]).all()
  1392. else:
  1393. assert item[field] == data_value_to_list[num_iter][field]
  1394. num_iter += 1
  1395. assert num_iter == 6
  1396. num_readers = 2
  1397. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1398. columns_list=["source_sos_ids",
  1399. "target_sos_ids",
  1400. "target_eos_mask"],
  1401. num_parallel_workers=num_readers,
  1402. shuffle=False)
  1403. assert data_set.get_dataset_size() == 6
  1404. num_iter = 0
  1405. for item in data_set.create_dict_iterator():
  1406. assert len(item) == 3
  1407. for field in item:
  1408. if isinstance(item[field], np.ndarray):
  1409. assert (item[field] ==
  1410. data_value_to_list[num_iter][field]).all()
  1411. else:
  1412. assert item[field] == data_value_to_list[num_iter][field]
  1413. num_iter += 1
  1414. assert num_iter == 6
  1415. num_readers = 2
  1416. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1417. columns_list=["target_eos_mask",
  1418. "source_eos_mask",
  1419. "source_sos_mask"],
  1420. num_parallel_workers=num_readers,
  1421. shuffle=False)
  1422. assert data_set.get_dataset_size() == 6
  1423. num_iter = 0
  1424. for item in data_set.create_dict_iterator():
  1425. assert len(item) == 3
  1426. for field in item:
  1427. if isinstance(item[field], np.ndarray):
  1428. assert (item[field] ==
  1429. data_value_to_list[num_iter][field]).all()
  1430. else:
  1431. assert item[field] == data_value_to_list[num_iter][field]
  1432. num_iter += 1
  1433. assert num_iter == 6
  1434. num_readers = 2
  1435. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1436. columns_list=["target_eos_ids"],
  1437. num_parallel_workers=num_readers,
  1438. shuffle=False)
  1439. assert data_set.get_dataset_size() == 6
  1440. num_iter = 0
  1441. for item in data_set.create_dict_iterator():
  1442. assert len(item) == 1
  1443. for field in item:
  1444. if isinstance(item[field], np.ndarray):
  1445. assert (item[field] ==
  1446. data_value_to_list[num_iter][field]).all()
  1447. else:
  1448. assert item[field] == data_value_to_list[num_iter][field]
  1449. num_iter += 1
  1450. assert num_iter == 6
  1451. num_readers = 1
  1452. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1453. columns_list=["target_eos_mask", "target_eos_ids",
  1454. "target_sos_mask", "target_sos_ids",
  1455. "source_eos_mask", "source_eos_ids",
  1456. "source_sos_mask", "source_sos_ids"],
  1457. num_parallel_workers=num_readers,
  1458. shuffle=False)
  1459. assert data_set.get_dataset_size() == 6
  1460. num_iter = 0
  1461. for item in data_set.create_dict_iterator():
  1462. assert len(item) == 8
  1463. for field in item:
  1464. if isinstance(item[field], np.ndarray):
  1465. assert (item[field] ==
  1466. data_value_to_list[num_iter][field]).all()
  1467. else:
  1468. assert item[field] == data_value_to_list[num_iter][field]
  1469. num_iter += 1
  1470. assert num_iter == 6
  1471. os.remove("{}".format(mindrecord_file_name))
  1472. os.remove("{}.db".format(mindrecord_file_name))