You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_minddataset.py 126 kB

5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This is the test module for mindrecord
  17. """
  18. import collections
  19. import json
  20. import math
  21. import os
  22. import re
  23. import string
  24. import pytest
  25. import numpy as np
  26. import mindspore.dataset as ds
  27. import mindspore.dataset.vision.c_transforms as vision
  28. from mindspore import log as logger
  29. from mindspore.dataset.vision import Inter
  30. from mindspore.mindrecord import FileWriter
  31. FILES_NUM = 4
  32. CV_DIR_NAME = "../data/mindrecord/testImageNetData"
  33. NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos"
  34. NLP_FILE_VOCAB = "../data/mindrecord/testAclImdbData/vocab.txt"
  35. @pytest.fixture
  36. def add_and_remove_cv_file():
  37. """add/remove cv file"""
  38. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  39. paths = ["{}{}".format(file_name, str(x).rjust(1, '0'))
  40. for x in range(FILES_NUM)]
  41. try:
  42. for x in paths:
  43. if os.path.exists("{}".format(x)):
  44. os.remove("{}".format(x))
  45. if os.path.exists("{}.db".format(x)):
  46. os.remove("{}.db".format(x))
  47. writer = FileWriter(file_name, FILES_NUM)
  48. data = get_data(CV_DIR_NAME)
  49. cv_schema_json = {"id": {"type": "int32"},
  50. "file_name": {"type": "string"},
  51. "label": {"type": "int32"},
  52. "data": {"type": "bytes"}}
  53. writer.add_schema(cv_schema_json, "img_schema")
  54. writer.add_index(["file_name", "label"])
  55. writer.write_raw_data(data)
  56. writer.commit()
  57. yield "yield_cv_data"
  58. except Exception as error:
  59. for x in paths:
  60. os.remove("{}".format(x))
  61. os.remove("{}.db".format(x))
  62. raise error
  63. else:
  64. for x in paths:
  65. os.remove("{}".format(x))
  66. os.remove("{}.db".format(x))
  67. @pytest.fixture
  68. def add_and_remove_nlp_file():
  69. """add/remove nlp file"""
  70. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  71. paths = ["{}{}".format(file_name, str(x).rjust(1, '0'))
  72. for x in range(FILES_NUM)]
  73. try:
  74. for x in paths:
  75. if os.path.exists("{}".format(x)):
  76. os.remove("{}".format(x))
  77. if os.path.exists("{}.db".format(x)):
  78. os.remove("{}.db".format(x))
  79. writer = FileWriter(file_name, FILES_NUM)
  80. data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
  81. nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"},
  82. "rating": {"type": "float32"},
  83. "input_ids": {"type": "int64",
  84. "shape": [-1]},
  85. "input_mask": {"type": "int64",
  86. "shape": [1, -1]},
  87. "segment_ids": {"type": "int64",
  88. "shape": [2, -1]}
  89. }
  90. writer.set_header_size(1 << 14)
  91. writer.set_page_size(1 << 15)
  92. writer.add_schema(nlp_schema_json, "nlp_schema")
  93. writer.add_index(["id", "rating"])
  94. writer.write_raw_data(data)
  95. writer.commit()
  96. yield "yield_nlp_data"
  97. except Exception as error:
  98. for x in paths:
  99. os.remove("{}".format(x))
  100. os.remove("{}.db".format(x))
  101. raise error
  102. else:
  103. for x in paths:
  104. os.remove("{}".format(x))
  105. os.remove("{}.db".format(x))
  106. @pytest.fixture
  107. def add_and_remove_nlp_compress_file():
  108. """add/remove nlp file"""
  109. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  110. paths = ["{}{}".format(file_name, str(x).rjust(1, '0'))
  111. for x in range(FILES_NUM)]
  112. try:
  113. for x in paths:
  114. if os.path.exists("{}".format(x)):
  115. os.remove("{}".format(x))
  116. if os.path.exists("{}.db".format(x)):
  117. os.remove("{}.db".format(x))
  118. writer = FileWriter(file_name, FILES_NUM)
  119. data = []
  120. for row_id in range(16):
  121. data.append({
  122. "label": row_id,
  123. "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129,
  124. 255, 256, -32768, 32767, -32769, 32768, -2147483648,
  125. 2147483647], dtype=np.int32), [-1]),
  126. "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
  127. 256, -32768, 32767, -32769, 32768,
  128. -2147483648, 2147483647, -2147483649, 2147483649,
  129. -922337036854775808, 9223372036854775807]), [1, -1]),
  130. "array_c": str.encode("nlp data"),
  131. "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1])
  132. })
  133. nlp_schema_json = {"label": {"type": "int32"},
  134. "array_a": {"type": "int32",
  135. "shape": [-1]},
  136. "array_b": {"type": "int64",
  137. "shape": [1, -1]},
  138. "array_c": {"type": "bytes"},
  139. "array_d": {"type": "int64",
  140. "shape": [2, -1]}
  141. }
  142. writer.set_header_size(1 << 14)
  143. writer.set_page_size(1 << 15)
  144. writer.add_schema(nlp_schema_json, "nlp_schema")
  145. writer.write_raw_data(data)
  146. writer.commit()
  147. yield "yield_nlp_data"
  148. except Exception as error:
  149. for x in paths:
  150. os.remove("{}".format(x))
  151. os.remove("{}.db".format(x))
  152. raise error
  153. else:
  154. for x in paths:
  155. os.remove("{}".format(x))
  156. os.remove("{}.db".format(x))
  157. def test_nlp_compress_data(add_and_remove_nlp_compress_file):
  158. """tutorial for nlp minderdataset."""
  159. data = []
  160. for row_id in range(16):
  161. data.append({
  162. "label": row_id,
  163. "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129,
  164. 255, 256, -32768, 32767, -32769, 32768, -2147483648,
  165. 2147483647], dtype=np.int32), [-1]),
  166. "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
  167. 256, -32768, 32767, -32769, 32768,
  168. -2147483648, 2147483647, -2147483649, 2147483649,
  169. -922337036854775808, 9223372036854775807]), [1, -1]),
  170. "array_c": str.encode("nlp data"),
  171. "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1])
  172. })
  173. num_readers = 1
  174. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  175. data_set = ds.MindDataset(
  176. file_name + "0", None, num_readers, shuffle=False)
  177. assert data_set.get_dataset_size() == 16
  178. num_iter = 0
  179. for x, item in zip(data, data_set.create_dict_iterator(num_epochs=1, output_numpy=True)):
  180. assert (item["array_a"] == x["array_a"]).all()
  181. assert (item["array_b"] == x["array_b"]).all()
  182. assert item["array_c"].tobytes() == x["array_c"]
  183. assert (item["array_d"] == x["array_d"]).all()
  184. assert item["label"] == x["label"]
  185. num_iter += 1
  186. assert num_iter == 16
  187. def test_cv_minddataset_writer_tutorial():
  188. """tutorial for cv dataset writer."""
  189. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  190. paths = ["{}{}".format(file_name, str(x).rjust(1, '0'))
  191. for x in range(FILES_NUM)]
  192. try:
  193. for x in paths:
  194. if os.path.exists("{}".format(x)):
  195. os.remove("{}".format(x))
  196. if os.path.exists("{}.db".format(x)):
  197. os.remove("{}.db".format(x))
  198. writer = FileWriter(file_name, FILES_NUM)
  199. data = get_data(CV_DIR_NAME)
  200. cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"},
  201. "data": {"type": "bytes"}}
  202. writer.add_schema(cv_schema_json, "img_schema")
  203. writer.add_index(["file_name", "label"])
  204. writer.write_raw_data(data)
  205. writer.commit()
  206. except Exception as error:
  207. for x in paths:
  208. os.remove("{}".format(x))
  209. os.remove("{}.db".format(x))
  210. raise error
  211. else:
  212. for x in paths:
  213. os.remove("{}".format(x))
  214. os.remove("{}.db".format(x))
  215. def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
  216. """tutorial for cv minddataset."""
  217. columns_list = ["data", "file_name", "label"]
  218. num_readers = 4
  219. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  220. def partitions(num_shards):
  221. for partition_id in range(num_shards):
  222. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers,
  223. num_shards=num_shards, shard_id=partition_id)
  224. num_iter = 0
  225. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  226. logger.info("-------------- partition : {} ------------------------".format(partition_id))
  227. logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
  228. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  229. num_iter += 1
  230. return num_iter
  231. assert partitions(4) == 3
  232. assert partitions(5) == 2
  233. assert partitions(9) == 2
  234. def test_cv_minddataset_partition_num_samples_0(add_and_remove_cv_file):
  235. """tutorial for cv minddataset."""
  236. columns_list = ["data", "file_name", "label"]
  237. num_readers = 4
  238. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  239. def partitions(num_shards):
  240. for partition_id in range(num_shards):
  241. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers,
  242. num_shards=num_shards,
  243. shard_id=partition_id, num_samples=1)
  244. assert data_set.get_dataset_size() == 1
  245. num_iter = 0
  246. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  247. logger.info("-------------- partition : {} ------------------------".format(partition_id))
  248. logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
  249. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  250. num_iter += 1
  251. return num_iter
  252. assert partitions(4) == 1
  253. assert partitions(5) == 1
  254. assert partitions(9) == 1
  255. def test_cv_minddataset_partition_num_samples_1(add_and_remove_cv_file):
  256. """tutorial for cv minddataset."""
  257. columns_list = ["data", "file_name", "label"]
  258. num_readers = 4
  259. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  260. def partitions(num_shards):
  261. for partition_id in range(num_shards):
  262. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers,
  263. num_shards=num_shards,
  264. shard_id=partition_id, num_samples=2)
  265. assert data_set.get_dataset_size() == 2
  266. num_iter = 0
  267. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  268. logger.info("-------------- partition : {} ------------------------".format(partition_id))
  269. logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
  270. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  271. num_iter += 1
  272. return num_iter
  273. assert partitions(4) == 2
  274. assert partitions(5) == 2
  275. assert partitions(9) == 2
  276. def test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file):
  277. """tutorial for cv minddataset."""
  278. columns_list = ["data", "file_name", "label"]
  279. num_readers = 4
  280. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  281. def partitions(num_shards, expect):
  282. for partition_id in range(num_shards):
  283. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers,
  284. num_shards=num_shards,
  285. shard_id=partition_id, num_samples=3)
  286. assert data_set.get_dataset_size() == expect
  287. num_iter = 0
  288. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  289. logger.info("-------------- partition : {} ------------------------".format(partition_id))
  290. logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
  291. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  292. num_iter += 1
  293. return num_iter
  294. assert partitions(4, 3) == 3
  295. assert partitions(5, 2) == 2
  296. assert partitions(9, 2) == 2
  297. def test_cv_minddataset_partition_num_samples_3(add_and_remove_cv_file):
  298. """tutorial for cv minddataset."""
  299. columns_list = ["data", "file_name", "label"]
  300. num_readers = 4
  301. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  302. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers, num_shards=1, shard_id=0, num_samples=5)
  303. assert data_set.get_dataset_size() == 5
  304. num_iter = 0
  305. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  306. logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
  307. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  308. num_iter += 1
  309. assert num_iter == 5
  310. def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file):
  311. """tutorial for cv minddataset."""
  312. columns_list = ["data", "file_name", "label"]
  313. num_readers = 4
  314. num_shards = 3
  315. epoch1 = []
  316. epoch2 = []
  317. epoch3 = []
  318. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  319. for partition_id in range(num_shards):
  320. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers,
  321. num_shards=num_shards, shard_id=partition_id)
  322. data_set = data_set.repeat(3)
  323. num_iter = 0
  324. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  325. logger.info("-------------- partition : {} ------------------------".format(partition_id))
  326. logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
  327. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  328. num_iter += 1
  329. if num_iter <= 4:
  330. epoch1.append(item["file_name"]) # save epoch 1 list
  331. elif num_iter <= 8:
  332. epoch2.append(item["file_name"]) # save epoch 2 list
  333. else:
  334. epoch3.append(item["file_name"]) # save epoch 3 list
  335. assert num_iter == 12
  336. assert len(epoch1) == 4
  337. assert len(epoch2) == 4
  338. assert len(epoch3) == 4
  339. assert epoch1 not in (epoch2, epoch3)
  340. assert epoch2 not in (epoch1, epoch3)
  341. assert epoch3 not in (epoch1, epoch2)
  342. epoch1 = []
  343. epoch2 = []
  344. epoch3 = []
  345. def test_cv_minddataset_partition_tutorial_check_whole_reshuffle_result_per_epoch(add_and_remove_cv_file):
  346. """tutorial for cv minddataset."""
  347. columns_list = ["data", "file_name", "label"]
  348. num_readers = 4
  349. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  350. num_shards = 3
  351. epoch_result = [[["", "", "", ""], ["", "", "", ""], ["", "", "", ""]], # save partition 0 result
  352. [["", "", "", ""], ["", "", "", ""], ["", "", "", ""]], # save partition 1 result
  353. [["", "", "", ""], ["", "", "", ""], ["", "", "", ""]]] # svae partition 2 result
  354. for partition_id in range(num_shards):
  355. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers,
  356. num_shards=num_shards, shard_id=partition_id)
  357. data_set = data_set.repeat(3)
  358. num_iter = 0
  359. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  360. logger.info("-------------- partition : {} ------------------------".format(partition_id))
  361. logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
  362. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  363. # total 3 partition, 4 result per epoch, total 12 result
  364. epoch_result[partition_id][int(num_iter / 4)][num_iter % 4] = item["file_name"] # save epoch result
  365. num_iter += 1
  366. assert num_iter == 12
  367. assert epoch_result[partition_id][0] not in (epoch_result[partition_id][1], epoch_result[partition_id][2])
  368. assert epoch_result[partition_id][1] not in (epoch_result[partition_id][0], epoch_result[partition_id][2])
  369. assert epoch_result[partition_id][2] not in (epoch_result[partition_id][1], epoch_result[partition_id][0])
  370. epoch_result[partition_id][0].sort()
  371. epoch_result[partition_id][1].sort()
  372. epoch_result[partition_id][2].sort()
  373. assert epoch_result[partition_id][0] != epoch_result[partition_id][1]
  374. assert epoch_result[partition_id][1] != epoch_result[partition_id][2]
  375. assert epoch_result[partition_id][2] != epoch_result[partition_id][0]
  376. def test_cv_minddataset_check_shuffle_result(add_and_remove_cv_file):
  377. """tutorial for cv minddataset."""
  378. columns_list = ["data", "file_name", "label"]
  379. num_readers = 4
  380. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  381. ds.config.set_seed(54321)
  382. epoch1 = []
  383. epoch2 = []
  384. epoch3 = []
  385. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers)
  386. data_set = data_set.repeat(3)
  387. num_iter = 0
  388. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  389. logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
  390. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  391. num_iter += 1
  392. if num_iter <= 10:
  393. epoch1.append(item["file_name"]) # save epoch 1 list
  394. elif num_iter <= 20:
  395. epoch2.append(item["file_name"]) # save epoch 2 list
  396. else:
  397. epoch3.append(item["file_name"]) # save epoch 3 list
  398. assert num_iter == 30
  399. assert len(epoch1) == 10
  400. assert len(epoch2) == 10
  401. assert len(epoch3) == 10
  402. assert epoch1 not in (epoch2, epoch3)
  403. assert epoch2 not in (epoch1, epoch3)
  404. assert epoch3 not in (epoch1, epoch2)
  405. epoch1_new_dataset = []
  406. epoch2_new_dataset = []
  407. epoch3_new_dataset = []
  408. data_set2 = ds.MindDataset(file_name + "0", columns_list, num_readers)
  409. data_set2 = data_set2.repeat(3)
  410. num_iter = 0
  411. for item in data_set2.create_dict_iterator(num_epochs=1, output_numpy=True):
  412. logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
  413. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  414. num_iter += 1
  415. if num_iter <= 10:
  416. epoch1_new_dataset.append(item["file_name"]) # save epoch 1 list
  417. elif num_iter <= 20:
  418. epoch2_new_dataset.append(item["file_name"]) # save epoch 2 list
  419. else:
  420. epoch3_new_dataset.append(item["file_name"]) # save epoch 3 list
  421. assert num_iter == 30
  422. assert len(epoch1_new_dataset) == 10
  423. assert len(epoch2_new_dataset) == 10
  424. assert len(epoch3_new_dataset) == 10
  425. assert epoch1_new_dataset not in (epoch2_new_dataset, epoch3_new_dataset)
  426. assert epoch2_new_dataset not in (epoch1_new_dataset, epoch3_new_dataset)
  427. assert epoch3_new_dataset not in (epoch1_new_dataset, epoch2_new_dataset)
  428. assert epoch1 == epoch1_new_dataset
  429. assert epoch2 == epoch2_new_dataset
  430. assert epoch3 == epoch3_new_dataset
  431. ds.config.set_seed(12345)
  432. epoch1_new_dataset2 = []
  433. epoch2_new_dataset2 = []
  434. epoch3_new_dataset2 = []
  435. data_set3 = ds.MindDataset(file_name + "0", columns_list, num_readers)
  436. data_set3 = data_set3.repeat(3)
  437. num_iter = 0
  438. for item in data_set3.create_dict_iterator(num_epochs=1, output_numpy=True):
  439. logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
  440. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  441. num_iter += 1
  442. if num_iter <= 10:
  443. epoch1_new_dataset2.append(item["file_name"]) # save epoch 1 list
  444. elif num_iter <= 20:
  445. epoch2_new_dataset2.append(item["file_name"]) # save epoch 2 list
  446. else:
  447. epoch3_new_dataset2.append(item["file_name"]) # save epoch 3 list
  448. assert num_iter == 30
  449. assert len(epoch1_new_dataset2) == 10
  450. assert len(epoch2_new_dataset2) == 10
  451. assert len(epoch3_new_dataset2) == 10
  452. assert epoch1_new_dataset2 not in (epoch2_new_dataset2, epoch3_new_dataset2)
  453. assert epoch2_new_dataset2 not in (epoch1_new_dataset2, epoch3_new_dataset2)
  454. assert epoch3_new_dataset2 not in (epoch1_new_dataset2, epoch2_new_dataset2)
  455. assert epoch1 != epoch1_new_dataset2
  456. assert epoch2 != epoch2_new_dataset2
  457. assert epoch3 != epoch3_new_dataset2
  458. def test_cv_minddataset_dataset_size(add_and_remove_cv_file):
  459. """tutorial for cv minddataset."""
  460. columns_list = ["data", "file_name", "label"]
  461. num_readers = 4
  462. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  463. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers)
  464. assert data_set.get_dataset_size() == 10
  465. repeat_num = 2
  466. data_set = data_set.repeat(repeat_num)
  467. num_iter = 0
  468. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  469. logger.info(
  470. "-------------- get dataset size {} -----------------".format(num_iter))
  471. logger.info(
  472. "-------------- item[label]: {} ---------------------".format(item["label"]))
  473. logger.info(
  474. "-------------- item[data]: {} ----------------------".format(item["data"]))
  475. num_iter += 1
  476. assert num_iter == 20
  477. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers,
  478. num_shards=4, shard_id=3)
  479. assert data_set.get_dataset_size() == 3
  480. def test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file):
  481. """tutorial for cv minddataset."""
  482. columns_list = ["data", "label"]
  483. num_readers = 4
  484. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  485. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers)
  486. decode_op = vision.Decode()
  487. data_set = data_set.map(
  488. input_columns=["data"], operations=decode_op, num_parallel_workers=2)
  489. resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
  490. data_set = data_set.map(operations=resize_op, input_columns="data",
  491. num_parallel_workers=2)
  492. data_set = data_set.batch(2)
  493. data_set = data_set.repeat(2)
  494. num_iter = 0
  495. labels = []
  496. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  497. logger.info(
  498. "-------------- get dataset size {} -----------------".format(num_iter))
  499. logger.info(
  500. "-------------- item[label]: {} ---------------------".format(item["label"]))
  501. logger.info(
  502. "-------------- item[data]: {} ----------------------".format(item["data"]))
  503. num_iter += 1
  504. labels.append(item["label"])
  505. assert num_iter == 10
  506. logger.info("repeat shuffle: {}".format(labels))
  507. assert len(labels) == 10
  508. assert labels[0:5] == labels[0:5]
  509. assert labels[0:5] != labels[5:5]
  510. def test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file):
  511. """tutorial for cv minddataset."""
  512. columns_list = ["data", "label"]
  513. num_readers = 4
  514. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  515. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers)
  516. decode_op = vision.Decode()
  517. data_set = data_set.map(
  518. input_columns=["data"], operations=decode_op, num_parallel_workers=2)
  519. resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
  520. data_set = data_set.map(operations=resize_op, input_columns="data",
  521. num_parallel_workers=2)
  522. data_set = data_set.batch(32, drop_remainder=True)
  523. num_iter = 0
  524. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  525. logger.info(
  526. "-------------- get dataset size {} -----------------".format(num_iter))
  527. logger.info(
  528. "-------------- item[label]: {} ---------------------".format(item["label"]))
  529. logger.info(
  530. "-------------- item[data]: {} ----------------------".format(item["data"]))
  531. num_iter += 1
  532. assert num_iter == 0
  533. def test_cv_minddataset_issue_888(add_and_remove_cv_file):
  534. """issue 888 test."""
  535. columns_list = ["data", "label"]
  536. num_readers = 2
  537. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  538. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers, shuffle=False, num_shards=5, shard_id=1)
  539. data_set = data_set.shuffle(2)
  540. data_set = data_set.repeat(9)
  541. num_iter = 0
  542. for _ in data_set.create_dict_iterator(num_epochs=1):
  543. num_iter += 1
  544. assert num_iter == 18
  545. def test_cv_minddataset_reader_file_list(add_and_remove_cv_file):
  546. """tutorial for cv minderdataset."""
  547. columns_list = ["data", "file_name", "label"]
  548. num_readers = 4
  549. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  550. data_set = ds.MindDataset([file_name + str(x)
  551. for x in range(FILES_NUM)], columns_list, num_readers)
  552. assert data_set.get_dataset_size() == 10
  553. num_iter = 0
  554. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  555. logger.info(
  556. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  557. logger.info(
  558. "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  559. logger.info(
  560. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  561. logger.info(
  562. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  563. logger.info(
  564. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  565. num_iter += 1
  566. assert num_iter == 10
  567. def test_cv_minddataset_reader_one_partition(add_and_remove_cv_file):
  568. """tutorial for cv minderdataset."""
  569. columns_list = ["data", "file_name", "label"]
  570. num_readers = 4
  571. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  572. data_set = ds.MindDataset([file_name + "0"], columns_list, num_readers)
  573. assert data_set.get_dataset_size() < 10
  574. num_iter = 0
  575. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  576. logger.info(
  577. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  578. logger.info(
  579. "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  580. logger.info(
  581. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  582. logger.info(
  583. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  584. logger.info(
  585. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  586. num_iter += 1
  587. assert num_iter < 10
  588. def test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file):
  589. """tutorial for cv minderdataset."""
  590. CV1_FILE_NAME = "../data/mindrecord/test_cv_minddataset_reader_two_dataset_1.mindrecord"
  591. CV2_FILE_NAME = "../data/mindrecord/test_cv_minddataset_reader_two_dataset_2.mindrecord"
  592. try:
  593. if os.path.exists(CV1_FILE_NAME):
  594. os.remove(CV1_FILE_NAME)
  595. if os.path.exists("{}.db".format(CV1_FILE_NAME)):
  596. os.remove("{}.db".format(CV1_FILE_NAME))
  597. if os.path.exists(CV2_FILE_NAME):
  598. os.remove(CV2_FILE_NAME)
  599. if os.path.exists("{}.db".format(CV2_FILE_NAME)):
  600. os.remove("{}.db".format(CV2_FILE_NAME))
  601. writer = FileWriter(CV1_FILE_NAME, 1)
  602. data = get_data(CV_DIR_NAME)
  603. cv_schema_json = {"id": {"type": "int32"},
  604. "file_name": {"type": "string"},
  605. "label": {"type": "int32"},
  606. "data": {"type": "bytes"}}
  607. writer.add_schema(cv_schema_json, "CV1_schema")
  608. writer.add_index(["file_name", "label"])
  609. writer.write_raw_data(data)
  610. writer.commit()
  611. writer = FileWriter(CV2_FILE_NAME, 1)
  612. data = get_data(CV_DIR_NAME)
  613. cv_schema_json = {"id": {"type": "int32"},
  614. "file_name": {"type": "string"},
  615. "label": {"type": "int32"},
  616. "data": {"type": "bytes"}}
  617. writer.add_schema(cv_schema_json, "CV2_schema")
  618. writer.add_index(["file_name", "label"])
  619. writer.write_raw_data(data)
  620. writer.commit()
  621. columns_list = ["data", "file_name", "label"]
  622. num_readers = 4
  623. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  624. data_set = ds.MindDataset([file_name + str(x) for x in range(FILES_NUM)] + [CV1_FILE_NAME, CV2_FILE_NAME],
  625. columns_list, num_readers)
  626. assert data_set.get_dataset_size() == 30
  627. num_iter = 0
  628. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  629. logger.info(
  630. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  631. logger.info(
  632. "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  633. logger.info(
  634. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  635. logger.info(
  636. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  637. logger.info(
  638. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  639. num_iter += 1
  640. assert num_iter == 30
  641. except Exception as error:
  642. if os.path.exists(CV1_FILE_NAME):
  643. os.remove(CV1_FILE_NAME)
  644. if os.path.exists("{}.db".format(CV1_FILE_NAME)):
  645. os.remove("{}.db".format(CV1_FILE_NAME))
  646. if os.path.exists(CV2_FILE_NAME):
  647. os.remove(CV2_FILE_NAME)
  648. if os.path.exists("{}.db".format(CV2_FILE_NAME)):
  649. os.remove("{}.db".format(CV2_FILE_NAME))
  650. raise error
  651. else:
  652. if os.path.exists(CV1_FILE_NAME):
  653. os.remove(CV1_FILE_NAME)
  654. if os.path.exists("{}.db".format(CV1_FILE_NAME)):
  655. os.remove("{}.db".format(CV1_FILE_NAME))
  656. if os.path.exists(CV2_FILE_NAME):
  657. os.remove(CV2_FILE_NAME)
  658. if os.path.exists("{}.db".format(CV2_FILE_NAME)):
  659. os.remove("{}.db".format(CV2_FILE_NAME))
  660. def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file):
  661. CV1_FILE_NAME = "../data/mindrecord/test_cv_minddataset_reader_two_dataset_partition_1"
  662. paths = ["{}{}".format(CV1_FILE_NAME, str(x).rjust(1, '0'))
  663. for x in range(FILES_NUM)]
  664. try:
  665. for x in paths:
  666. if os.path.exists("{}".format(x)):
  667. os.remove("{}".format(x))
  668. if os.path.exists("{}.db".format(x)):
  669. os.remove("{}.db".format(x))
  670. writer = FileWriter(CV1_FILE_NAME, FILES_NUM)
  671. data = get_data(CV_DIR_NAME)
  672. cv_schema_json = {"id": {"type": "int32"},
  673. "file_name": {"type": "string"},
  674. "label": {"type": "int32"},
  675. "data": {"type": "bytes"}}
  676. writer.add_schema(cv_schema_json, "CV1_schema")
  677. writer.add_index(["file_name", "label"])
  678. writer.write_raw_data(data)
  679. writer.commit()
  680. columns_list = ["data", "file_name", "label"]
  681. num_readers = 4
  682. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  683. data_set = ds.MindDataset([file_name + str(x) for x in range(2)] +
  684. [CV1_FILE_NAME + str(x) for x in range(2, 4)],
  685. columns_list, num_readers)
  686. assert data_set.get_dataset_size() < 20
  687. num_iter = 0
  688. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  689. logger.info(
  690. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  691. logger.info(
  692. "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  693. logger.info(
  694. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  695. logger.info(
  696. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  697. logger.info(
  698. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  699. num_iter += 1
  700. assert num_iter < 20
  701. except Exception as error:
  702. for x in paths:
  703. os.remove("{}".format(x))
  704. os.remove("{}.db".format(x))
  705. raise error
  706. else:
  707. for x in paths:
  708. os.remove("{}".format(x))
  709. os.remove("{}.db".format(x))
  710. def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file):
  711. """tutorial for cv minderdataset."""
  712. columns_list = ["data", "file_name", "label"]
  713. num_readers = 4
  714. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  715. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers)
  716. assert data_set.get_dataset_size() == 10
  717. num_iter = 0
  718. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  719. logger.info(
  720. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  721. logger.info(
  722. "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  723. logger.info(
  724. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  725. logger.info(
  726. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  727. logger.info(
  728. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  729. num_iter += 1
  730. assert num_iter == 10
  731. def test_nlp_minddataset_reader_basic_tutorial(add_and_remove_nlp_file):
  732. """tutorial for nlp minderdataset."""
  733. num_readers = 4
  734. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  735. data_set = ds.MindDataset(file_name + "0", None, num_readers)
  736. assert data_set.get_dataset_size() == 10
  737. num_iter = 0
  738. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  739. logger.info(
  740. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  741. logger.info(
  742. "-------------- num_iter: {} ------------------------".format(num_iter))
  743. logger.info(
  744. "-------------- item[id]: {} ------------------------".format(item["id"]))
  745. logger.info(
  746. "-------------- item[rating]: {} --------------------".format(item["rating"]))
  747. logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format(
  748. item["input_ids"], item["input_ids"].shape))
  749. logger.info("-------------- item[input_mask]: {}, shape: {} -----------------".format(
  750. item["input_mask"], item["input_mask"].shape))
  751. logger.info("-------------- item[segment_ids]: {}, shape: {} -----------------".format(
  752. item["segment_ids"], item["segment_ids"].shape))
  753. assert item["input_ids"].shape == (50,)
  754. assert item["input_mask"].shape == (1, 50)
  755. assert item["segment_ids"].shape == (2, 25)
  756. num_iter += 1
  757. assert num_iter == 10
  758. def test_cv_minddataset_reader_basic_tutorial_5_epoch(add_and_remove_cv_file):
  759. """tutorial for cv minderdataset."""
  760. columns_list = ["data", "file_name", "label"]
  761. num_readers = 4
  762. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  763. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers)
  764. assert data_set.get_dataset_size() == 10
  765. for _ in range(5):
  766. num_iter = 0
  767. for data in data_set.create_tuple_iterator(num_epochs=1, output_numpy=True):
  768. logger.info("data is {}".format(data))
  769. num_iter += 1
  770. assert num_iter == 10
  771. data_set.reset()
  772. def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_cv_file):
  773. """tutorial for cv minderdataset."""
  774. columns_list = ["data", "label"]
  775. num_readers = 4
  776. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  777. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers)
  778. resize_height = 32
  779. resize_width = 32
  780. # define map operations
  781. decode_op = vision.Decode()
  782. resize_op = vision.Resize((resize_height, resize_width))
  783. data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=4)
  784. data_set = data_set.map(input_columns=["data"], operations=resize_op, num_parallel_workers=4)
  785. data_set = data_set.batch(2)
  786. assert data_set.get_dataset_size() == 5
  787. for _ in range(5):
  788. num_iter = 0
  789. for data in data_set.create_tuple_iterator(num_epochs=1, output_numpy=True):
  790. logger.info("data is {}".format(data))
  791. num_iter += 1
  792. assert num_iter == 5
  793. data_set.reset()
  794. def test_cv_minddataset_reader_no_columns(add_and_remove_cv_file):
  795. """tutorial for cv minderdataset."""
  796. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  797. data_set = ds.MindDataset(file_name + "0")
  798. assert data_set.get_dataset_size() == 10
  799. num_iter = 0
  800. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  801. logger.info(
  802. "-------------- cv reader basic: {} ------------------------".format(num_iter))
  803. logger.info(
  804. "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  805. logger.info(
  806. "-------------- item[data]: {} -----------------------------".format(item["data"]))
  807. logger.info(
  808. "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  809. logger.info(
  810. "-------------- item[label]: {} ----------------------------".format(item["label"]))
  811. num_iter += 1
  812. assert num_iter == 10
  813. def test_cv_minddataset_reader_repeat_tutorial(add_and_remove_cv_file):
  814. """tutorial for cv minderdataset."""
  815. columns_list = ["data", "file_name", "label"]
  816. num_readers = 4
  817. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  818. data_set = ds.MindDataset(file_name + "0", columns_list, num_readers)
  819. repeat_num = 2
  820. data_set = data_set.repeat(repeat_num)
  821. num_iter = 0
  822. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  823. logger.info(
  824. "-------------- repeat two test {} ------------------------".format(num_iter))
  825. logger.info(
  826. "-------------- len(item[data]): {} -----------------------".format(len(item["data"])))
  827. logger.info(
  828. "-------------- item[data]: {} ----------------------------".format(item["data"]))
  829. logger.info(
  830. "-------------- item[file_name]: {} -----------------------".format(item["file_name"]))
  831. logger.info(
  832. "-------------- item[label]: {} ---------------------------".format(item["label"]))
  833. num_iter += 1
  834. assert num_iter == 20
  835. def get_data(dir_name):
  836. """
  837. usage: get data from imagenet dataset
  838. params:
  839. dir_name: directory containing folder images and annotation information
  840. """
  841. if not os.path.isdir(dir_name):
  842. raise IOError("Directory {} not exists".format(dir_name))
  843. img_dir = os.path.join(dir_name, "images")
  844. ann_file = os.path.join(dir_name, "annotation.txt")
  845. with open(ann_file, "r") as file_reader:
  846. lines = file_reader.readlines()
  847. data_list = []
  848. for i, line in enumerate(lines):
  849. try:
  850. filename, label = line.split(",")
  851. label = label.strip("\n")
  852. with open(os.path.join(img_dir, filename), "rb") as file_reader:
  853. img = file_reader.read()
  854. data_json = {"id": i,
  855. "file_name": filename,
  856. "data": img,
  857. "label": int(label)}
  858. data_list.append(data_json)
  859. except FileNotFoundError:
  860. continue
  861. return data_list
  862. def get_multi_bytes_data(file_name, bytes_num=3):
  863. """
  864. Return raw data of multi-bytes dataset.
  865. Args:
  866. file_name (str): String of multi-bytes dataset's path.
  867. bytes_num (int): Number of bytes fields.
  868. Returns:
  869. List
  870. """
  871. if not os.path.exists(file_name):
  872. raise IOError("map file {} not exists".format(file_name))
  873. dir_name = os.path.dirname(file_name)
  874. with open(file_name, "r") as file_reader:
  875. lines = file_reader.readlines()
  876. data_list = []
  877. row_num = 0
  878. for line in lines:
  879. try:
  880. img10_path = line.strip('\n').split(" ")
  881. img5 = []
  882. for path in img10_path[:bytes_num]:
  883. with open(os.path.join(dir_name, path), "rb") as file_reader:
  884. img5 += [file_reader.read()]
  885. data_json = {"image_{}".format(i): img5[i]
  886. for i in range(len(img5))}
  887. data_json.update({"id": row_num})
  888. row_num += 1
  889. data_list.append(data_json)
  890. except FileNotFoundError:
  891. continue
  892. return data_list
  893. def get_mkv_data(dir_name):
  894. """
  895. Return raw data of Vehicle_and_Person dataset.
  896. Args:
  897. dir_name (str): String of Vehicle_and_Person dataset's path.
  898. Returns:
  899. List
  900. """
  901. if not os.path.isdir(dir_name):
  902. raise IOError("Directory {} not exists".format(dir_name))
  903. img_dir = os.path.join(dir_name, "Image")
  904. label_dir = os.path.join(dir_name, "prelabel")
  905. data_list = []
  906. file_list = os.listdir(label_dir)
  907. index = 1
  908. for item in file_list:
  909. if os.path.splitext(item)[1] == '.json':
  910. file_path = os.path.join(label_dir, item)
  911. image_name = ''.join([os.path.splitext(item)[0], ".jpg"])
  912. image_path = os.path.join(img_dir, image_name)
  913. with open(file_path, "r") as load_f:
  914. load_dict = json.load(load_f)
  915. if os.path.exists(image_path):
  916. with open(image_path, "rb") as file_reader:
  917. img = file_reader.read()
  918. data_json = {"file_name": image_name,
  919. "prelabel": str(load_dict),
  920. "data": img,
  921. "id": index}
  922. data_list.append(data_json)
  923. index += 1
  924. logger.info('{} images are missing'.format(
  925. len(file_list) - len(data_list)))
  926. return data_list
  927. def get_nlp_data(dir_name, vocab_file, num):
  928. """
  929. Return raw data of aclImdb dataset.
  930. Args:
  931. dir_name (str): String of aclImdb dataset's path.
  932. vocab_file (str): String of dictionary's path.
  933. num (int): Number of sample.
  934. Returns:
  935. List
  936. """
  937. if not os.path.isdir(dir_name):
  938. raise IOError("Directory {} not exists".format(dir_name))
  939. for root, _, files in os.walk(dir_name):
  940. for index, file_name_extension in enumerate(files):
  941. if index < num:
  942. file_path = os.path.join(root, file_name_extension)
  943. file_name, _ = file_name_extension.split('.', 1)
  944. id_, rating = file_name.split('_', 1)
  945. with open(file_path, 'r') as f:
  946. raw_content = f.read()
  947. dictionary = load_vocab(vocab_file)
  948. vectors = [dictionary.get('[CLS]')]
  949. vectors += [dictionary.get(i) if i in dictionary
  950. else dictionary.get('[UNK]')
  951. for i in re.findall(r"[\w']+|[{}]"
  952. .format(string.punctuation),
  953. raw_content)]
  954. vectors += [dictionary.get('[SEP]')]
  955. input_, mask, segment = inputs(vectors)
  956. input_ids = np.reshape(np.array(input_), [-1])
  957. input_mask = np.reshape(np.array(mask), [1, -1])
  958. segment_ids = np.reshape(np.array(segment), [2, -1])
  959. data = {
  960. "label": 1,
  961. "id": id_,
  962. "rating": float(rating),
  963. "input_ids": input_ids,
  964. "input_mask": input_mask,
  965. "segment_ids": segment_ids
  966. }
  967. yield data
  968. def convert_to_uni(text):
  969. if isinstance(text, str):
  970. return text
  971. if isinstance(text, bytes):
  972. return text.decode('utf-8', 'ignore')
  973. raise Exception("The type %s does not convert!" % type(text))
  974. def load_vocab(vocab_file):
  975. """load vocabulary to translate statement."""
  976. vocab = collections.OrderedDict()
  977. vocab.setdefault('blank', 2)
  978. index = 0
  979. with open(vocab_file) as reader:
  980. while True:
  981. tmp = reader.readline()
  982. if not tmp:
  983. break
  984. token = convert_to_uni(tmp)
  985. token = token.strip()
  986. vocab[token] = index
  987. index += 1
  988. return vocab
  989. def inputs(vectors, maxlen=50):
  990. length = len(vectors)
  991. if length > maxlen:
  992. return vectors[0:maxlen], [1] * maxlen, [0] * maxlen
  993. input_ = vectors + [0] * (maxlen - length)
  994. mask = [1] * length + [0] * (maxlen - length)
  995. segment = [0] * maxlen
  996. return input_, mask, segment
  997. def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
  998. mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  999. try:
  1000. if os.path.exists("{}".format(mindrecord_file_name)):
  1001. os.remove("{}".format(mindrecord_file_name))
  1002. if os.path.exists("{}.db".format(mindrecord_file_name)):
  1003. os.remove("{}.db".format(mindrecord_file_name))
  1004. data = [{"file_name": "001.jpg", "label": 4,
  1005. "image1": bytes("image1 bytes abc", encoding='UTF-8'),
  1006. "image2": bytes("image1 bytes def", encoding='UTF-8'),
  1007. "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
  1008. "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1009. "image3": bytes("image1 bytes ghi", encoding='UTF-8'),
  1010. "image4": bytes("image1 bytes jkl", encoding='UTF-8'),
  1011. "image5": bytes("image1 bytes mno", encoding='UTF-8'),
  1012. "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
  1013. "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
  1014. "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1015. "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
  1016. {"file_name": "002.jpg", "label": 5,
  1017. "image1": bytes("image2 bytes abc", encoding='UTF-8'),
  1018. "image2": bytes("image2 bytes def", encoding='UTF-8'),
  1019. "image3": bytes("image2 bytes ghi", encoding='UTF-8'),
  1020. "image4": bytes("image2 bytes jkl", encoding='UTF-8'),
  1021. "image5": bytes("image2 bytes mno", encoding='UTF-8'),
  1022. "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
  1023. "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1024. "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
  1025. "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
  1026. "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1027. "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
  1028. {"file_name": "003.jpg", "label": 6,
  1029. "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
  1030. "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1031. "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
  1032. "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
  1033. "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1034. "image1": bytes("image3 bytes abc", encoding='UTF-8'),
  1035. "image2": bytes("image3 bytes def", encoding='UTF-8'),
  1036. "image3": bytes("image3 bytes ghi", encoding='UTF-8'),
  1037. "image4": bytes("image3 bytes jkl", encoding='UTF-8'),
  1038. "image5": bytes("image3 bytes mno", encoding='UTF-8'),
  1039. "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
  1040. {"file_name": "004.jpg", "label": 7,
  1041. "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
  1042. "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1043. "image1": bytes("image4 bytes abc", encoding='UTF-8'),
  1044. "image2": bytes("image4 bytes def", encoding='UTF-8'),
  1045. "image3": bytes("image4 bytes ghi", encoding='UTF-8'),
  1046. "image4": bytes("image4 bytes jkl", encoding='UTF-8'),
  1047. "image5": bytes("image4 bytes mno", encoding='UTF-8'),
  1048. "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
  1049. "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
  1050. "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1051. "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
  1052. {"file_name": "005.jpg", "label": 8,
  1053. "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
  1054. "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1055. "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
  1056. "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
  1057. "image1": bytes("image5 bytes abc", encoding='UTF-8'),
  1058. "image2": bytes("image5 bytes def", encoding='UTF-8'),
  1059. "image3": bytes("image5 bytes ghi", encoding='UTF-8'),
  1060. "image4": bytes("image5 bytes jkl", encoding='UTF-8'),
  1061. "image5": bytes("image5 bytes mno", encoding='UTF-8'),
  1062. "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1063. "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
  1064. {"file_name": "006.jpg", "label": 9,
  1065. "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
  1066. "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1067. "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
  1068. "image1": bytes("image6 bytes abc", encoding='UTF-8'),
  1069. "image2": bytes("image6 bytes def", encoding='UTF-8'),
  1070. "image3": bytes("image6 bytes ghi", encoding='UTF-8'),
  1071. "image4": bytes("image6 bytes jkl", encoding='UTF-8'),
  1072. "image5": bytes("image6 bytes mno", encoding='UTF-8'),
  1073. "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
  1074. "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1075. "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
  1076. ]
  1077. writer = FileWriter(mindrecord_file_name)
  1078. schema = {"file_name": {"type": "string"},
  1079. "image1": {"type": "bytes"},
  1080. "image2": {"type": "bytes"},
  1081. "source_sos_ids": {"type": "int64", "shape": [-1]},
  1082. "source_sos_mask": {"type": "int64", "shape": [-1]},
  1083. "image3": {"type": "bytes"},
  1084. "image4": {"type": "bytes"},
  1085. "image5": {"type": "bytes"},
  1086. "target_sos_ids": {"type": "int64", "shape": [-1]},
  1087. "target_sos_mask": {"type": "int64", "shape": [-1]},
  1088. "target_eos_ids": {"type": "int64", "shape": [-1]},
  1089. "target_eos_mask": {"type": "int64", "shape": [-1]},
  1090. "label": {"type": "int32"}}
  1091. writer.add_schema(schema, "data is so cool")
  1092. writer.write_raw_data(data)
  1093. writer.commit()
  1094. # change data value to list
  1095. data_value_to_list = []
  1096. for item in data:
  1097. new_data = {}
  1098. new_data['file_name'] = np.asarray(item["file_name"], dtype='S')
  1099. new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
  1100. new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
  1101. new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
  1102. new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
  1103. new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
  1104. new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
  1105. new_data['source_sos_ids'] = item["source_sos_ids"]
  1106. new_data['source_sos_mask'] = item["source_sos_mask"]
  1107. new_data['target_sos_ids'] = item["target_sos_ids"]
  1108. new_data['target_sos_mask'] = item["target_sos_mask"]
  1109. new_data['target_eos_ids'] = item["target_eos_ids"]
  1110. new_data['target_eos_mask'] = item["target_eos_mask"]
  1111. data_value_to_list.append(new_data)
  1112. num_readers = 2
  1113. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1114. num_parallel_workers=num_readers,
  1115. shuffle=False)
  1116. assert data_set.get_dataset_size() == 6
  1117. num_iter = 0
  1118. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1119. assert len(item) == 13
  1120. for field in item:
  1121. if isinstance(item[field], np.ndarray):
  1122. assert (item[field] ==
  1123. data_value_to_list[num_iter][field]).all()
  1124. else:
  1125. assert item[field] == data_value_to_list[num_iter][field]
  1126. num_iter += 1
  1127. assert num_iter == 6
  1128. num_readers = 2
  1129. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1130. columns_list=["source_sos_ids",
  1131. "source_sos_mask", "target_sos_ids"],
  1132. num_parallel_workers=num_readers,
  1133. shuffle=False)
  1134. assert data_set.get_dataset_size() == 6
  1135. num_iter = 0
  1136. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1137. assert len(item) == 3
  1138. for field in item:
  1139. if isinstance(item[field], np.ndarray):
  1140. assert (item[field] == data[num_iter][field]).all()
  1141. else:
  1142. assert item[field] == data[num_iter][field]
  1143. num_iter += 1
  1144. assert num_iter == 6
  1145. num_readers = 1
  1146. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1147. columns_list=["image2", "source_sos_mask", "image3", "target_sos_ids"],
  1148. num_parallel_workers=num_readers,
  1149. shuffle=False)
  1150. assert data_set.get_dataset_size() == 6
  1151. num_iter = 0
  1152. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1153. assert len(item) == 4
  1154. for field in item:
  1155. if isinstance(item[field], np.ndarray):
  1156. assert (item[field] ==
  1157. data_value_to_list[num_iter][field]).all()
  1158. else:
  1159. assert item[field] == data_value_to_list[num_iter][field]
  1160. num_iter += 1
  1161. assert num_iter == 6
  1162. num_readers = 3
  1163. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1164. columns_list=["target_sos_ids",
  1165. "image4", "source_sos_ids"],
  1166. num_parallel_workers=num_readers,
  1167. shuffle=False)
  1168. assert data_set.get_dataset_size() == 6
  1169. num_iter = 0
  1170. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1171. assert len(item) == 3
  1172. for field in item:
  1173. if isinstance(item[field], np.ndarray):
  1174. assert (item[field] ==
  1175. data_value_to_list[num_iter][field]).all()
  1176. else:
  1177. assert item[field] == data_value_to_list[num_iter][field]
  1178. num_iter += 1
  1179. assert num_iter == 6
  1180. num_readers = 3
  1181. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1182. columns_list=["target_sos_ids", "image5",
  1183. "image4", "image3", "source_sos_ids"],
  1184. num_parallel_workers=num_readers,
  1185. shuffle=False)
  1186. assert data_set.get_dataset_size() == 6
  1187. num_iter = 0
  1188. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1189. assert len(item) == 5
  1190. for field in item:
  1191. if isinstance(item[field], np.ndarray):
  1192. assert (item[field] ==
  1193. data_value_to_list[num_iter][field]).all()
  1194. else:
  1195. assert item[field] == data_value_to_list[num_iter][field]
  1196. num_iter += 1
  1197. assert num_iter == 6
  1198. num_readers = 1
  1199. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1200. columns_list=["target_eos_mask", "image5",
  1201. "image2", "source_sos_mask", "label"],
  1202. num_parallel_workers=num_readers,
  1203. shuffle=False)
  1204. assert data_set.get_dataset_size() == 6
  1205. num_iter = 0
  1206. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1207. assert len(item) == 5
  1208. for field in item:
  1209. if isinstance(item[field], np.ndarray):
  1210. assert (item[field] ==
  1211. data_value_to_list[num_iter][field]).all()
  1212. else:
  1213. assert item[field] == data_value_to_list[num_iter][field]
  1214. num_iter += 1
  1215. assert num_iter == 6
  1216. num_readers = 2
  1217. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1218. columns_list=["label", "target_eos_mask", "image1", "target_eos_ids",
  1219. "source_sos_mask", "image2", "image4", "image3",
  1220. "source_sos_ids", "image5", "file_name"],
  1221. num_parallel_workers=num_readers,
  1222. shuffle=False)
  1223. assert data_set.get_dataset_size() == 6
  1224. num_iter = 0
  1225. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1226. assert len(item) == 11
  1227. for field in item:
  1228. if isinstance(item[field], np.ndarray):
  1229. assert (item[field] ==
  1230. data_value_to_list[num_iter][field]).all()
  1231. else:
  1232. assert item[field] == data_value_to_list[num_iter][field]
  1233. num_iter += 1
  1234. assert num_iter == 6
  1235. except Exception as error:
  1236. os.remove("{}".format(mindrecord_file_name))
  1237. os.remove("{}.db".format(mindrecord_file_name))
  1238. raise error
  1239. else:
  1240. os.remove("{}".format(mindrecord_file_name))
  1241. os.remove("{}.db".format(mindrecord_file_name))
  1242. def test_write_with_multi_bytes_and_MindDataset():
  1243. mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  1244. try:
  1245. data = [{"file_name": "001.jpg", "label": 43,
  1246. "image1": bytes("image1 bytes abc", encoding='UTF-8'),
  1247. "image2": bytes("image1 bytes def", encoding='UTF-8'),
  1248. "image3": bytes("image1 bytes ghi", encoding='UTF-8'),
  1249. "image4": bytes("image1 bytes jkl", encoding='UTF-8'),
  1250. "image5": bytes("image1 bytes mno", encoding='UTF-8')},
  1251. {"file_name": "002.jpg", "label": 91,
  1252. "image1": bytes("image2 bytes abc", encoding='UTF-8'),
  1253. "image2": bytes("image2 bytes def", encoding='UTF-8'),
  1254. "image3": bytes("image2 bytes ghi", encoding='UTF-8'),
  1255. "image4": bytes("image2 bytes jkl", encoding='UTF-8'),
  1256. "image5": bytes("image2 bytes mno", encoding='UTF-8')},
  1257. {"file_name": "003.jpg", "label": 61,
  1258. "image1": bytes("image3 bytes abc", encoding='UTF-8'),
  1259. "image2": bytes("image3 bytes def", encoding='UTF-8'),
  1260. "image3": bytes("image3 bytes ghi", encoding='UTF-8'),
  1261. "image4": bytes("image3 bytes jkl", encoding='UTF-8'),
  1262. "image5": bytes("image3 bytes mno", encoding='UTF-8')},
  1263. {"file_name": "004.jpg", "label": 29,
  1264. "image1": bytes("image4 bytes abc", encoding='UTF-8'),
  1265. "image2": bytes("image4 bytes def", encoding='UTF-8'),
  1266. "image3": bytes("image4 bytes ghi", encoding='UTF-8'),
  1267. "image4": bytes("image4 bytes jkl", encoding='UTF-8'),
  1268. "image5": bytes("image4 bytes mno", encoding='UTF-8')},
  1269. {"file_name": "005.jpg", "label": 78,
  1270. "image1": bytes("image5 bytes abc", encoding='UTF-8'),
  1271. "image2": bytes("image5 bytes def", encoding='UTF-8'),
  1272. "image3": bytes("image5 bytes ghi", encoding='UTF-8'),
  1273. "image4": bytes("image5 bytes jkl", encoding='UTF-8'),
  1274. "image5": bytes("image5 bytes mno", encoding='UTF-8')},
  1275. {"file_name": "006.jpg", "label": 37,
  1276. "image1": bytes("image6 bytes abc", encoding='UTF-8'),
  1277. "image2": bytes("image6 bytes def", encoding='UTF-8'),
  1278. "image3": bytes("image6 bytes ghi", encoding='UTF-8'),
  1279. "image4": bytes("image6 bytes jkl", encoding='UTF-8'),
  1280. "image5": bytes("image6 bytes mno", encoding='UTF-8')}
  1281. ]
  1282. writer = FileWriter(mindrecord_file_name)
  1283. schema = {"file_name": {"type": "string"},
  1284. "image1": {"type": "bytes"},
  1285. "image2": {"type": "bytes"},
  1286. "image3": {"type": "bytes"},
  1287. "label": {"type": "int32"},
  1288. "image4": {"type": "bytes"},
  1289. "image5": {"type": "bytes"}}
  1290. writer.add_schema(schema, "data is so cool")
  1291. writer.write_raw_data(data)
  1292. writer.commit()
  1293. # change data value to list
  1294. data_value_to_list = []
  1295. for item in data:
  1296. new_data = {}
  1297. new_data['file_name'] = np.asarray(item["file_name"], dtype='S')
  1298. new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
  1299. new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
  1300. new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
  1301. new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
  1302. new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
  1303. new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
  1304. data_value_to_list.append(new_data)
  1305. num_readers = 2
  1306. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1307. num_parallel_workers=num_readers,
  1308. shuffle=False)
  1309. assert data_set.get_dataset_size() == 6
  1310. num_iter = 0
  1311. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1312. assert len(item) == 7
  1313. for field in item:
  1314. if isinstance(item[field], np.ndarray):
  1315. assert (item[field] ==
  1316. data_value_to_list[num_iter][field]).all()
  1317. else:
  1318. assert item[field] == data_value_to_list[num_iter][field]
  1319. num_iter += 1
  1320. assert num_iter == 6
  1321. num_readers = 2
  1322. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1323. columns_list=["image1", "image2", "image5"],
  1324. num_parallel_workers=num_readers,
  1325. shuffle=False)
  1326. assert data_set.get_dataset_size() == 6
  1327. num_iter = 0
  1328. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1329. assert len(item) == 3
  1330. for field in item:
  1331. if isinstance(item[field], np.ndarray):
  1332. assert (item[field] ==
  1333. data_value_to_list[num_iter][field]).all()
  1334. else:
  1335. assert item[field] == data_value_to_list[num_iter][field]
  1336. num_iter += 1
  1337. assert num_iter == 6
  1338. num_readers = 2
  1339. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1340. columns_list=["image2", "image4"],
  1341. num_parallel_workers=num_readers,
  1342. shuffle=False)
  1343. assert data_set.get_dataset_size() == 6
  1344. num_iter = 0
  1345. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1346. assert len(item) == 2
  1347. for field in item:
  1348. if isinstance(item[field], np.ndarray):
  1349. assert (item[field] ==
  1350. data_value_to_list[num_iter][field]).all()
  1351. else:
  1352. assert item[field] == data_value_to_list[num_iter][field]
  1353. num_iter += 1
  1354. assert num_iter == 6
  1355. num_readers = 2
  1356. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1357. columns_list=["image5", "image2"],
  1358. num_parallel_workers=num_readers,
  1359. shuffle=False)
  1360. assert data_set.get_dataset_size() == 6
  1361. num_iter = 0
  1362. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1363. assert len(item) == 2
  1364. for field in item:
  1365. if isinstance(item[field], np.ndarray):
  1366. assert (item[field] ==
  1367. data_value_to_list[num_iter][field]).all()
  1368. else:
  1369. assert item[field] == data_value_to_list[num_iter][field]
  1370. num_iter += 1
  1371. assert num_iter == 6
  1372. num_readers = 2
  1373. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1374. columns_list=["image5", "image2", "label"],
  1375. num_parallel_workers=num_readers,
  1376. shuffle=False)
  1377. assert data_set.get_dataset_size() == 6
  1378. num_iter = 0
  1379. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1380. assert len(item) == 3
  1381. for field in item:
  1382. if isinstance(item[field], np.ndarray):
  1383. assert (item[field] ==
  1384. data_value_to_list[num_iter][field]).all()
  1385. else:
  1386. assert item[field] == data_value_to_list[num_iter][field]
  1387. num_iter += 1
  1388. assert num_iter == 6
  1389. num_readers = 2
  1390. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1391. columns_list=["image4", "image5",
  1392. "image2", "image3", "file_name"],
  1393. num_parallel_workers=num_readers,
  1394. shuffle=False)
  1395. assert data_set.get_dataset_size() == 6
  1396. num_iter = 0
  1397. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1398. assert len(item) == 5
  1399. for field in item:
  1400. if isinstance(item[field], np.ndarray):
  1401. assert (item[field] ==
  1402. data_value_to_list[num_iter][field]).all()
  1403. else:
  1404. assert item[field] == data_value_to_list[num_iter][field]
  1405. num_iter += 1
  1406. assert num_iter == 6
  1407. except Exception as error:
  1408. os.remove("{}".format(mindrecord_file_name))
  1409. os.remove("{}.db".format(mindrecord_file_name))
  1410. raise error
  1411. else:
  1412. os.remove("{}".format(mindrecord_file_name))
  1413. os.remove("{}.db".format(mindrecord_file_name))
  1414. def test_write_with_multi_array_and_MindDataset():
  1415. mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  1416. try:
  1417. data = [{"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
  1418. "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1419. "source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64),
  1420. "source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
  1421. "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
  1422. "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
  1423. "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1424. "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
  1425. {"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
  1426. "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1427. "source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64),
  1428. "source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
  1429. "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
  1430. "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
  1431. "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1432. "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
  1433. {"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
  1434. "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1435. "source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64),
  1436. "source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
  1437. "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
  1438. "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
  1439. "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1440. "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
  1441. {"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
  1442. "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1443. "source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64),
  1444. "source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
  1445. "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
  1446. "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
  1447. "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1448. "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
  1449. {"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
  1450. "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1451. "source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64),
  1452. "source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
  1453. "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
  1454. "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
  1455. "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1456. "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
  1457. {"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
  1458. "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
  1459. "source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64),
  1460. "source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
  1461. "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
  1462. "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
  1463. "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
  1464. "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
  1465. ]
  1466. writer = FileWriter(mindrecord_file_name)
  1467. schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
  1468. "source_sos_mask": {"type": "int64", "shape": [-1]},
  1469. "source_eos_ids": {"type": "int64", "shape": [-1]},
  1470. "source_eos_mask": {"type": "int64", "shape": [-1]},
  1471. "target_sos_ids": {"type": "int64", "shape": [-1]},
  1472. "target_sos_mask": {"type": "int64", "shape": [-1]},
  1473. "target_eos_ids": {"type": "int64", "shape": [-1]},
  1474. "target_eos_mask": {"type": "int64", "shape": [-1]}}
  1475. writer.add_schema(schema, "data is so cool")
  1476. writer.write_raw_data(data)
  1477. writer.commit()
  1478. # change data value to list - do none
  1479. data_value_to_list = []
  1480. for item in data:
  1481. new_data = {}
  1482. new_data['source_sos_ids'] = item["source_sos_ids"]
  1483. new_data['source_sos_mask'] = item["source_sos_mask"]
  1484. new_data['source_eos_ids'] = item["source_eos_ids"]
  1485. new_data['source_eos_mask'] = item["source_eos_mask"]
  1486. new_data['target_sos_ids'] = item["target_sos_ids"]
  1487. new_data['target_sos_mask'] = item["target_sos_mask"]
  1488. new_data['target_eos_ids'] = item["target_eos_ids"]
  1489. new_data['target_eos_mask'] = item["target_eos_mask"]
  1490. data_value_to_list.append(new_data)
  1491. num_readers = 2
  1492. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1493. num_parallel_workers=num_readers,
  1494. shuffle=False)
  1495. assert data_set.get_dataset_size() == 6
  1496. num_iter = 0
  1497. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1498. assert len(item) == 8
  1499. for field in item:
  1500. if isinstance(item[field], np.ndarray):
  1501. assert (item[field] ==
  1502. data_value_to_list[num_iter][field]).all()
  1503. else:
  1504. assert item[field] == data_value_to_list[num_iter][field]
  1505. num_iter += 1
  1506. assert num_iter == 6
  1507. num_readers = 2
  1508. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1509. columns_list=["source_eos_ids", "source_eos_mask",
  1510. "target_sos_ids", "target_sos_mask",
  1511. "target_eos_ids", "target_eos_mask"],
  1512. num_parallel_workers=num_readers,
  1513. shuffle=False)
  1514. assert data_set.get_dataset_size() == 6
  1515. num_iter = 0
  1516. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1517. assert len(item) == 6
  1518. for field in item:
  1519. if isinstance(item[field], np.ndarray):
  1520. assert (item[field] ==
  1521. data_value_to_list[num_iter][field]).all()
  1522. else:
  1523. assert item[field] == data_value_to_list[num_iter][field]
  1524. num_iter += 1
  1525. assert num_iter == 6
  1526. num_readers = 2
  1527. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1528. columns_list=["source_sos_ids",
  1529. "target_sos_ids",
  1530. "target_eos_mask"],
  1531. num_parallel_workers=num_readers,
  1532. shuffle=False)
  1533. assert data_set.get_dataset_size() == 6
  1534. num_iter = 0
  1535. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1536. assert len(item) == 3
  1537. for field in item:
  1538. if isinstance(item[field], np.ndarray):
  1539. assert (item[field] ==
  1540. data_value_to_list[num_iter][field]).all()
  1541. else:
  1542. assert item[field] == data_value_to_list[num_iter][field]
  1543. num_iter += 1
  1544. assert num_iter == 6
  1545. num_readers = 2
  1546. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1547. columns_list=["target_eos_mask",
  1548. "source_eos_mask",
  1549. "source_sos_mask"],
  1550. num_parallel_workers=num_readers,
  1551. shuffle=False)
  1552. assert data_set.get_dataset_size() == 6
  1553. num_iter = 0
  1554. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1555. assert len(item) == 3
  1556. for field in item:
  1557. if isinstance(item[field], np.ndarray):
  1558. assert (item[field] ==
  1559. data_value_to_list[num_iter][field]).all()
  1560. else:
  1561. assert item[field] == data_value_to_list[num_iter][field]
  1562. num_iter += 1
  1563. assert num_iter == 6
  1564. num_readers = 2
  1565. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1566. columns_list=["target_eos_ids"],
  1567. num_parallel_workers=num_readers,
  1568. shuffle=False)
  1569. assert data_set.get_dataset_size() == 6
  1570. num_iter = 0
  1571. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1572. assert len(item) == 1
  1573. for field in item:
  1574. if isinstance(item[field], np.ndarray):
  1575. assert (item[field] ==
  1576. data_value_to_list[num_iter][field]).all()
  1577. else:
  1578. assert item[field] == data_value_to_list[num_iter][field]
  1579. num_iter += 1
  1580. assert num_iter == 6
  1581. num_readers = 1
  1582. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1583. columns_list=["target_eos_mask", "target_eos_ids",
  1584. "target_sos_mask", "target_sos_ids",
  1585. "source_eos_mask", "source_eos_ids",
  1586. "source_sos_mask", "source_sos_ids"],
  1587. num_parallel_workers=num_readers,
  1588. shuffle=False)
  1589. assert data_set.get_dataset_size() == 6
  1590. num_iter = 0
  1591. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1592. assert len(item) == 8
  1593. for field in item:
  1594. if isinstance(item[field], np.ndarray):
  1595. assert (item[field] ==
  1596. data_value_to_list[num_iter][field]).all()
  1597. else:
  1598. assert item[field] == data_value_to_list[num_iter][field]
  1599. num_iter += 1
  1600. assert num_iter == 6
  1601. except Exception as error:
  1602. os.remove("{}".format(mindrecord_file_name))
  1603. os.remove("{}.db".format(mindrecord_file_name))
  1604. raise error
  1605. else:
  1606. os.remove("{}".format(mindrecord_file_name))
  1607. os.remove("{}.db".format(mindrecord_file_name))
  1608. def test_numpy_generic():
  1609. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  1610. paths = ["{}{}".format(file_name, str(x).rjust(1, '0'))
  1611. for x in range(FILES_NUM)]
  1612. try:
  1613. for x in paths:
  1614. if os.path.exists("{}".format(x)):
  1615. os.remove("{}".format(x))
  1616. if os.path.exists("{}.db".format(x)):
  1617. os.remove("{}.db".format(x))
  1618. writer = FileWriter(file_name, FILES_NUM)
  1619. cv_schema_json = {"label1": {"type": "int32"}, "label2": {"type": "int64"},
  1620. "label3": {"type": "float32"}, "label4": {"type": "float64"}}
  1621. data = []
  1622. for idx in range(10):
  1623. row = {}
  1624. row['label1'] = np.int32(idx)
  1625. row['label2'] = np.int64(idx * 10)
  1626. row['label3'] = np.float32(idx + 0.12345)
  1627. row['label4'] = np.float64(idx + 0.12345789)
  1628. data.append(row)
  1629. writer.add_schema(cv_schema_json, "img_schema")
  1630. writer.write_raw_data(data)
  1631. writer.commit()
  1632. num_readers = 4
  1633. data_set = ds.MindDataset(file_name + "0", None, num_readers, shuffle=False)
  1634. assert data_set.get_dataset_size() == 10
  1635. idx = 0
  1636. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1637. assert item['label1'] == item['label1']
  1638. assert item['label2'] == item['label2']
  1639. assert item['label3'] == item['label3']
  1640. assert item['label4'] == item['label4']
  1641. idx += 1
  1642. assert idx == 10
  1643. except Exception as error:
  1644. for x in paths:
  1645. os.remove("{}".format(x))
  1646. os.remove("{}.db".format(x))
  1647. raise error
  1648. else:
  1649. for x in paths:
  1650. os.remove("{}".format(x))
  1651. os.remove("{}.db".format(x))
  1652. def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset():
  1653. mindrecord_file_name = "test_write_with_float32_float64_float32_array_float64_array_and_MindDataset.mindrecord"
  1654. try:
  1655. data = [{"float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32),
  1656. "float64_array": np.array([48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471,
  1657. 123414314.2141243, 87.1212122], dtype=np.float64),
  1658. "float32": 3456.12345,
  1659. "float64": 1987654321.123456785,
  1660. "int32_array": np.array([1, 2, 3, 4, 5], dtype=np.int32),
  1661. "int64_array": np.array([48, 49, 50, 51, 123414314, 87], dtype=np.int64),
  1662. "int32": 3456,
  1663. "int64": 947654321123},
  1664. {"float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32),
  1665. "float64_array": np.array([48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471,
  1666. 123414314.2141243, 87.1212122], dtype=np.float64),
  1667. "float32": 3456.12445,
  1668. "float64": 1987654321.123456786,
  1669. "int32_array": np.array([11, 21, 31, 41, 51], dtype=np.int32),
  1670. "int64_array": np.array([481, 491, 501, 511, 1234143141, 871], dtype=np.int64),
  1671. "int32": 3466,
  1672. "int64": 957654321123},
  1673. {"float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32),
  1674. "float64_array": np.array([48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471,
  1675. 123414314.2141243, 87.1212122], dtype=np.float64),
  1676. "float32": 3456.12545,
  1677. "float64": 1987654321.123456787,
  1678. "int32_array": np.array([12, 22, 32, 42, 52], dtype=np.int32),
  1679. "int64_array": np.array([482, 492, 502, 512, 1234143142, 872], dtype=np.int64),
  1680. "int32": 3476,
  1681. "int64": 967654321123},
  1682. {"float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32),
  1683. "float64_array": np.array([48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471,
  1684. 123414314.2141243, 87.1212122], dtype=np.float64),
  1685. "float32": 3456.12645,
  1686. "float64": 1987654321.123456788,
  1687. "int32_array": np.array([13, 23, 33, 43, 53], dtype=np.int32),
  1688. "int64_array": np.array([483, 493, 503, 513, 1234143143, 873], dtype=np.int64),
  1689. "int32": 3486,
  1690. "int64": 977654321123},
  1691. {"float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32),
  1692. "float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471,
  1693. 123414314.2141243, 87.1212122], dtype=np.float64),
  1694. "float32": 3456.12745,
  1695. "float64": 1987654321.123456789,
  1696. "int32_array": np.array([14, 24, 34, 44, 54], dtype=np.int32),
  1697. "int64_array": np.array([484, 494, 504, 514, 1234143144, 874], dtype=np.int64),
  1698. "int32": 3496,
  1699. "int64": 987654321123},
  1700. ]
  1701. writer = FileWriter(mindrecord_file_name)
  1702. schema = {"float32_array": {"type": "float32", "shape": [-1]},
  1703. "float64_array": {"type": "float64", "shape": [-1]},
  1704. "float32": {"type": "float32"},
  1705. "float64": {"type": "float64"},
  1706. "int32_array": {"type": "int32", "shape": [-1]},
  1707. "int64_array": {"type": "int64", "shape": [-1]},
  1708. "int32": {"type": "int32"},
  1709. "int64": {"type": "int64"}}
  1710. writer.add_schema(schema, "data is so cool")
  1711. writer.write_raw_data(data)
  1712. writer.commit()
  1713. # change data value to list - do none
  1714. data_value_to_list = []
  1715. for item in data:
  1716. new_data = {}
  1717. new_data['float32_array'] = item["float32_array"]
  1718. new_data['float64_array'] = item["float64_array"]
  1719. new_data['float32'] = item["float32"]
  1720. new_data['float64'] = item["float64"]
  1721. new_data['int32_array'] = item["int32_array"]
  1722. new_data['int64_array'] = item["int64_array"]
  1723. new_data['int32'] = item["int32"]
  1724. new_data['int64'] = item["int64"]
  1725. data_value_to_list.append(new_data)
  1726. num_readers = 2
  1727. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1728. num_parallel_workers=num_readers,
  1729. shuffle=False)
  1730. assert data_set.get_dataset_size() == 5
  1731. num_iter = 0
  1732. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1733. assert len(item) == 8
  1734. for field in item:
  1735. if isinstance(item[field], np.ndarray):
  1736. if item[field].dtype == np.float32:
  1737. assert (item[field] ==
  1738. np.array(data_value_to_list[num_iter][field], np.float32)).all()
  1739. else:
  1740. assert (item[field] ==
  1741. data_value_to_list[num_iter][field]).all()
  1742. else:
  1743. assert item[field] == data_value_to_list[num_iter][field]
  1744. num_iter += 1
  1745. assert num_iter == 5
  1746. num_readers = 2
  1747. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1748. columns_list=["float32", "int32"],
  1749. num_parallel_workers=num_readers,
  1750. shuffle=False)
  1751. assert data_set.get_dataset_size() == 5
  1752. num_iter = 0
  1753. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1754. assert len(item) == 2
  1755. for field in item:
  1756. if isinstance(item[field], np.ndarray):
  1757. if item[field].dtype == np.float32:
  1758. assert (item[field] ==
  1759. np.array(data_value_to_list[num_iter][field], np.float32)).all()
  1760. else:
  1761. assert (item[field] ==
  1762. data_value_to_list[num_iter][field]).all()
  1763. else:
  1764. assert item[field] == data_value_to_list[num_iter][field]
  1765. num_iter += 1
  1766. assert num_iter == 5
  1767. num_readers = 2
  1768. data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
  1769. columns_list=["float64", "int64"],
  1770. num_parallel_workers=num_readers,
  1771. shuffle=False)
  1772. assert data_set.get_dataset_size() == 5
  1773. num_iter = 0
  1774. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1775. assert len(item) == 2
  1776. for field in item:
  1777. if isinstance(item[field], np.ndarray):
  1778. if item[field].dtype == np.float32:
  1779. assert (item[field] ==
  1780. np.array(data_value_to_list[num_iter][field], np.float32)).all()
  1781. elif item[field].dtype == np.float64:
  1782. assert math.isclose(item[field],
  1783. np.array(data_value_to_list[num_iter][field], np.float64),
  1784. rel_tol=1e-14)
  1785. else:
  1786. assert (item[field] ==
  1787. data_value_to_list[num_iter][field]).all()
  1788. else:
  1789. assert item[field] == data_value_to_list[num_iter][field]
  1790. num_iter += 1
  1791. assert num_iter == 5
  1792. except Exception as error:
  1793. os.remove("{}".format(mindrecord_file_name))
  1794. os.remove("{}.db".format(mindrecord_file_name))
  1795. raise error
  1796. else:
  1797. os.remove("{}".format(mindrecord_file_name))
  1798. os.remove("{}.db".format(mindrecord_file_name))
  1799. @pytest.fixture
  1800. def create_multi_mindrecord_files():
  1801. """files: {0.mindrecord : 10, 1.mindrecord : 14, 2.mindrecord : 8, 3.mindrecord : 20}"""
  1802. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  1803. files = [file_name + str(idx) for idx in range(4)]
  1804. items = [10, 14, 8, 20]
  1805. file_items = {files[0]: items[0], files[1]: items[1], files[2]: items[2], files[3]: items[3]}
  1806. try:
  1807. index = 0
  1808. for key in file_items:
  1809. if os.path.exists(key):
  1810. os.remove("{}".format(key))
  1811. os.remove("{}.db".format(key))
  1812. value = file_items[key]
  1813. data_list = []
  1814. for i in range(value):
  1815. data = {}
  1816. data['id'] = i + index
  1817. data_list.append(data)
  1818. index += value
  1819. writer = FileWriter(key)
  1820. schema = {"id": {"type": "int32"}}
  1821. writer.add_schema(schema, "data is so cool")
  1822. writer.write_raw_data(data_list)
  1823. writer.commit()
  1824. yield "yield_create_multi_mindrecord_files"
  1825. except Exception as error:
  1826. for filename in file_items:
  1827. if os.path.exists(filename):
  1828. os.remove("{}".format(filename))
  1829. os.remove("{}.db".format(filename))
  1830. raise error
  1831. else:
  1832. for filename in file_items:
  1833. if os.path.exists(filename):
  1834. os.remove("{}".format(filename))
  1835. os.remove("{}.db".format(filename))
  1836. def test_shuffle_with_global_infile_files(create_multi_mindrecord_files):
  1837. ds.config.set_seed(1)
  1838. datas_all = []
  1839. index = 0
  1840. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  1841. files = [file_name + str(idx) for idx in range(4)]
  1842. items = [10, 14, 8, 20]
  1843. file_items = {files[0]: items[0], files[1]: items[1], files[2]: items[2], files[3]: items[3]}
  1844. for filename in file_items:
  1845. value = file_items[filename]
  1846. data_list = []
  1847. for i in range(value):
  1848. data = {}
  1849. data['id'] = np.array(i + index, dtype=np.int32)
  1850. data_list.append(data)
  1851. index += value
  1852. datas_all.append(data_list)
  1853. # no shuffle parameter
  1854. num_readers = 2
  1855. data_set = ds.MindDataset(dataset_file=files,
  1856. num_parallel_workers=num_readers)
  1857. assert data_set.get_dataset_size() == 52
  1858. num_iter = 0
  1859. datas_all_minddataset = []
  1860. data_list = []
  1861. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1862. assert len(item) == 1
  1863. data_list.append(item)
  1864. if num_iter == 9:
  1865. datas_all_minddataset.append(data_list)
  1866. data_list = []
  1867. elif num_iter == 23:
  1868. datas_all_minddataset.append(data_list)
  1869. data_list = []
  1870. elif num_iter == 31:
  1871. datas_all_minddataset.append(data_list)
  1872. data_list = []
  1873. elif num_iter == 51:
  1874. datas_all_minddataset.append(data_list)
  1875. data_list = []
  1876. num_iter += 1
  1877. assert data_set.get_dataset_size() == 52
  1878. assert len(datas_all) == len(datas_all_minddataset)
  1879. for i, _ in enumerate(datas_all):
  1880. assert len(datas_all[i]) == len(datas_all_minddataset[i])
  1881. assert datas_all[i] != datas_all_minddataset[i]
  1882. # shuffle=False
  1883. num_readers = 2
  1884. data_set = ds.MindDataset(dataset_file=files,
  1885. num_parallel_workers=num_readers,
  1886. shuffle=False)
  1887. assert data_set.get_dataset_size() == 52
  1888. num_iter = 0
  1889. datas_all_minddataset = []
  1890. data_list = []
  1891. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1892. assert len(item) == 1
  1893. data_list.append(item)
  1894. if num_iter == 9:
  1895. datas_all_minddataset.append(data_list)
  1896. data_list = []
  1897. elif num_iter == 23:
  1898. datas_all_minddataset.append(data_list)
  1899. data_list = []
  1900. elif num_iter == 31:
  1901. datas_all_minddataset.append(data_list)
  1902. data_list = []
  1903. elif num_iter == 51:
  1904. datas_all_minddataset.append(data_list)
  1905. data_list = []
  1906. num_iter += 1
  1907. assert data_set.get_dataset_size() == 52
  1908. assert len(datas_all) == len(datas_all_minddataset)
  1909. for i, _ in enumerate(datas_all):
  1910. assert len(datas_all[i]) == len(datas_all_minddataset[i])
  1911. assert datas_all[i] == datas_all_minddataset[i]
  1912. # shuffle=True
  1913. num_readers = 2
  1914. data_set = ds.MindDataset(dataset_file=files,
  1915. num_parallel_workers=num_readers,
  1916. shuffle=True)
  1917. assert data_set.get_dataset_size() == 52
  1918. num_iter = 0
  1919. datas_all_minddataset = []
  1920. data_list = []
  1921. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1922. assert len(item) == 1
  1923. data_list.append(item)
  1924. if num_iter == 9:
  1925. datas_all_minddataset.append(data_list)
  1926. data_list = []
  1927. elif num_iter == 23:
  1928. datas_all_minddataset.append(data_list)
  1929. data_list = []
  1930. elif num_iter == 31:
  1931. datas_all_minddataset.append(data_list)
  1932. data_list = []
  1933. elif num_iter == 51:
  1934. datas_all_minddataset.append(data_list)
  1935. data_list = []
  1936. num_iter += 1
  1937. assert data_set.get_dataset_size() == 52
  1938. assert len(datas_all) == len(datas_all_minddataset)
  1939. for i, _ in enumerate(datas_all):
  1940. assert len(datas_all[i]) == len(datas_all_minddataset[i])
  1941. assert datas_all[i] != datas_all_minddataset[i]
  1942. # shuffle=Shuffle.GLOBAL
  1943. num_readers = 2
  1944. data_set = ds.MindDataset(dataset_file=files,
  1945. num_parallel_workers=num_readers,
  1946. shuffle=ds.Shuffle.GLOBAL)
  1947. assert data_set.get_dataset_size() == 52
  1948. num_iter = 0
  1949. datas_all_minddataset = []
  1950. data_list = []
  1951. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1952. assert len(item) == 1
  1953. data_list.append(item)
  1954. if num_iter == 9:
  1955. datas_all_minddataset.append(data_list)
  1956. data_list = []
  1957. elif num_iter == 23:
  1958. datas_all_minddataset.append(data_list)
  1959. data_list = []
  1960. elif num_iter == 31:
  1961. datas_all_minddataset.append(data_list)
  1962. data_list = []
  1963. elif num_iter == 51:
  1964. datas_all_minddataset.append(data_list)
  1965. data_list = []
  1966. num_iter += 1
  1967. assert data_set.get_dataset_size() == 52
  1968. assert len(datas_all) == len(datas_all_minddataset)
  1969. for i, _ in enumerate(datas_all):
  1970. assert len(datas_all[i]) == len(datas_all_minddataset[i])
  1971. assert datas_all[i] != datas_all_minddataset[i]
  1972. # shuffle=Shuffle.INFILE
  1973. num_readers = 2
  1974. data_set = ds.MindDataset(dataset_file=files,
  1975. num_parallel_workers=num_readers,
  1976. shuffle=ds.Shuffle.INFILE)
  1977. assert data_set.get_dataset_size() == 52
  1978. num_iter = 0
  1979. datas_all_minddataset = []
  1980. data_list = []
  1981. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  1982. assert len(item) == 1
  1983. data_list.append(item)
  1984. if num_iter == 9:
  1985. datas_all_minddataset.append(data_list)
  1986. data_list = []
  1987. elif num_iter == 23:
  1988. datas_all_minddataset.append(data_list)
  1989. data_list = []
  1990. elif num_iter == 31:
  1991. datas_all_minddataset.append(data_list)
  1992. data_list = []
  1993. elif num_iter == 51:
  1994. datas_all_minddataset.append(data_list)
  1995. data_list = []
  1996. num_iter += 1
  1997. assert data_set.get_dataset_size() == 52
  1998. def sort_list_with_dict(dict_in_list):
  1999. keys = []
  2000. for item in dict_in_list:
  2001. for key in item:
  2002. keys.append(int(item[key]))
  2003. keys.sort()
  2004. data_list = []
  2005. for item in keys:
  2006. data = {}
  2007. data['id'] = np.array(item, dtype=np.int32)
  2008. data_list.append(data)
  2009. return data_list
  2010. assert len(datas_all) == len(datas_all_minddataset)
  2011. for i, _ in enumerate(datas_all):
  2012. assert len(datas_all[i]) == len(datas_all_minddataset[i])
  2013. assert datas_all[i] != datas_all_minddataset[i]
  2014. # order the datas_all_minddataset
  2015. new_datas_all_minddataset = sort_list_with_dict(datas_all_minddataset[i])
  2016. assert datas_all[i] == new_datas_all_minddataset
  2017. # shuffle=Shuffle.FILES
  2018. num_readers = 2
  2019. data_set = ds.MindDataset(dataset_file=files,
  2020. num_parallel_workers=num_readers,
  2021. shuffle=ds.Shuffle.FILES)
  2022. assert data_set.get_dataset_size() == 52
  2023. num_iter = 0
  2024. data_list = []
  2025. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  2026. assert len(item) == 1
  2027. data_list.append(item)
  2028. num_iter += 1
  2029. assert data_set.get_dataset_size() == 52
  2030. current_shard_size = 0
  2031. current_shard_index = 0
  2032. shard_count = 0
  2033. datas_index = 0
  2034. origin_index = [i for i in range(len(items))]
  2035. current_index = []
  2036. while shard_count < len(items):
  2037. if data_list[datas_index]['id'] < 10:
  2038. current_shard_index = 0
  2039. elif data_list[datas_index]['id'] < 24:
  2040. current_shard_index = 1
  2041. elif data_list[datas_index]['id'] < 32:
  2042. current_shard_index = 2
  2043. elif data_list[datas_index]['id'] < 52:
  2044. current_shard_index = 3
  2045. else:
  2046. raise ValueError("Index out of range")
  2047. current_shard_size = items[current_shard_index]
  2048. tmp_datas = data_list[datas_index:datas_index + current_shard_size]
  2049. current_index.append(current_shard_index)
  2050. assert len(datas_all[current_shard_index]) == len(tmp_datas)
  2051. assert datas_all[current_shard_index] == tmp_datas
  2052. datas_index += current_shard_size
  2053. shard_count += 1
  2054. assert origin_index != current_index
  2055. def test_distributed_shuffle_with_global_infile_files(create_multi_mindrecord_files):
  2056. ds.config.set_seed(1)
  2057. datas_all = []
  2058. datas_all_samples = []
  2059. index = 0
  2060. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  2061. files = [file_name + str(idx) for idx in range(4)]
  2062. items = [10, 14, 8, 20]
  2063. file_items = {files[0]: items[0], files[1]: items[1], files[2]: items[2], files[3]: items[3]}
  2064. for filename in file_items:
  2065. value = file_items[filename]
  2066. data_list = []
  2067. for i in range(value):
  2068. data = {}
  2069. data['id'] = np.array(i + index, dtype=np.int32)
  2070. data_list.append(data)
  2071. datas_all_samples.append(data)
  2072. index += value
  2073. datas_all.append(data_list)
  2074. # no shuffle parameter
  2075. num_readers = 2
  2076. data_set = ds.MindDataset(dataset_file=files,
  2077. num_parallel_workers=num_readers,
  2078. num_shards=4,
  2079. shard_id=3)
  2080. assert data_set.get_dataset_size() == 13
  2081. num_iter = 0
  2082. data_list = []
  2083. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  2084. assert len(item) == 1
  2085. data_list.append(item)
  2086. num_iter += 1
  2087. assert num_iter == 13
  2088. assert data_list != datas_all_samples[3*13:]
  2089. # shuffle=False
  2090. num_readers = 2
  2091. data_set = ds.MindDataset(dataset_file=files,
  2092. num_parallel_workers=num_readers,
  2093. shuffle=False,
  2094. num_shards=4,
  2095. shard_id=2)
  2096. assert data_set.get_dataset_size() == 13
  2097. num_iter = 0
  2098. data_list = []
  2099. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  2100. assert len(item) == 1
  2101. data_list.append(item)
  2102. num_iter += 1
  2103. assert num_iter == 13
  2104. assert data_list == datas_all_samples[2*13:3*13]
  2105. # shuffle=True
  2106. num_readers = 2
  2107. data_set = ds.MindDataset(dataset_file=files,
  2108. num_parallel_workers=num_readers,
  2109. shuffle=True,
  2110. num_shards=4,
  2111. shard_id=1)
  2112. assert data_set.get_dataset_size() == 13
  2113. num_iter = 0
  2114. data_list = []
  2115. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  2116. assert len(item) == 1
  2117. data_list.append(item)
  2118. num_iter += 1
  2119. assert num_iter == 13
  2120. assert data_list != datas_all_samples[1*13:2*13]
  2121. # shuffle=Shuffle.GLOBAL
  2122. num_readers = 2
  2123. data_set = ds.MindDataset(dataset_file=files,
  2124. num_parallel_workers=num_readers,
  2125. shuffle=ds.Shuffle.GLOBAL,
  2126. num_shards=4,
  2127. shard_id=0)
  2128. assert data_set.get_dataset_size() == 13
  2129. num_iter = 0
  2130. data_list = []
  2131. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  2132. assert len(item) == 1
  2133. data_list.append(item)
  2134. num_iter += 1
  2135. assert num_iter == 13
  2136. assert data_list != datas_all_samples[0:1*13]
  2137. # shuffle=Shuffle.INFILE
  2138. output_datas = []
  2139. for shard_id in range(4):
  2140. num_readers = 2
  2141. data_set = ds.MindDataset(dataset_file=files,
  2142. num_parallel_workers=num_readers,
  2143. shuffle=ds.Shuffle.INFILE,
  2144. num_shards=4,
  2145. shard_id=shard_id)
  2146. assert data_set.get_dataset_size() == 13
  2147. num_iter = 0
  2148. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  2149. assert len(item) == 1
  2150. output_datas.append(item)
  2151. num_iter += 1
  2152. assert num_iter == 13
  2153. num_iter = 0
  2154. datas_all_minddataset = []
  2155. data_list = []
  2156. for item in output_datas:
  2157. assert len(item) == 1
  2158. data_list.append(item)
  2159. if num_iter == 9:
  2160. datas_all_minddataset.append(data_list)
  2161. data_list = []
  2162. elif num_iter == 23:
  2163. datas_all_minddataset.append(data_list)
  2164. data_list = []
  2165. elif num_iter == 31:
  2166. datas_all_minddataset.append(data_list)
  2167. data_list = []
  2168. elif num_iter == 51:
  2169. datas_all_minddataset.append(data_list)
  2170. data_list = []
  2171. num_iter += 1
  2172. assert num_iter == 52
  2173. def sort_list_with_dict(dict_in_list):
  2174. keys = []
  2175. for item in dict_in_list:
  2176. for key in item:
  2177. keys.append(int(item[key]))
  2178. keys.sort()
  2179. data_list = []
  2180. for item in keys:
  2181. data = {}
  2182. data['id'] = np.array(item, dtype=np.int32)
  2183. data_list.append(data)
  2184. return data_list
  2185. assert len(datas_all) == len(datas_all_minddataset)
  2186. for i, _ in enumerate(datas_all):
  2187. assert len(datas_all[i]) == len(datas_all_minddataset[i])
  2188. assert datas_all[i] != datas_all_minddataset[i]
  2189. # order the datas_all_minddataset
  2190. new_datas_all_minddataset = sort_list_with_dict(datas_all_minddataset[i])
  2191. assert datas_all[i] == new_datas_all_minddataset
  2192. # shuffle=Shuffle.Files
  2193. data_list = []
  2194. for shard_id in range(4):
  2195. num_readers = 2
  2196. data_set = ds.MindDataset(dataset_file=files,
  2197. num_parallel_workers=num_readers,
  2198. shuffle=ds.Shuffle.FILES,
  2199. num_shards=4,
  2200. shard_id=shard_id)
  2201. assert data_set.get_dataset_size() == 13
  2202. num_iter = 0
  2203. for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
  2204. assert len(item) == 1
  2205. data_list.append(item)
  2206. num_iter += 1
  2207. assert num_iter == 13
  2208. assert len(data_list) == 52
  2209. current_shard_size = 0
  2210. current_shard_index = 0
  2211. shard_count = 0
  2212. datas_index = 0
  2213. origin_index = [i for i in range(len(items))]
  2214. current_index = []
  2215. while shard_count < len(items):
  2216. if data_list[datas_index]['id'] < 10:
  2217. current_shard_index = 0
  2218. elif data_list[datas_index]['id'] < 24:
  2219. current_shard_index = 1
  2220. elif data_list[datas_index]['id'] < 32:
  2221. current_shard_index = 2
  2222. elif data_list[datas_index]['id'] < 52:
  2223. current_shard_index = 3
  2224. else:
  2225. raise ValueError("Index out of range")
  2226. current_shard_size = items[current_shard_index]
  2227. tmp_datas = data_list[datas_index:datas_index + current_shard_size]
  2228. current_index.append(current_shard_index)
  2229. assert len(datas_all[current_shard_index]) == len(tmp_datas)
  2230. assert datas_all[current_shard_index] == tmp_datas
  2231. datas_index += current_shard_size
  2232. shard_count += 1
  2233. assert origin_index != current_index
  2234. def test_distributed_shuffle_with_multi_epochs(create_multi_mindrecord_files):
  2235. ds.config.set_seed(1)
  2236. datas_all = []
  2237. datas_all_samples = []
  2238. index = 0
  2239. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  2240. files = [file_name + str(idx) for idx in range(4)]
  2241. items = [10, 14, 8, 20]
  2242. file_items = {files[0]: items[0], files[1]: items[1], files[2]: items[2], files[3]: items[3]}
  2243. for filename in file_items:
  2244. value = file_items[filename]
  2245. data_list = []
  2246. for i in range(value):
  2247. data = {}
  2248. data['id'] = np.array(i + index, dtype=np.int32)
  2249. data_list.append(data)
  2250. datas_all_samples.append(data)
  2251. index += value
  2252. datas_all.append(data_list)
  2253. epoch_size = 3
  2254. # no shuffle parameter
  2255. for shard_id in range(4):
  2256. num_readers = 2
  2257. data_set = ds.MindDataset(dataset_file=files,
  2258. num_parallel_workers=num_readers,
  2259. num_shards=4,
  2260. shard_id=shard_id)
  2261. assert data_set.get_dataset_size() == 13
  2262. data_list = []
  2263. dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
  2264. for epoch in range(epoch_size): # 3 epoch
  2265. num_iter = 0
  2266. new_datas = []
  2267. for item in dataset_iter:
  2268. assert len(item) == 1
  2269. new_datas.append(item)
  2270. num_iter += 1
  2271. assert num_iter == 13
  2272. assert new_datas != datas_all_samples[shard_id*13:(shard_id+1)*13]
  2273. assert data_list != new_datas
  2274. data_list = new_datas
  2275. # shuffle=False
  2276. for shard_id in range(4):
  2277. num_readers = 2
  2278. data_set = ds.MindDataset(dataset_file=files,
  2279. num_parallel_workers=num_readers,
  2280. shuffle=False,
  2281. num_shards=4,
  2282. shard_id=shard_id)
  2283. assert data_set.get_dataset_size() == 13
  2284. data_list = []
  2285. dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
  2286. for epoch in range(epoch_size): # 3 epoch
  2287. num_iter = 0
  2288. new_datas = []
  2289. for item in dataset_iter:
  2290. assert len(item) == 1
  2291. new_datas.append(item)
  2292. num_iter += 1
  2293. assert num_iter == 13
  2294. assert new_datas == datas_all_samples[shard_id*13:(shard_id+1)*13]
  2295. # shuffle=True
  2296. for shard_id in range(4):
  2297. num_readers = 2
  2298. data_set = ds.MindDataset(dataset_file=files,
  2299. num_parallel_workers=num_readers,
  2300. shuffle=True,
  2301. num_shards=4,
  2302. shard_id=shard_id)
  2303. assert data_set.get_dataset_size() == 13
  2304. data_list = []
  2305. dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
  2306. for epoch in range(epoch_size): # 3 epoch
  2307. num_iter = 0
  2308. new_datas = []
  2309. for item in dataset_iter:
  2310. assert len(item) == 1
  2311. new_datas.append(item)
  2312. num_iter += 1
  2313. assert num_iter == 13
  2314. assert new_datas != datas_all_samples[shard_id*13:(shard_id+1)*13]
  2315. assert data_list != new_datas
  2316. data_list = new_datas
  2317. # shuffle=Shuffle.GLOBAL
  2318. for shard_id in range(4):
  2319. num_readers = 2
  2320. data_set = ds.MindDataset(dataset_file=files,
  2321. num_parallel_workers=num_readers,
  2322. shuffle=ds.Shuffle.GLOBAL,
  2323. num_shards=4,
  2324. shard_id=shard_id)
  2325. assert data_set.get_dataset_size() == 13
  2326. data_list = []
  2327. dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
  2328. for epoch in range(epoch_size): # 3 epoch
  2329. num_iter = 0
  2330. new_datas = []
  2331. for item in dataset_iter:
  2332. assert len(item) == 1
  2333. new_datas.append(item)
  2334. num_iter += 1
  2335. assert num_iter == 13
  2336. assert new_datas != datas_all_samples[shard_id*13:(shard_id+1)*13]
  2337. assert data_list != new_datas
  2338. data_list = new_datas
  2339. # shuffle=Shuffle.INFILE
  2340. for shard_id in range(4):
  2341. num_readers = 2
  2342. data_set = ds.MindDataset(dataset_file=files,
  2343. num_parallel_workers=num_readers,
  2344. shuffle=ds.Shuffle.INFILE,
  2345. num_shards=4,
  2346. shard_id=shard_id)
  2347. assert data_set.get_dataset_size() == 13
  2348. data_list = []
  2349. dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
  2350. for epoch in range(epoch_size): # 3 epoch
  2351. num_iter = 0
  2352. new_datas = []
  2353. for item in dataset_iter:
  2354. assert len(item) == 1
  2355. new_datas.append(item)
  2356. num_iter += 1
  2357. assert num_iter == 13
  2358. assert new_datas != datas_all_samples[shard_id*13:(shard_id+1)*13]
  2359. assert data_list != new_datas
  2360. data_list = new_datas
  2361. # shuffle=Shuffle.FILES
  2362. datas_epoch1 = []
  2363. datas_epoch2 = []
  2364. datas_epoch3 = []
  2365. for shard_id in range(4):
  2366. num_readers = 2
  2367. data_set = ds.MindDataset(dataset_file=files,
  2368. num_parallel_workers=num_readers,
  2369. shuffle=ds.Shuffle.FILES,
  2370. num_shards=4,
  2371. shard_id=shard_id)
  2372. assert data_set.get_dataset_size() == 13
  2373. dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
  2374. for epoch in range(epoch_size): # 3 epoch
  2375. num_iter = 0
  2376. for item in dataset_iter:
  2377. assert len(item) == 1
  2378. if epoch == 0:
  2379. datas_epoch1.append(item)
  2380. elif epoch == 1:
  2381. datas_epoch2.append(item)
  2382. elif epoch == 2:
  2383. datas_epoch3.append(item)
  2384. num_iter += 1
  2385. assert num_iter == 13
  2386. assert datas_epoch1 not in (datas_epoch2, datas_epoch3)
  2387. assert datas_epoch2 not in (datas_epoch1, datas_epoch3)
  2388. assert datas_epoch3 not in (datas_epoch2, datas_epoch1)
  2389. def test_field_is_null_numpy():
  2390. """add/remove nlp file"""
  2391. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  2392. paths = ["{}{}".format(file_name, str(x).rjust(1, '0'))
  2393. for x in range(FILES_NUM)]
  2394. for x in paths:
  2395. if os.path.exists("{}".format(x)):
  2396. os.remove("{}".format(x))
  2397. if os.path.exists("{}.db".format(x)):
  2398. os.remove("{}.db".format(x))
  2399. writer = FileWriter(file_name, FILES_NUM)
  2400. data = []
  2401. # field array_d is null
  2402. for row_id in range(16):
  2403. data.append({
  2404. "label": row_id,
  2405. "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129,
  2406. 255, 256, -32768, 32767, -32769, 32768, -2147483648,
  2407. 2147483647], dtype=np.int32), [-1]),
  2408. "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
  2409. 256, -32768, 32767, -32769, 32768,
  2410. -2147483648, 2147483647, -2147483649, 2147483649,
  2411. -922337036854775808, 9223372036854775807]), [1, -1]),
  2412. "array_d": np.array([], dtype=np.int64)
  2413. })
  2414. nlp_schema_json = {"label": {"type": "int32"},
  2415. "array_a": {"type": "int32",
  2416. "shape": [-1]},
  2417. "array_b": {"type": "int64",
  2418. "shape": [1, -1]},
  2419. "array_d": {"type": "int64",
  2420. "shape": [-1]}
  2421. }
  2422. writer.set_header_size(1 << 14)
  2423. writer.set_page_size(1 << 15)
  2424. writer.add_schema(nlp_schema_json, "nlp_schema")
  2425. writer.write_raw_data(data)
  2426. writer.commit()
  2427. data_set = ds.MindDataset(dataset_file=file_name + "0",
  2428. columns_list=["label", "array_a", "array_b", "array_d"],
  2429. num_parallel_workers=2,
  2430. shuffle=False)
  2431. assert data_set.get_dataset_size() == 16
  2432. assert data_set.output_shapes() == [[], [15], [1, 19], []]
  2433. assert data_set.output_types()[0] == np.int32
  2434. assert data_set.output_types()[1] == np.int32
  2435. assert data_set.output_types()[2] == np.int64
  2436. assert data_set.output_types()[3] == np.int64
  2437. for x in paths:
  2438. os.remove("{}".format(x))
  2439. os.remove("{}.db".format(x))
  2440. def test_for_loop_dataset_iterator(add_and_remove_nlp_compress_file):
  2441. """test for loop dataset iterator"""
  2442. data = []
  2443. for row_id in range(16):
  2444. data.append({
  2445. "label": row_id,
  2446. "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129,
  2447. 255, 256, -32768, 32767, -32769, 32768, -2147483648,
  2448. 2147483647], dtype=np.int32), [-1]),
  2449. "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
  2450. 256, -32768, 32767, -32769, 32768,
  2451. -2147483648, 2147483647, -2147483649, 2147483649,
  2452. -922337036854775808, 9223372036854775807]), [1, -1]),
  2453. "array_c": str.encode("nlp data"),
  2454. "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1])
  2455. })
  2456. num_readers = 1
  2457. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  2458. data_set = ds.MindDataset(
  2459. file_name + "0", None, num_readers, shuffle=False)
  2460. assert data_set.get_dataset_size() == 16
  2461. # create_dict_iterator in for loop
  2462. for _ in range(10):
  2463. num_iter = 0
  2464. for x, item in zip(data, data_set.create_dict_iterator(num_epochs=1, output_numpy=True)):
  2465. assert (item["array_a"] == x["array_a"]).all()
  2466. assert (item["array_b"] == x["array_b"]).all()
  2467. assert item["array_c"].tobytes() == x["array_c"]
  2468. assert (item["array_d"] == x["array_d"]).all()
  2469. assert item["label"] == x["label"]
  2470. num_iter += 1
  2471. assert num_iter == 16
  2472. # create_dict_iterator beyond for loop
  2473. dataset_iter = data_set.create_dict_iterator(num_epochs=10, output_numpy=True)
  2474. new_data = data * 10
  2475. for _ in range(10):
  2476. num_iter = 0
  2477. for x, item in zip(new_data, dataset_iter):
  2478. assert (item["array_a"] == x["array_a"]).all()
  2479. assert (item["array_b"] == x["array_b"]).all()
  2480. assert item["array_c"].tobytes() == x["array_c"]
  2481. assert (item["array_d"] == x["array_d"]).all()
  2482. assert item["label"] == x["label"]
  2483. num_iter += 1
  2484. assert num_iter == 16
  2485. # create mulit iter by user
  2486. dataset_iter2 = data_set.create_dict_iterator(num_epochs=1, output_numpy=True)
  2487. assert (next(dataset_iter2)["array_a"] == data[0]["array_a"]).all()
  2488. assert (next(dataset_iter2)["array_a"] == data[1]["array_a"]).all()
  2489. dataset_iter3 = data_set.create_dict_iterator(num_epochs=1, output_numpy=True)
  2490. assert (next(dataset_iter3)["array_a"] == data[0]["array_a"]).all()
  2491. assert (next(dataset_iter3)["array_a"] == data[1]["array_a"]).all()
  2492. assert (next(dataset_iter3)["array_a"] == data[2]["array_a"]).all()
  2493. assert (next(dataset_iter2)["array_a"] == data[2]["array_a"]).all()
  2494. assert (next(dataset_iter2)["array_a"] == data[3]["array_a"]).all()
  2495. dataset_iter4 = data_set.create_dict_iterator(num_epochs=1, output_numpy=True)
  2496. assert (next(dataset_iter4)["array_a"] == data[0]["array_a"]).all()
  2497. assert (next(dataset_iter4)["array_a"] == data[1]["array_a"]).all()
  2498. assert (next(dataset_iter4)["array_a"] == data[2]["array_a"]).all()
  2499. assert (next(dataset_iter3)["array_a"] == data[3]["array_a"]).all()
  2500. assert (next(dataset_iter3)["array_a"] == data[4]["array_a"]).all()
  2501. assert (next(dataset_iter3)["array_a"] == data[5]["array_a"]).all()
  2502. if __name__ == '__main__':
  2503. test_nlp_compress_data(add_and_remove_nlp_compress_file)
  2504. test_nlp_compress_data_old_version(add_and_remove_nlp_compress_file)
  2505. test_cv_minddataset_writer_tutorial()
  2506. test_cv_minddataset_partition_tutorial(add_and_remove_cv_file)
  2507. test_cv_minddataset_partition_num_samples_0(add_and_remove_cv_file)
  2508. test_cv_minddataset_partition_num_samples_1(add_and_remove_cv_file)
  2509. test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file)
  2510. test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file)
  2511. test_cv_minddataset_partition_tutorial_check_whole_reshuffle_result_per_epoch(add_and_remove_cv_file)
  2512. test_cv_minddataset_check_shuffle_result(add_and_remove_cv_file)
  2513. test_cv_minddataset_dataset_size(add_and_remove_cv_file)
  2514. test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file)
  2515. test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file)
  2516. test_cv_minddataset_issue_888(add_and_remove_cv_file)
  2517. test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file)
  2518. test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_remove_cv_file)
  2519. test_cv_minddataset_reader_file_list(add_and_remove_cv_file)
  2520. test_cv_minddataset_reader_one_partition(add_and_remove_cv_file)
  2521. test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file)
  2522. test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file)
  2523. test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file)
  2524. test_nlp_minddataset_reader_basic_tutorial(add_and_remove_cv_file)
  2525. test_cv_minddataset_reader_basic_tutorial_5_epoch(add_and_remove_cv_file)
  2526. test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_cv_file)
  2527. test_cv_minddataset_reader_no_columns(add_and_remove_cv_file)
  2528. test_cv_minddataset_reader_repeat_tutorial(add_and_remove_cv_file)
  2529. test_write_with_multi_bytes_and_array_and_read_by_MindDataset()
  2530. test_write_with_multi_bytes_and_MindDataset()
  2531. test_write_with_multi_array_and_MindDataset()
  2532. test_numpy_generic()
  2533. test_write_with_float32_float64_float32_array_float64_array_and_MindDataset()
  2534. test_shuffle_with_global_infile_files(create_multi_mindrecord_files)
  2535. test_distributed_shuffle_with_global_infile_files(create_multi_mindrecord_files)
  2536. test_distributed_shuffle_with_multi_epochs(create_multi_mindrecord_files)
  2537. test_field_is_null_numpy()
  2538. test_for_loop_dataset_iterator(add_and_remove_nlp_compress_file)