You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 212 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. datasets.py supports various formats of datasets, including ImageNet, TFData,
  17. MNIST, Cifar10/100, Manifest, MindRecord, etc. This module could load data in
  18. high performance and parse data precisely. It also provides the following
  19. operations for users to preprocess data: shuffle, batch, repeat, map, and zip.
  20. """
  21. import glob
  22. import json
  23. import math
  24. import os
  25. import uuid
  26. import multiprocessing
  27. import queue
  28. from enum import Enum
  29. from importlib import import_module
  30. import threading
  31. import copy
  32. import numpy as np
  33. from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
  34. MindRecordOp, TextFileOp, ClueOp, VOCOp, CocoOp, CBatchInfo
  35. from mindspore._c_expression import typing
  36. from mindspore import log as logger
  37. from . import samplers
  38. from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp
  39. from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
  40. check_rename, check_numpyslicesdataset, \
  41. check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
  42. check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
  43. check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
  44. check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save
  45. from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
  46. try:
  47. context = import_module("mindspore.context")
  48. except ModuleNotFoundError:
  49. context = None
  50. class Shuffle(str, Enum):
  51. GLOBAL: str = "global"
  52. FILES: str = "file"
  53. @check_zip
  54. def zip(datasets):
  55. """
  56. Zip the datasets in the input tuple of datasets.
  57. Args:
  58. datasets (tuple of class Dataset): A tuple of datasets to be zipped together.
  59. The number of datasets should be more than 1.
  60. Returns:
  61. DatasetOp, ZipDataset.
  62. Raises:
  63. ValueError: If the number of datasets is 1.
  64. TypeError: If datasets is not a tuple.
  65. Examples:
  66. >>> import mindspore.dataset as ds
  67. >>>
  68. >>> dataset_dir1 = "path/to/imagefolder_directory1"
  69. >>> dataset_dir2 = "path/to/imagefolder_directory2"
  70. >>> ds1 = ds.ImageFolderDatasetV2(dataset_dir1, num_parallel_workers=8)
  71. >>> ds2 = ds.ImageFolderDatasetV2(dataset_dir2, num_parallel_workers=8)
  72. >>>
  73. >>> # creates a dataset which is the combination of ds1 and ds2
  74. >>> data = ds.zip((ds1, ds2))
  75. """
  76. if len(datasets) <= 1:
  77. raise ValueError(
  78. "Can't zip empty or just one dataset!")
  79. return ZipDataset(datasets)
  80. def get_num_rows(num_rows, num_shards):
  81. """
  82. Get the number rows of the dataset according to the shards.
  83. Args:
  84. num_rows (int): The number rows of the dataset should be more than 0.
  85. The number rows of the dataset should be more than 0.
  86. num_shards (int or None): Number of shards that the dataset should be divided into.
  87. The number of shards should be None or more than 1.
  88. Returns:
  89. Int, number of rows.
  90. Raises:
  91. ValueError: If num_rows is invalid (< 0).
  92. ValueError: If num_shards is invalid (<= 0).
  93. """
  94. if num_rows < 0:
  95. raise ValueError("num_rows is invalid (< 0)")
  96. if num_shards is not None:
  97. if num_shards <= 0:
  98. raise ValueError("num_shards is invalid (<= 0)")
  99. if num_rows % num_shards == 0:
  100. num_rows = num_rows // num_shards
  101. else:
  102. num_rows = num_rows // num_shards + 1
  103. return num_rows
  104. class Dataset:
  105. """
  106. Abstract class to represent a dataset in DataEngine's data pipeline.
  107. This class is the base class of SourceDataset and DatasetOp, and represents
  108. a node in the data flow graph.
  109. Args:
  110. num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel
  111. (default=None).
  112. """
  113. def __init__(self, num_parallel_workers=None):
  114. self.children = []
  115. self.parent = []
  116. self.num_parallel_workers = num_parallel_workers
  117. self._device_iter = 0
  118. self._input_indexs = ()
  119. self._output_types = None
  120. self._output_shapes = None
  121. self._dataset_size = None
  122. self._batch_size = None
  123. self._num_classes = None
  124. self._repeat_count = None
  125. self._sync = False
  126. self.ms_role = os.getenv("MS_ROLE")
  127. def _noop_mode(self):
  128. if self.ms_role in ("MS_PSERVER", "MS_SCHED"):
  129. return True
  130. return False
  131. def __add__(self, datasets):
  132. return self.concat(datasets)
  133. def get_args(self):
  134. """
  135. Return attributes (member variables) related to the current class.
  136. Must include all arguments passed to the __init__() of the current class, excluding 'input_dataset'.
  137. Args:
  138. Returns:
  139. Python dictionary.
  140. """
  141. args = dict()
  142. args["num_parallel_workers"] = self.num_parallel_workers
  143. return args
  144. @check_bucket_batch_by_length
  145. def bucket_batch_by_length(self, column_names, bucket_boundaries, bucket_batch_sizes,
  146. element_length_function=None, pad_info=None,
  147. pad_to_bucket_boundary=False, drop_remainder=False):
  148. """
  149. Bucket elements according to their lengths, and pad and batch the buckets when
  150. they are full.
  151. A length function is called on each row in the dataset, the row is then
  152. bucketed based on its length and bucket_boundaries. When a bucket reaches its
  153. corresponding size specified in bucket_batch_sizes, the entire bucket will be
  154. padded according to batch_info, and then batched. Each batch will be full,
  155. except for maybe the last batch for each bucket.
  156. Args:
  157. column_names (list of string): Columns passed to element_length_function.
  158. bucket_boundaries (list of int): A list consisting of the upper boundaries
  159. of the buckets. Must be strictly increasing. If there are n boundaries,
  160. n+1 buckets are created: One bucket for [0, bucket_boundaries[0]), one
  161. bucket for [bucket_boundaries[i], bucket_boundaries[i+1]) for each
  162. 0<i<n, and one bucket for [bucket_boundaries[n-1], inf).
  163. bucket_batch_sizes (list of int): A list consisting of the batch sizes for
  164. each bucket. Must contain len(bucket_boundaries)+1 elements.
  165. element_length_function (Callable, optional): A function that takes in
  166. len(column_names) arguments and returns an int. If no value is
  167. provided, then len(column_names) must be 1, and the size of the first
  168. dimension of that column will be taken as the length (default=None).
  169. pad_info (dict, optional): Represents how to batch each column. The key
  170. corresponds to the column name, the value must be a tuple of 2 elements.
  171. The first element corresponds to the shape to pad to, and the second
  172. element corresponds to the value to pad with. If a column is not
  173. specified, then that column will be padded to the longest in the current
  174. batch, and 0 will be used as the padding value. Any None dimensions will
  175. be padded to the longest in the current batch, unless if
  176. pad_to_bucket_boundary is True. If no padding is wanted, set pad_info
  177. to None (default=None).
  178. pad_to_bucket_boundary (bool, optional): If True, will pad each None
  179. dimension in pad_info to the bucket_boundary minus 1. If there are any
  180. elements that fall into the last bucket, an error will occur
  181. (default=False).
  182. drop_remainder (bool, optional): If True, will drop the last batch for each
  183. bucket if it is not a full batch (default=False).
  184. Examples:
  185. >>> import mindspore.dataset as ds
  186. >>> # data is an instance of Dataset object.
  187. >>>
  188. >>> # creates a dataset where every 100 rows is combined into a batch
  189. >>> # and drops the last incomplete batch if there is one.
  190. >>> column_names = ["col1", "col2"]
  191. >>> buket_boundaries = [5, 10]
  192. >>> bucket_batch_sizes = [5, 1, 1]
  193. >>> element_length_function = (lambda col1, col2: max(len(col1), len(col2)))
  194. >>>
  195. >>> # will pad col1 to shape [2, bucket_boundaries[i]] where i is the
  196. >>> # index of the bucket that is currently being batched.
  197. >>> # will pad col2 to a shape where each dimension is the longest in all
  198. >>> # the elements currently being batched.
  199. >>> pad_info = {"col1", ([2, None], -1)}
  200. >>> pad_to_bucket_boundary = True
  201. >>>
  202. >>> data = data.bucket_batch_by_length(column_names, bucket_boundaries,
  203. >>> bucket_batch_sizes,
  204. >>> element_length_function, pad_info,
  205. >>> pad_to_bucket_boundary)
  206. """
  207. return BucketBatchByLengthDataset(self, column_names, bucket_boundaries, bucket_batch_sizes,
  208. element_length_function, pad_info,
  209. pad_to_bucket_boundary, drop_remainder)
  210. @check_batch
  211. def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
  212. input_columns=None, pad_info=None):
  213. """
  214. Combine batch_size number of consecutive rows into batches.
  215. For any child node, a batch is treated as a single row.
  216. For any column, all the elements within that column must have the same shape.
  217. If a per_batch_map callable is provided, it will be applied to the batches of tensors.
  218. Note:
  219. The order of using repeat and batch reflects the number of batches. Recommend that
  220. repeat operation should be used after batch operation.
  221. Args:
  222. batch_size (int or function): The number of rows each batch is created with. An
  223. int or callable which takes exactly 1 parameter, BatchInfo.
  224. drop_remainder (bool, optional): Determines whether or not to drop the last
  225. possibly incomplete batch (default=False). If True, and if there are less
  226. than batch_size rows available to make the last batch, then those rows will
  227. be dropped and not propagated to the child node.
  228. num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel (default=None).
  229. per_batch_map (callable, optional): Per batch map callable. A callable which takes
  230. (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represent a batch of
  231. Tensors on a given column. The number of lists should match with number of entries in input_columns. The
  232. last parameter of the callable should always be a BatchInfo object.
  233. input_columns (list of string, optional): List of names of the input columns. The size of the list should
  234. match with signature of per_batch_map callable.
  235. pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)}
  236. would pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0.
  237. Returns:
  238. BatchDataset, dataset batched.
  239. Examples:
  240. >>> import mindspore.dataset as ds
  241. >>> # data is an instance of Dataset object.
  242. >>> # creates a dataset where every 100 rows is combined into a batch
  243. >>> # and drops the last incomplete batch if there is one.
  244. >>> data = data.batch(100, True)
  245. """
  246. return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns,
  247. pad_info)
  248. @check_sync_wait
  249. def sync_wait(self, condition_name, num_batch=1, callback=None):
  250. '''
  251. Add a blocking condition to the input Dataset.
  252. Args:
  253. num_batch (int): the number of batches without blocking at the start of each epoch.
  254. condition_name (str): The condition name that is used to toggle sending next row.
  255. callback (function): The callback funciton that will be invoked when sync_update is called.
  256. Raises:
  257. RuntimeError: If condition name already exists.
  258. Examples:
  259. >>> import mindspore.dataset as ds
  260. >>> # data is an instance of Dataset object.
  261. >>> data = data.sync_wait("callback1")
  262. >>> data = data.batch(batch_size)
  263. >>> for batch_data in data.create_dict_iterator():
  264. >>> data = data.sync_update("callback1")
  265. '''
  266. return SyncWaitDataset(self, condition_name, num_batch, callback)
  267. @check_shuffle
  268. def shuffle(self, buffer_size):
  269. """
  270. Randomly shuffles the rows of this dataset using the following algorithm:
  271. 1. Make a shuffle buffer that contains the first buffer_size rows.
  272. 2. Randomly select an element from the shuffle buffer to be the next row
  273. propogated to the child node.
  274. 3. Get the next row (if any) from the parent node and put it in the shuffle buffer.
  275. 4. Repeat steps 2 and 3 until there are no more rows left in the shuffle buffer.
  276. A seed can be provided to be used on the first epoch. In every subsequent
  277. epoch, the seed is changed to a new one, randomly generated value.
  278. Args:
  279. buffer_size (int): The size of the buffer (must be larger than 1) for
  280. shuffling. Setting buffer_size equal to the number of rows in the entire
  281. dataset will result in a global shuffle.
  282. Returns:
  283. ShuffleDataset, dataset shuffled.
  284. Raises:
  285. RuntimeError: If exist sync operators before shuffle.
  286. Examples:
  287. >>> import mindspore.dataset as ds
  288. >>> # data is an instance of Dataset object
  289. >>> # optionally set the seed for the first epoch
  290. >>> ds.config.set_seed(58)
  291. >>>
  292. >>> # creates a shuffled dataset using a shuffle buffer of size 4
  293. >>> data = data.shuffle(4)
  294. """
  295. return ShuffleDataset(self, buffer_size)
  296. def flat_map(self, func):
  297. """
  298. Map `func` to each row in dataset and flatten the result.
  299. The specified `func` is a function that must take one 'Ndarray' as input
  300. and return a 'Dataset'.
  301. Args:
  302. func (function): A function that must take one 'Ndarray' as an argument and
  303. return a 'Dataset'.
  304. Returns:
  305. Dataset, applied by the function.
  306. Examples:
  307. >>> import mindspore.dataset as ds
  308. >>> import mindspore.dataset.text as text
  309. >>> # declare a function which returns a Dataset object
  310. >>> def flat_map_func(x):
  311. >>> data_dir = text.to_str(x[0])
  312. >>> d = ds.ImageFolderDatasetV2(data_dir)
  313. >>> return d
  314. >>> # data is a Dataset object
  315. >>> data = ds.TextFileDataset(DATA_FILE)
  316. >>> data = data.flat_map(flat_map_func)
  317. Raises:
  318. TypeError: If `func` is not a function.
  319. TypeError: If `func` doesn't return a Dataset.
  320. """
  321. dataset = None
  322. if not hasattr(func, '__call__'):
  323. logger.error("func must be a function.")
  324. raise TypeError("func must be a function.")
  325. for row_data in self:
  326. if dataset is None:
  327. dataset = func(row_data)
  328. else:
  329. dataset += func(row_data)
  330. if not isinstance(dataset, Dataset):
  331. logger.error("flat_map must return a Dataset object.")
  332. raise TypeError("flat_map must return a Dataset object.")
  333. return dataset
  334. @check_map
  335. def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
  336. num_parallel_workers=None, python_multiprocessing=False, cache=None):
  337. """
  338. Apply each operation in operations to this dataset.
  339. The order of operations is determined by the position of each operation in operations.
  340. operations[0] will be applied first, then operations[1], then operations[2], etc.
  341. Each operation will be passed one or more columns from the dataset as input, and zero or
  342. more columns will be outputted. The first operation will be passed the columns specified
  343. in input_columns as input. If there is more than one operator in operations, the outputted
  344. columns of the previous operation are used as the input columns for the next operation.
  345. The columns outputted by the very last operation will be assigned names specified by
  346. output_columns.
  347. Only the columns specified in columns_order will be propagated to the child node. These
  348. columns will be in the same order as specified in columns_order.
  349. Args:
  350. input_columns (list[str]): List of the names of the columns that will be passed to
  351. the first operation as input. The size of this list must match the number of
  352. input columns expected by the first operator. (default=None, the first
  353. operation will be passed however many columns that is required, starting from
  354. the first column).
  355. operations (list[TensorOp] or Python list[functions]): List of operations to be
  356. applied on the dataset. Operations are applied in the order they appear in this list.
  357. output_columns (list[str], optional): List of names assigned to the columns outputted by
  358. the last operation. This parameter is mandatory if len(input_columns) !=
  359. len(output_columns). The size of this list must match the number of output
  360. columns of the last operation. (default=None, output columns will have the same
  361. name as the input columns, i.e., the columns will be replaced).
  362. columns_order (list[str], optional): list of all the desired columns to propagate to the
  363. child node. This list must be a subset of all the columns in the dataset after
  364. all operations are applied. The order of the columns in each row propagated to the
  365. child node follow the order they appear in this list. The parameter is mandatory
  366. if the len(input_columns) != len(output_columns). (default=None, all columns
  367. will be propagated to the child node, the order of the columns will remain the
  368. same).
  369. num_parallel_workers (int, optional): Number of threads used to process the dataset in
  370. parallel (default=None, the value from the config will be used).
  371. python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
  372. option could be beneficial if the python operation is computational heavy (default=False).
  373. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used)
  374. Returns:
  375. MapDataset, dataset after mapping operation.
  376. Examples:
  377. >>> import mindspore.dataset as ds
  378. >>> import mindspore.dataset.transforms.vision.c_transforms as c_transforms
  379. >>>
  380. >>> # data is an instance of Dataset which has 2 columns, "image" and "label".
  381. >>> # ds_pyfunc is an instance of Dataset which has 3 columns, "col0", "col1", and "col2". Each column is
  382. >>> # a 2d array of integers.
  383. >>>
  384. >>> # This config is a global setting, meaning that all future operations which
  385. >>> # uses this config value will use 2 worker threads, unless if specified
  386. >>> # otherwise in their constructor. set_num_parallel_workers can be called
  387. >>> # again later if a different number of worker threads are needed.
  388. >>> ds.config.set_num_parallel_workers(2)
  389. >>>
  390. >>> # Two operations, which takes 1 column for input and outputs 1 column.
  391. >>> decode_op = c_transforms.Decode(rgb_format=True)
  392. >>> random_jitter_op = c_transforms.RandomColorAdjust((0.8, 0.8), (1, 1), (1, 1), (0, 0))
  393. >>>
  394. >>> # 1) Simple map example
  395. >>>
  396. >>> operations = [decode_op]
  397. >>> input_columns = ["image"]
  398. >>>
  399. >>> # Applies decode_op on column "image". This column will be replaced by the outputed
  400. >>> # column of decode_op. Since columns_order is not provided, both columns "image"
  401. >>> # and "label" will be propagated to the child node in their original order.
  402. >>> ds_decoded = data.map(input_columns, operations)
  403. >>>
  404. >>> # Rename column "image" to "decoded_image"
  405. >>> output_columns = ["decoded_image"]
  406. >>> ds_decoded = data.map(input_columns, operations, output_columns)
  407. >>>
  408. >>> # Specify the order of the columns.
  409. >>> columns_order ["label", "image"]
  410. >>> ds_decoded = data.map(input_columns, operations, None, columns_order)
  411. >>>
  412. >>> # Rename column "image" to "decoded_image" and also specify the order of the columns.
  413. >>> columns_order ["label", "decoded_image"]
  414. >>> output_columns = ["decoded_image"]
  415. >>> ds_decoded = data.map(input_columns, operations, output_columns, columns_order)
  416. >>>
  417. >>> # Rename column "image" to "decoded_image" and keep only this column.
  418. >>> columns_order ["decoded_image"]
  419. >>> output_columns = ["decoded_image"]
  420. >>> ds_decoded = data.map(input_columns, operations, output_columns, columns_order)
  421. >>>
  422. >>> # Simple example using pyfunc. Renaming columns and specifying column order
  423. >>> # work in the same way as the previous examples.
  424. >>> input_columns = ["col0"]
  425. >>> operations = [(lambda x: x + 1)]
  426. >>> ds_mapped = ds_pyfunc.map(input_columns, operations)
  427. >>>
  428. >>> # 2) Map example with more than one operation
  429. >>>
  430. >>> # If this list of operations is used with map, decode_op will be applied
  431. >>> # first, then random_jitter_op will be applied.
  432. >>> operations = [decode_op, random_jitter_op]
  433. >>>
  434. >>> input_columns = ["image"]
  435. >>>
  436. >>> # Creates a dataset where the images are decoded, then randomly color jittered.
  437. >>> # decode_op takes column "image" as input and outputs one column. The column
  438. >>> # outputted by decode_op is passed as input to random_jitter_op.
  439. >>> # random_jitter_op will output one column. Column "image" will be replaced by
  440. >>> # the column outputted by random_jitter_op (the very last operation). All other
  441. >>> # columns are unchanged. Since columns_order is not specified, the order of the
  442. >>> # columns will remain the same.
  443. >>> ds_mapped = data.map(input_columns, operations)
  444. >>>
  445. >>> # Creates a dataset that is identical to ds_mapped, except the column "image"
  446. >>> # that is outputted by random_jitter_op is renamed to "image_transformed".
  447. >>> # Specifying column order works in the same way as examples in 1).
  448. >>> output_columns = ["image_transformed"]
  449. >>> ds_mapped_and_renamed = data.map(input_columns, operation, output_columns)
  450. >>>
  451. >>> # Multiple operations using pyfunc. Renaming columns and specifying column order
  452. >>> # work in the same way as examples in 1).
  453. >>> input_columns = ["col0"]
  454. >>> operations = [(lambda x: x + x), (lambda x: x - 1)]
  455. >>> output_columns = ["col0_mapped"]
  456. >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns)
  457. >>>
  458. >>> # 3) Example where number of input columns is not equal to number of output columns
  459. >>>
  460. >>> # operations[0] is a lambda that takes 2 columns as input and outputs 3 columns.
  461. >>> # operations[1] is a lambda that takes 3 columns as input and outputs 1 column.
  462. >>> # operations[1] is a lambda that takes 1 column as input and outputs 4 columns.
  463. >>> #
  464. >>> # Note: the number of output columns of operation[i] must equal the number of
  465. >>> # input columns of operation[i+1]. Otherwise, this map call will also result
  466. >>> # in an error.
  467. >>> operations = [(lambda x y: (x, x + y, x + y + 1)),
  468. >>> (lambda x y z: x * y * z),
  469. >>> (lambda x: (x % 2, x % 3, x % 5, x % 7))]
  470. >>>
  471. >>> # Note: because the number of input columns is not the same as the number of
  472. >>> # output columns, the output_columns and columns_order parameter must be
  473. >>> # specified. Otherwise, this map call will also result in an error.
  474. >>> input_columns = ["col2", "col0"]
  475. >>> output_columns = ["mod2", "mod3", "mod5", "mod7"]
  476. >>>
  477. >>> # Propagate all columns to the child node in this order:
  478. >>> columns_order = ["col0", "col2", "mod2", "mod3", "mod5", "mod7", "col1"]
  479. >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
  480. >>>
  481. >>> # Propagate some columns to the child node in this order:
  482. >>> columns_order = ["mod7", "mod3", "col1"]
  483. >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
  484. """
  485. return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers,
  486. python_multiprocessing, cache)
  487. @check_filter
  488. def filter(self, predicate, input_columns=None, num_parallel_workers=1):
  489. """
  490. Filter dataset by predicate.
  491. Note:
  492. If input_columns not provided or empty, all columns will be used.
  493. Args:
  494. predicate(callable): python callable which returns a boolean value, if False then filter the element.
  495. input_columns: (list[str], optional): List of names of the input columns, when
  496. default=None, the predicate will be applied on all columns in the dataset.
  497. num_parallel_workers (int, optional): Number of workers to process the Dataset
  498. in parallel (default=None).
  499. Returns:
  500. FilterDataset, dataset filter.
  501. Examples:
  502. >>> import mindspore.dataset as ds
  503. >>> # generator data(0 ~ 63)
  504. >>> # filter the data that greater than or equal to 11
  505. >>> dataset_f = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"])
  506. """
  507. return FilterDataset(self, predicate, input_columns, num_parallel_workers)
  508. @check_repeat
  509. def repeat(self, count=None):
  510. """
  511. Repeat this dataset count times. Repeat indefinitely if the count is None or -1.
  512. Note:
  513. The order of using repeat and batch reflects the number of batches. Recommend that
  514. repeat operation should be used after batch operation.
  515. If dataset_sink_mode is False, here repeat operation is invalid.
  516. If dataset_sink_mode is True, repeat count should be equal to the epoch of training. Otherwise,
  517. errors could occur since the amount of data is not the amount training requires.
  518. Args:
  519. count (int): Number of times the dataset should be repeated (default=None).
  520. Returns:
  521. RepeatDataset, dataset repeated.
  522. Examples:
  523. >>> import mindspore.dataset as ds
  524. >>> # data is an instance of Dataset object.
  525. >>> # creates a dataset where the dataset is repeated for 50 epochs
  526. >>> repeated = data.repeat(50)
  527. >>>
  528. >>> # creates a dataset where each epoch is shuffled individually
  529. >>> shuffled_and_repeated = data.shuffle(10)
  530. >>> shuffled_and_repeated = shuffled_and_repeated.repeat(50)
  531. >>>
  532. >>> # creates a dataset where the dataset is first repeated for
  533. >>> # 50 epochs before shuffling. the shuffle operator will treat
  534. >>> # the entire 50 epochs as one big dataset.
  535. >>> repeat_and_shuffle = data.repeat(50)
  536. >>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10)
  537. """
  538. if count == 1:
  539. return self
  540. return RepeatDataset(self, count)
  541. @check_skip
  542. def skip(self, count):
  543. """
  544. Skip the first N elements of this dataset.
  545. Args:
  546. count (int): Number of elements the dataset should be skipped.
  547. Returns:
  548. SkipDataset, dataset skipped.
  549. Examples:
  550. >>> import mindspore.dataset as ds
  551. >>> # data is an instance of Dataset object.
  552. >>> # creates a dataset which skips first 3 elements from data
  553. >>> data = data.skip(3)
  554. """
  555. return SkipDataset(self, count)
  556. @check_take
  557. def take(self, count=-1):
  558. """
  559. Takes at most given numbers of elements from the dataset.
  560. Note:
  561. 1. If count is greater than the number of element in dataset or equal to -1,
  562. all the element in dataset will be taken.
  563. 2. The order of using take and batch effects. If take before batch operation,
  564. then taken given number of rows, otherwise take given number of batches.
  565. Args:
  566. count (int, optional): Number of elements to be taken from the dataset (default=-1).
  567. Returns:
  568. TakeDataset, dataset taken.
  569. Examples:
  570. >>> import mindspore.dataset as ds
  571. >>> # data is an instance of Dataset object.
  572. >>> # creates a dataset where the dataset including 50 elements.
  573. >>> data = data.take(50)
  574. """
  575. if count == -1:
  576. return self
  577. return TakeDataset(self, count)
  578. def _get_absolute_split_sizes(self, sizes):
  579. """
  580. Internal method called by split to calculate absolute split sizes and to
  581. do some error checking after calculating absolute split sizes.
  582. """
  583. # call get_dataset_size here and check input here because
  584. # dont want to call this once in check_split and another time in
  585. # here again
  586. dataset_size = self.get_dataset_size()
  587. if dataset_size is None or dataset_size <= 0:
  588. raise RuntimeError("dataset_size is unknown, unable to split.")
  589. if not isinstance(sizes, list):
  590. raise RuntimeError("sizes should be a list.")
  591. all_int = all(isinstance(item, int) for item in sizes)
  592. if all_int:
  593. sizes_sum = sum(sizes)
  594. if sizes_sum != dataset_size:
  595. raise RuntimeError("Sum of split sizes {} is not equal to dataset size {}."
  596. .format(sizes_sum, dataset_size))
  597. return sizes
  598. absolute_sizes = []
  599. for item in sizes:
  600. absolute_size = int(round(item * dataset_size))
  601. if absolute_size == 0:
  602. raise RuntimeError("Split percentage {} is too small.".format(item))
  603. absolute_sizes.append(absolute_size)
  604. absolute_sizes_sum = sum(absolute_sizes)
  605. # if we still need more rows, give them to the first split.
  606. # if we have too many rows, remove the extras from the first split that has
  607. # enough rows.
  608. size_difference = int(dataset_size - absolute_sizes_sum)
  609. if size_difference > 0:
  610. absolute_sizes[0] += size_difference
  611. else:
  612. for i, _ in enumerate(absolute_sizes):
  613. if absolute_sizes[i] + size_difference > 0:
  614. absolute_sizes[i] += size_difference
  615. break
  616. if sum(absolute_sizes) != dataset_size:
  617. raise RuntimeError("Sum of calculated split sizes {} is not equal to dataset size {}."
  618. .format(absolute_sizes_sum, dataset_size))
  619. return absolute_sizes
  620. @check_split
  621. def split(self, sizes, randomize=True):
  622. """
  623. Split the dataset into smaller, non-overlapping datasets.
  624. This is a general purpose split function which can be called from any operator in the pipeline.
  625. There is another, optimized split function, which will be called automatically if ds.split is
  626. called where ds is a MappableDataset.
  627. Args:
  628. sizes (list of int or list of float): If a list of integers [s1, s2, …, sn] is
  629. provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
  630. respectively. If the sum of all sizes does not equal the original dataset size, an
  631. an error will occur.
  632. If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1
  633. and must sum to 1, otherwise an error will occur. The dataset will be split into n
  634. Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the
  635. original dataset.
  636. If after rounding:
  637. - Any size equals 0, an error will occur.
  638. - The sum of split sizes < K, the difference will be added to the first split.
  639. - The sum of split sizes > K, the difference will be removed from the first large
  640. enough split such that it will have atleast 1 row after removing the difference.
  641. randomize (bool, optional): determines whether or not to split the data randomly (default=True).
  642. If true, the data will be randomly split. Otherwise, each split will be created with
  643. consecutive rows from the dataset.
  644. Note:
  645. 1. Dataset cannot be sharded if split is going to be called.
  646. 2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
  647. Shuffling the dataset may not be deterministic, which means the data in each split
  648. will be different in each epoch.
  649. Raises:
  650. RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
  651. RuntimeError: If sizes is list of integers and sum of all elements in sizes does not
  652. equal the dataset size.
  653. RuntimeError: If sizes is list of float and there is a split with size 0 after calculations.
  654. RuntimeError: If the dataset is sharded prior to calling split.
  655. ValueError: If sizes is list of float and not all floats are between 0 and 1, or if the
  656. floats don’t sum to 1.
  657. Returns
  658. tuple(Dataset), a tuple of datasets that have been split.
  659. Examples:
  660. >>> import mindspore.dataset as ds
  661. >>>
  662. >>> dataset_dir = "/path/to/text_file.txt"
  663. >>>
  664. >>> # TextFileDataset is not a mappable dataset, so this non optimized split will be called.
  665. >>> # many datasets have shuffle on by default, set shuffle to False if split will be called!
  666. >>> data = ds.TextFileDataset(dataset_dir, shuffle=False)
  667. >>> train, test = data.split([0.9, 0.1])
  668. """
  669. if self.is_shuffled():
  670. logger.warning("Dataset is shuffled before split.")
  671. if self.is_sharded():
  672. raise RuntimeError("Dataset should not be sharded before split.")
  673. absolute_sizes = self._get_absolute_split_sizes(sizes)
  674. splits = []
  675. rows_to_skip = 0
  676. for size in absolute_sizes:
  677. ds = copy.deepcopy(self)
  678. if randomize:
  679. # want to shuffle the same way every epoch before split
  680. # in alter_tree, shuffle buffer is minimum 10000, so use 10000 here
  681. ds = ds.shuffle(10000)
  682. ds.reshuffle_each_epoch = False
  683. if rows_to_skip > 0:
  684. ds = ds.skip(rows_to_skip)
  685. ds = ds.take(size)
  686. splits.append(ds)
  687. rows_to_skip += size
  688. return tuple(splits)
  689. @check_zip_dataset
  690. def zip(self, datasets):
  691. """
  692. Zip the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name.
  693. Args:
  694. datasets (tuple or class Dataset): A tuple of datasets or a single class Dataset
  695. to be zipped together with this dataset.
  696. Returns:
  697. ZipDataset, dataset zipped.
  698. Examples:
  699. >>> import mindspore.dataset as ds
  700. >>> # ds1 and ds2 are instances of Dataset object
  701. >>> # creates a dataset which is the combination of ds1 and ds2
  702. >>> data = ds1.zip(ds2)
  703. """
  704. if isinstance(datasets, tuple):
  705. datasets = (self, *datasets)
  706. elif isinstance(datasets, Dataset):
  707. datasets = (self, datasets)
  708. else:
  709. raise TypeError("The zip function %s type error!" % (datasets))
  710. return ZipDataset(datasets)
  711. @check_concat
  712. def concat(self, datasets):
  713. """
  714. Concat the datasets in the input list of datasets, supported using "+" to reload concat operation.
  715. Note:
  716. The column name,column data type and rank of column data should be the same in input datasets.
  717. Args:
  718. datasets (list or class Dataset): A list of datasets or a single class Dataset
  719. to be concatenated together with this dataset.
  720. Returns:
  721. ConcatDataset, dataset concatenated.
  722. Examples:
  723. >>> import mindspore.dataset as ds
  724. >>> # ds1 and ds2 are instances of Dataset object
  725. >>> # creates a dataset by concating ds1 and ds2 with "+" operation
  726. >>> data1 = ds1 + ds2
  727. >>> # creates a dataset by concating ds1 and ds2 with concat operation
  728. >>> data1 = ds1.concat(ds2)
  729. """
  730. if isinstance(datasets, Dataset):
  731. datasets = [self] + [datasets]
  732. elif isinstance(datasets, list):
  733. datasets = [self] + datasets
  734. else:
  735. raise TypeError("The concat_dataset function %s type error!" % (datasets))
  736. return ConcatDataset(datasets)
  737. @check_rename
  738. def rename(self, input_columns, output_columns):
  739. """
  740. Rename the columns in input datasets.
  741. Args:
  742. input_columns (list[str]): list of names of the input columns.
  743. output_columns (list[str]): list of names of the output columns.
  744. Returns:
  745. RenameDataset, dataset renamed.
  746. Examples:
  747. >>> import mindspore.dataset as ds
  748. >>> # data is an instance of Dataset object.
  749. >>> input_columns = ["input_col1", "input_col2", "input_col3"]
  750. >>> output_columns = ["output_col1", "output_col2", "output_col3"]
  751. >>>
  752. >>> # creates a dataset where input_col1 is renamed to output_col1, and
  753. >>> # input_col2 is renamed to output_col2, and input_col3 is renamed
  754. >>> # to output_col3.
  755. >>> data = data.rename(input_columns=input_columns, output_columns=output_columns)
  756. """
  757. return RenameDataset(self, input_columns, output_columns)
  758. @check_project
  759. def project(self, columns):
  760. """
  761. Project certain columns in input datasets.
  762. The specified columns will be selected from the dataset and passed down
  763. the pipeline in the order specified. The other columns are discarded.
  764. Args:
  765. columns(list[str]): list of names of the columns to project.
  766. Returns:
  767. ProjectDataset, dataset projected.
  768. Examples:
  769. >>> import mindspore.dataset as ds
  770. >>> # data is an instance of Dataset object
  771. >>> columns_to_project = ["column3", "column1", "column2"]
  772. >>>
  773. >>> # creates a dataset that consist of column3, column1, column2
  774. >>> # in that order, regardless of the original order of columns.
  775. >>> data = data.project(columns=columns_to_project)
  776. """
  777. return ProjectDataset(self, columns)
  778. def build_vocab(self, vocab, columns, freq_range, top_k, special_tokens, special_first):
  779. return BuildVocabDataset(self, vocab, columns, freq_range, top_k, special_tokens, special_first)
  780. def apply(self, apply_func):
  781. """
  782. Apply a function in this dataset.
  783. The specified apply_func is a function that must take one 'Dataset' as an argument
  784. and return a preprogressing 'Dataset'.
  785. Args:
  786. apply_func (function): A function that must take one 'Dataset' as an argument and
  787. return a preprogressing 'Dataset'.
  788. Returns:
  789. Dataset, applied by the function.
  790. Examples:
  791. >>> import mindspore.dataset as ds
  792. >>> # data is an instance of Dataset object
  793. >>> # declare an apply_func function which returns a Dataset object
  794. >>> def apply_func(ds):
  795. >>> ds = ds.batch(2)
  796. >>> return ds
  797. >>> # use apply to call apply_func
  798. >>> data = data.apply(apply_func)
  799. Raises:
  800. TypeError: If apply_func is not a function.
  801. TypeError: If apply_func doesn't return a Dataset.
  802. """
  803. if not hasattr(apply_func, '__call__'):
  804. raise TypeError("apply_func must be a function.")
  805. dataset = apply_func(self)
  806. if not isinstance(dataset, Dataset):
  807. raise TypeError("apply_func must return a dataset.")
  808. return dataset
  809. def device_que(self, prefetch_size=None, send_epoch_end=True):
  810. """
  811. Return a transferredDataset that transfer data through device.
  812. Args:
  813. prefetch_size (int, optional): prefetch number of records ahead of the
  814. user's request (default=None).
  815. send_epoch_end (bool, optional): whether send end of sequence to device or not.(default=True)
  816. Note:
  817. If device is Ascend, features of data will be transferred one by one. The limitation
  818. of data transmission per time is 256M.
  819. Return:
  820. TransferDataset, dataset for transferring.
  821. """
  822. return self.to_device(send_epoch_end=send_epoch_end)
  823. def to_device(self, send_epoch_end=True):
  824. """
  825. Transfer data through CPU, GPU or Ascend devices.
  826. Args:
  827. send_epoch_end (bool, optional): whether send end of sequence to device or not.(default=True)
  828. Note:
  829. If device is Ascend, features of data will be transferred one by one. The limitation
  830. of data transmission per time is 256M.
  831. Returns:
  832. TransferDataset, dataset for transferring.
  833. Raises:
  834. TypeError: If device_type is empty.
  835. ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'.
  836. RuntimeError: If dataset is unknown.
  837. RuntimeError: If distribution file path is given but failed to read.
  838. """
  839. queue_name = str(uuid.uuid1())
  840. if context:
  841. device_type = context.get_context("device_target")
  842. else:
  843. device_type = "CPU"
  844. if device_type == "":
  845. raise TypeError("Please set device_type in context")
  846. if device_type not in ('Ascend', 'GPU', 'CPU'):
  847. raise ValueError("Only support CPU, Ascend, GPU")
  848. def get_distribution(output_dataset):
  849. dev_id = 0
  850. if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2,
  851. ManifestDataset, MnistDataset, VOCDataset, CocoDataset, CelebADataset,
  852. MindDataset)):
  853. sampler = output_dataset.sampler
  854. if isinstance(sampler, samplers.DistributedSampler):
  855. dev_id = sampler.shard_id
  856. return "", dev_id
  857. if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset)):
  858. if output_dataset.shard_id is not None:
  859. dev_id = output_dataset.shard_id
  860. return "", dev_id
  861. if not output_dataset.children:
  862. raise RuntimeError("Unknown output_dataset: {}".format(type(output_dataset)))
  863. input_dataset = output_dataset.children[0]
  864. return get_distribution(input_dataset)
  865. distribution_path, device_id = get_distribution(self)
  866. if distribution_path == "":
  867. return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end)
  868. try:
  869. with open(distribution_path, 'r') as distribution_f:
  870. dist = json.load(distribution_f)
  871. device_id = dist["deviceId"]
  872. except json.decoder.JSONDecodeError:
  873. raise RuntimeError("Json decode error when load distribution file")
  874. except Exception:
  875. raise RuntimeError("Distribution file failed to read")
  876. return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end)
  877. @check_save
  878. def save(self, file_name, num_files=1, file_type='mindrecord'):
  879. """
  880. Save the dynamic data processed by dataset pipeline as common dataset format, support: mindrecord.
  881. Note:
  882. 1. To save the samples in order, should set dataset's shuffle false and num_files 1.
  883. 2. Before call the function, do not use batch, repeat operator or data augmentation operators
  884. with random attribute in map operator.
  885. 3. Mindreocrd do not support np.uint64, multi-dimensional np.uint8(drop dimension) and
  886. multi-dimensional string.
  887. Args:
  888. file_name (str): Path to dataset file.
  889. num_files (int, optional): Number of dataset files.(default=1).
  890. file_type (str, optional): dataset format.(default='mindrecord')
  891. """
  892. if num_files == 1:
  893. file_names = [file_name]
  894. else:
  895. suffix = len(str(num_files - 1))
  896. file_names = ["{}{}".format(file_name, str(x).rjust(suffix, '0'))
  897. for x in range(num_files)]
  898. return SaveOp(self).save(file_names, file_type)
  899. def create_tuple_iterator(self, columns=None, num_epochs=-1):
  900. """
  901. Create an Iterator over the dataset. The data retrieved will be a list of ndarray of data.
  902. To specify which columns to list and the order needed, use columns_list. If columns_list
  903. is not provided, the order of the columns will not be changed.
  904. Args:
  905. columns (list[str], optional): List of columns to be used to specify the order of columns
  906. (default=None, means all columns).
  907. Returns:
  908. Iterator, list of ndarray.
  909. Examples:
  910. >>> import mindspore.dataset as ds
  911. >>> # data is an instance of Dataset object
  912. >>> # creates an iterator. The columns in the data obtained by the
  913. >>> # iterator will not be changed.
  914. >>> iterator = data.create_tuple_iterator()
  915. >>> for item in iterator:
  916. >>> # convert the returned tuple to a list and print
  917. >>> print(list(item))
  918. """
  919. if self._noop_mode():
  920. return DummyIterator(self, 'tuple')
  921. return TupleIterator(self, columns, num_epochs)
  922. def create_dict_iterator(self, num_epochs=-1):
  923. """
  924. Create an Iterator over the dataset.
  925. The data retrieved will be a dictionary. The order
  926. of the columns in the dictionary may not be the same as the original order.
  927. Returns:
  928. Iterator, dictionary of column_name-ndarray pair.
  929. Examples:
  930. >>> import mindspore.dataset as ds
  931. >>> # data is an instance of Dataset object
  932. >>> # creates an iterator. The columns in the data obtained by the
  933. >>> # iterator might be changed.
  934. >>> iterator = data.create_dict_iterator()
  935. >>> for item in iterator:
  936. >>> # print the data in column1
  937. >>> print(item["column1"])
  938. """
  939. if self._noop_mode():
  940. return DummyIterator(self, 'dict')
  941. return DictIterator(self, num_epochs)
  942. def __iter__(self):
  943. """Create an Iterator over the dataset."""
  944. return self.create_tuple_iterator()
  945. @property
  946. def input_indexs(self):
  947. return self._input_indexs
  948. @input_indexs.setter
  949. def input_indexs(self, value):
  950. self._input_indexs = value
  951. def _get_pipeline_info(self):
  952. """
  953. Get pipeline information.
  954. """
  955. device_iter = TupleIterator(self)
  956. self._output_shapes = device_iter.get_output_shapes()
  957. self._output_types = device_iter.get_output_types()
  958. if self._dataset_size is None:
  959. self._dataset_size = device_iter.get_dataset_size()
  960. self._batch_size = device_iter.get_batch_size()
  961. self._num_classes = device_iter.num_classes()
  962. self._repeat_count = device_iter.get_repeat_count()
  963. device_iter.stop()
  964. def output_shapes(self):
  965. """
  966. Get the shapes of output data.
  967. Return:
  968. List, list of shape of each column.
  969. """
  970. if self._output_shapes is None:
  971. self._get_pipeline_info()
  972. return self._output_shapes
  973. def output_types(self):
  974. """
  975. Get the types of output data.
  976. Return:
  977. List of data type.
  978. """
  979. if self._output_types is None:
  980. self._get_pipeline_info()
  981. return self._output_types
  982. def get_dataset_size(self):
  983. """
  984. Get the number of batches in an epoch.
  985. Return:
  986. Number, number of batches.
  987. """
  988. if self.children:
  989. return self.children[0].get_dataset_size()
  990. return None
  991. def num_classes(self):
  992. """
  993. Get the number of classes in a dataset.
  994. Return:
  995. Number, number of classes.
  996. """
  997. if self.children:
  998. return self.children[0].num_classes()
  999. return None
  1000. def get_sync_notifiers(self):
  1001. if self.children:
  1002. return self.children[0].get_sync_notifiers()
  1003. return {}
  1004. def disable_sync(self):
  1005. if self.children:
  1006. return self.children[0].disable_sync()
  1007. return {}
  1008. def is_sync(self):
  1009. if self.children:
  1010. return self.children[0].is_sync()
  1011. return False
  1012. def sync_update(self, condition_name, num_batch=None, data=None):
  1013. """
  1014. Release a blocking condition and trigger callback with given data.
  1015. Args:
  1016. condition_name (str): The condition name that is used to toggle sending next row.
  1017. num_batch (int or None): The number of batches(rows) that are released.
  1018. When num_batch is None, it will default to the number specified by the
  1019. sync_wait operator (default=None).
  1020. data (dict or None): The data passed to the callback (default=None).
  1021. """
  1022. if isinstance(num_batch, int) and num_batch <= 0:
  1023. # throwing exception, disable all sync_wait in pipeline
  1024. self.disable_sync()
  1025. raise RuntimeError("Sync_update batch size can only be positive, got : {}".format(num_batch))
  1026. notifiers_dict = self.get_sync_notifiers()
  1027. if condition_name not in notifiers_dict:
  1028. # throwing exception, disable all sync_wait in pipeline
  1029. self.disable_sync()
  1030. raise RuntimeError("Condition name not found")
  1031. if num_batch is not None:
  1032. num_batch *= self.get_batch_size()
  1033. notifiers_dict[condition_name](num_batch, data)
  1034. def get_batch_size(self):
  1035. """
  1036. Get the size of a batch.
  1037. Return:
  1038. Number, the number of data in a batch.
  1039. """
  1040. if self.children:
  1041. return self.children[0].get_batch_size()
  1042. return 1
  1043. def get_repeat_count(self):
  1044. """
  1045. Get the replication times in RepeatDataset else 1.
  1046. Return:
  1047. Number, the count of repeat.
  1048. """
  1049. if self.children:
  1050. return self.children[0].get_repeat_count()
  1051. return 1
  1052. def get_class_indexing(self):
  1053. """
  1054. Get the class index.
  1055. Return:
  1056. Dict, A str-to-int mapping from label name to index.
  1057. """
  1058. if self.children:
  1059. return self.children[0].get_class_indexing()
  1060. raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self)))
  1061. def reset(self):
  1062. """Reset the dataset for next epoch."""
  1063. def is_shuffled(self):
  1064. for input_dataset in self.children:
  1065. if input_dataset.is_shuffled():
  1066. return True
  1067. return False
  1068. def is_sharded(self):
  1069. for input_dataset in self.children:
  1070. if input_dataset.is_sharded():
  1071. return True
  1072. return False
  1073. class SourceDataset(Dataset):
  1074. """
  1075. Abstract class to represent a source dataset which produces content to the data pipeline.
  1076. """
  1077. # No need for __init__ since it is the same as the super's init
  1078. @staticmethod
  1079. def _find_files(patterns):
  1080. """
  1081. Utility function to search for files with the given glob patterns.
  1082. Args:
  1083. patterns (str or list[str]): string or list of patterns to be searched.
  1084. Returns:
  1085. List, files.
  1086. """
  1087. if not isinstance(patterns, list):
  1088. patterns = [patterns]
  1089. file_list = []
  1090. unmatched_patterns = []
  1091. for pattern in patterns:
  1092. matches = [match for match in glob.glob(pattern, recursive=True) if os.path.isfile(match)]
  1093. if matches:
  1094. file_list.extend(matches)
  1095. else:
  1096. unmatched_patterns.append(pattern)
  1097. if unmatched_patterns:
  1098. raise ValueError("The following patterns did not match any files: ", unmatched_patterns)
  1099. if file_list: # not empty
  1100. return file_list
  1101. raise ValueError("The list of path names matching the patterns is empty.")
  1102. def is_shuffled(self):
  1103. raise NotImplementedError("SourceDataset must implement is_shuffled.")
  1104. def is_sharded(self):
  1105. raise NotImplementedError("SourceDataset must implement is_sharded.")
  1106. class MappableDataset(SourceDataset):
  1107. """
  1108. Abstract class to represent a source dataset which supports use of samplers.
  1109. """
  1110. def __init__(self, num_parallel_workers=None):
  1111. # check if all subclasses use this name
  1112. super().__init__(num_parallel_workers)
  1113. self.sampler = None
  1114. def add_sampler(self, new_sampler):
  1115. # note: by adding a sampler, we mean that the sampled ids will flow to new_sampler
  1116. # after first passing through the current samplers attached to this dataset.
  1117. new_sampler.add_child(self.sampler)
  1118. self.sampler = new_sampler
  1119. def use_sampler(self, new_sampler):
  1120. """
  1121. Will make the current dataset use the new_sampler provided.
  1122. Args:
  1123. new_sampler (Sampler): the sampler to use for the current dataset.
  1124. Returns:
  1125. Dataset, that uses new_sampler.
  1126. Examples:
  1127. >>> import mindspore.dataset as ds
  1128. >>>
  1129. >>> dataset_dir = "/path/to/imagefolder_directory"
  1130. >>> # a SequentialSampler is created by default
  1131. >>> data = ds.ImageFolderDatasetV2(dataset_dir)
  1132. >>>
  1133. >>> # use a DistributedSampler instead of the SequentialSampler
  1134. >>> new_sampler = ds.DistributedSampler(10, 2)
  1135. >>> data.use_sampler(new_sampler)
  1136. """
  1137. if new_sampler is None:
  1138. raise TypeError("Input sampler can not be None.")
  1139. if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
  1140. raise TypeError("Input sampler is not an instance of a sampler.")
  1141. self.sampler = self.sampler.child_sampler
  1142. self.add_sampler(new_sampler)
  1143. def is_shuffled(self):
  1144. raise NotImplementedError("MappableDataset must implement is_shuffled.")
  1145. def is_sharded(self):
  1146. raise NotImplementedError("MappableDataset must implement is_sharded.")
  1147. def _get_sampler_dataset_size(self):
  1148. if self.sampler is not None:
  1149. if hasattr(self.sampler, 'get_num_samples'):
  1150. return self.sampler.get_num_samples()
  1151. if hasattr(self.sampler, '__len__'):
  1152. return len(self.sampler)
  1153. return None
  1154. @check_split
  1155. def split(self, sizes, randomize=True):
  1156. """
  1157. Split the dataset into smaller, non-overlapping datasets.
  1158. There is the optimized split function, which will be called automatically when the dataset
  1159. that calls this function is a MappableDataset.
  1160. Args:
  1161. sizes (list of int or list of float): If a list of integers [s1, s2, …, sn] is
  1162. provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
  1163. respectively. If the sum of all sizes does not equal the original dataset size, an
  1164. an error will occur.
  1165. If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1
  1166. and must sum to 1, otherwise an error will occur. The dataset will be split into n
  1167. Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the
  1168. original dataset.
  1169. If after rounding:
  1170. - Any size equals 0, an error will occur.
  1171. - The sum of split sizes < K, the difference will be added to the first split.
  1172. - The sum of split sizes > K, the difference will be removed from the first large
  1173. enough split such that it will have atleast 1 row after removing the difference.
  1174. randomize (bool, optional): determines whether or not to split the data randomly (default=True).
  1175. If true, the data will be randomly split. Otherwise, each split will be created with
  1176. consecutive rows from the dataset.
  1177. Note:
  1178. 1. Dataset should not be sharded if split is going to be called. Instead, create a
  1179. DistributedSampler and specify a split to shard after splitting. If dataset is
  1180. sharded after a split, it is strongly recommended to set the same seed in each instance
  1181. of execution, otherwise each shard may not be part of the same split (see Examples).
  1182. 2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
  1183. Shuffling the dataset may not be deterministic, which means the data in each split
  1184. will be different in each epoch. Furthermore, if sharding occurs after split, each
  1185. shard may not be part of the same split.
  1186. Raises:
  1187. RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
  1188. RuntimeError: If sizes is list of integers and sum of all elements in sizes does not
  1189. equal the dataset size.
  1190. RuntimeError: If sizes is list of float and there is a split with size 0 after calculations.
  1191. RuntimeError: If the dataset is sharded prior to calling split.
  1192. ValueError: If sizes is list of float and not all floats are between 0 and 1, or if the
  1193. floats don’t sum to 1.
  1194. Returns
  1195. tuple(Dataset), a tuple of datasets that have been split.
  1196. Examples:
  1197. >>> import mindspore.dataset as ds
  1198. >>>
  1199. >>> dataset_dir = "/path/to/imagefolder_directory"
  1200. >>>
  1201. >>> # many datasets have shuffle on by default, set shuffle to False if split will be called!
  1202. >>> data = ds.ImageFolderDatasetV2(dataset_dir, shuffle=False)
  1203. >>>
  1204. >>> # sets the seed, and tells split to use this seed when randomizing. This
  1205. >>> # is needed because we are sharding later
  1206. >>> ds.config.set_seed(58)
  1207. >>> train, test = data.split([0.9, 0.1])
  1208. >>>
  1209. >>> # if we want to shard the train dataset, we can use a DistributedSampler
  1210. >>> train_sampler = ds.DistributedSampler(10, 2)
  1211. >>> train.use_sampler(train_sampler)
  1212. """
  1213. if self.is_shuffled():
  1214. logger.warning("Dataset is shuffled before split.")
  1215. if self.is_sharded():
  1216. raise RuntimeError("Dataset should not be sharded before split.")
  1217. absolute_sizes = self._get_absolute_split_sizes(sizes)
  1218. splits = []
  1219. current_split_start_index = 0
  1220. for size in absolute_sizes:
  1221. ds = copy.deepcopy(self)
  1222. if randomize:
  1223. # want to shuffle the same way every epoch before split, we are assuming
  1224. # that the user will call set_seed
  1225. random_sampler = samplers.RandomSampler()
  1226. random_sampler.reshuffle_each_epoch = False
  1227. ds.add_sampler(random_sampler)
  1228. subset_sampler = samplers.SequentialSampler(current_split_start_index, size)
  1229. ds.add_sampler(subset_sampler)
  1230. # add sequential sampler, so that if user calls use_sampler, we will
  1231. # get rid of the sequential sampler instead of something we need
  1232. ds.add_sampler(samplers.SequentialSampler())
  1233. splits.append(ds)
  1234. current_split_start_index += size
  1235. return tuple(splits)
  1236. class DatasetOp(Dataset):
  1237. """
  1238. Abstract class to represent a operations on dataset.
  1239. """
  1240. # No need for __init__ since it is the same as the super's init
  1241. class BucketBatchByLengthDataset(DatasetOp):
  1242. """
  1243. The result of applying BucketBatchByLength operator to the input dataset.
  1244. """
  1245. def __init__(self, input_dataset, column_names, bucket_boundaries, bucket_batch_sizes,
  1246. element_length_function, pad_info, pad_to_bucket_boundary, drop_remainder):
  1247. super().__init__()
  1248. self.column_names = column_names
  1249. self.bucket_boundaries = bucket_boundaries
  1250. self.bucket_batch_sizes = bucket_batch_sizes
  1251. self.element_length_function = element_length_function
  1252. self.pad_info = pad_info
  1253. self.pad_to_bucket_boundary = pad_to_bucket_boundary
  1254. self.drop_remainder = drop_remainder
  1255. self.children.append(input_dataset)
  1256. input_dataset.parent.append(self)
  1257. self._input_indexs = input_dataset.input_indexs
  1258. def get_args(self):
  1259. args = super().get_args()
  1260. args["length_dependent_columns"] = self.column_names
  1261. args["bucket_boundaries"] = self.bucket_boundaries
  1262. args["bucket_batch_sizes"] = self.bucket_batch_sizes
  1263. args["element_length_function"] = self.element_length_function
  1264. args["pad_info"] = self.pad_info
  1265. args["pad_to_bucket_boundary"] = self.pad_to_bucket_boundary
  1266. args["drop_remainder"] = self.drop_remainder
  1267. return args
  1268. def get_dataset_size(self):
  1269. """
  1270. Get the number of batches in an epoch.
  1271. Return:
  1272. Number, number of batches.
  1273. """
  1274. return None
  1275. class BatchDataset(DatasetOp):
  1276. """
  1277. The result of applying Batch operator to the input dataset.
  1278. Args:
  1279. input_dataset (Dataset): Input Dataset to be batched.
  1280. batch_size (int or function): The number of rows each batch is created with. An
  1281. int or callable which takes exactly 1 parameter, BatchInfo.
  1282. drop_remainder (bool, optional): Determines whether or not to drop the last
  1283. possibly incomplete batch (default=False). If True, and if there are less
  1284. than batch_size rows available to make the last batch, then those rows will
  1285. be dropped and not propagated to the child node.
  1286. num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel (default=None).
  1287. per_batch_map (callable, optional): Per batch map callable. A callable which takes
  1288. (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represent a batch of
  1289. Tensors on a given column. The number of lists should match with number of entries in input_columns. The
  1290. last parameter of the callable should always be a BatchInfo object.
  1291. input_columns (list of string, optional): List of names of the input columns. The size of the list should
  1292. match with signature of per_batch_map callable.
  1293. pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)}
  1294. would pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0.
  1295. """
  1296. def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None,
  1297. per_batch_map=None, input_columns=None, pad_info=None):
  1298. super().__init__(num_parallel_workers)
  1299. if BatchDataset._is_ancestor_of_repeat(input_dataset):
  1300. logger.warning("Repeat is located before batch, data from two epochs can be batched together.")
  1301. BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
  1302. self.batch_size = batch_size
  1303. self.drop_remainder = drop_remainder
  1304. self.per_batch_map = per_batch_map
  1305. self.input_columns = input_columns
  1306. self.pad_info = pad_info
  1307. self.children.append(input_dataset)
  1308. input_dataset.parent.append(self)
  1309. self._input_indexs = input_dataset.input_indexs
  1310. def get_args(self):
  1311. args = super().get_args()
  1312. args["batch_size"] = self.batch_size
  1313. args["drop_remainder"] = self.drop_remainder
  1314. args["per_batch_map"] = self.per_batch_map
  1315. args["input_columns"] = self.input_columns
  1316. args["pad_info"] = self.pad_info
  1317. return args
  1318. def get_dataset_size(self):
  1319. """
  1320. Get the number of batches in an epoch.
  1321. Return:
  1322. Number, number of batches.
  1323. """
  1324. child_size = self.children[0].get_dataset_size()
  1325. if child_size is not None and isinstance(self.batch_size, int):
  1326. if self.drop_remainder:
  1327. return math.floor(child_size / self.batch_size)
  1328. return math.ceil(child_size / self.batch_size)
  1329. return None
  1330. def get_batch_size(self):
  1331. """
  1332. Get the size of a batch.
  1333. Return:
  1334. Number, the number of data in a batch.
  1335. """
  1336. return self.batch_size
  1337. @staticmethod
  1338. def _is_ancestor_of_repeat(dataset):
  1339. """
  1340. Utility function to find the case where repeat is used before batch.
  1341. Args:
  1342. dataset (Dataset): dataset to be checked.
  1343. Return:
  1344. True or False.
  1345. """
  1346. if isinstance(dataset, RepeatDataset):
  1347. return True
  1348. flag = False
  1349. for input_dataset in dataset.children:
  1350. flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset)
  1351. return flag
  1352. @staticmethod
  1353. def _update_batch_size_for_syncwait(dataset, batch_size):
  1354. """
  1355. Utility function to notify batch size to sync_wait.
  1356. Args:
  1357. dataset (Dataset): dataset to be checked.
  1358. batch_size (int): batch size to notify.
  1359. """
  1360. if isinstance(dataset, SyncWaitDataset):
  1361. dataset.update_sync_batch_size(batch_size)
  1362. for input_dataset in dataset.children:
  1363. BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
  1364. class BatchInfo(CBatchInfo):
  1365. """
  1366. The information object associates with the current batch of tensors.
  1367. """
  1368. def get_batch_num(self):
  1369. """
  1370. Return the batch number of the current batch.
  1371. Return:
  1372. Number, number of the current batch.
  1373. """
  1374. return
  1375. def get_epoch_num(self):
  1376. """
  1377. Return the epoch number of the current batch.
  1378. Return:
  1379. Number, number of the current epoch.
  1380. """
  1381. return
  1382. class BlockReleasePair:
  1383. """
  1384. The blocking condition class used by SyncWaitDataset.
  1385. Args:
  1386. init_release_rows (int): Number of lines to allow through the pipeline.
  1387. callback (function): The callback function that will be called when release is called.
  1388. """
  1389. def __init__(self, init_release_rows, callback=None):
  1390. if isinstance(init_release_rows, int) and init_release_rows <= 0:
  1391. raise ValueError("release_rows need to be greater than 0.")
  1392. self.row_count = -init_release_rows
  1393. self.cv = threading.Condition()
  1394. self.callback = callback
  1395. self.default_rows = init_release_rows
  1396. self.disable = False
  1397. def __deepcopy__(self, memodict):
  1398. if id(self) in memodict:
  1399. return memodict[id(self)]
  1400. memodict[id(self)] = self
  1401. # condition variable and callback are the same, but reset the counter
  1402. self.reset()
  1403. return self
  1404. def reset(self):
  1405. with self.cv:
  1406. self.row_count = -self.default_rows
  1407. self.cv.notify_all()
  1408. def update_batched_size(self, batch_size):
  1409. # sanity check
  1410. if isinstance(batch_size, int) and batch_size <= 0:
  1411. raise ValueError("batch_size need to be greater than 0.")
  1412. # should only use before the pipeline creates
  1413. self.row_count *= batch_size
  1414. self.default_rows *= batch_size
  1415. def block_func(self):
  1416. with self.cv:
  1417. # if disable is true, the always evaluate to true
  1418. self.cv.wait_for(lambda: (self.row_count < 0 or self.disable))
  1419. self.row_count += 1
  1420. return True
  1421. def release_func(self, pass_rows=None, data=None):
  1422. with self.cv:
  1423. if pass_rows is None:
  1424. pass_rows = self.default_rows
  1425. self.row_count -= pass_rows
  1426. if self.callback is not None:
  1427. self.callback(data)
  1428. self.cv.notify_all()
  1429. def disable_lock(self):
  1430. with self.cv:
  1431. self.disable = True
  1432. self.cv.notify_all()
  1433. class SyncWaitDataset(DatasetOp):
  1434. """
  1435. The result of adding a blocking condition to the input Dataset.
  1436. Args:
  1437. input_dataset (Dataset): Input dataset to apply flow control.
  1438. num_batch (int): the number of batches without blocking at the start of each epoch.
  1439. condition_name (str): The condition name that is used to toggle sending next row.
  1440. callback (function): The callback function that will be invoked when sync_update is called.
  1441. Raises:
  1442. RuntimeError: If condition name already exists.
  1443. """
  1444. def __init__(self, input_dataset, condition_name, num_batch, callback=None):
  1445. super().__init__()
  1446. self.children.append(input_dataset)
  1447. input_dataset.parent.append(self)
  1448. # set to the default value, waiting for the batch to update it
  1449. self._condition_name = condition_name
  1450. if isinstance(num_batch, int) and num_batch <= 0:
  1451. raise ValueError("num_batch need to be greater than 0.")
  1452. self._pair = BlockReleasePair(num_batch, callback)
  1453. if self._condition_name in self.children[0].get_sync_notifiers():
  1454. raise RuntimeError("Condition name is already in use")
  1455. logger.warning("Please remember to add dataset.sync_update(condition=%s), otherwise will result in hanging",
  1456. condition_name)
  1457. def get_sync_notifiers(self):
  1458. return {**self.children[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}}
  1459. def is_sync(self):
  1460. return True
  1461. def get_args(self):
  1462. args = super().get_args()
  1463. args["condition_name"] = self._condition_name
  1464. args["condition_func"] = self._pair.block_func
  1465. return args
  1466. def update_sync_batch_size(self, batch_size):
  1467. if isinstance(batch_size, int) and batch_size <= 0:
  1468. raise ValueError("num_batch need to be greater than 0.")
  1469. self._pair.update_batched_size(batch_size)
  1470. def disable_sync(self):
  1471. logger.info("Disabling Sync")
  1472. self._pair.disable_lock()
  1473. @staticmethod
  1474. def _is_ancestor_of_batch(dataset):
  1475. """
  1476. Utility function to find the case where sync_wait is used before batch.
  1477. Args:
  1478. dataset (Dataset): dataset to be checked.
  1479. Return:
  1480. True or False.
  1481. """
  1482. if isinstance(dataset, BatchDataset):
  1483. return True
  1484. flag = False
  1485. for input_dataset in dataset.children:
  1486. flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset)
  1487. return flag
  1488. class ShuffleDataset(DatasetOp):
  1489. """
  1490. The result of applying Shuffle operator to the input Dataset.
  1491. Args:
  1492. input_dataset (Dataset): Input Dataset to be shuffled.
  1493. buffer_size (int): The size of the buffer.
  1494. Raises:
  1495. RuntimeError: If exist sync operators before shuffle.
  1496. """
  1497. def __init__(self, input_dataset, buffer_size):
  1498. super().__init__()
  1499. self.buffer_size = buffer_size
  1500. self.children.append(input_dataset)
  1501. self.reshuffle_each_epoch = None
  1502. input_dataset.parent.append(self)
  1503. self._input_indexs = input_dataset.input_indexs
  1504. if self.is_sync():
  1505. raise RuntimeError("No shuffle after sync operators")
  1506. def get_args(self):
  1507. args = super().get_args()
  1508. args["buffer_size"] = self.buffer_size
  1509. if self.reshuffle_each_epoch is not None:
  1510. args["reshuffle_each_epoch"] = self.reshuffle_each_epoch
  1511. return args
  1512. def is_shuffled(self):
  1513. return True
  1514. # Pyfunc collection for multiprocess pyfunc
  1515. # This global variable will only be used within subprocesses
  1516. _GLOBAL_PYFUNC_LIST = []
  1517. # Pyfunc worker init function
  1518. # Python multiprocessing library forbid sending lambda function through pipe.
  1519. # This init function allow us to add all python function to a global collection and then fork afterwards.
  1520. def _pyfunc_worker_init(pyfunc_list):
  1521. global _GLOBAL_PYFUNC_LIST
  1522. _GLOBAL_PYFUNC_LIST = pyfunc_list
  1523. # Pyfunc worker execution function
  1524. # All exceptions will be raised to main processes
  1525. def _pyfunc_worker_exec(index, *args):
  1526. try:
  1527. return _GLOBAL_PYFUNC_LIST[index](*args)
  1528. except KeyboardInterrupt:
  1529. raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
  1530. # PythonCallable wrapper for multiprocess pyfunc
  1531. class _PythonCallable:
  1532. """
  1533. Internal python function wrapper for multiprocessing pyfunc.
  1534. """
  1535. def __init__(self, py_callable, idx, pool=None):
  1536. # Original python callable from user.
  1537. self.py_callable = py_callable
  1538. # Process pool created for current iterator.
  1539. self.pool = pool
  1540. # Python callable index for subprocess _GLOBAL_PYFUNC_LIST
  1541. self.idx = idx
  1542. def __call__(self, *args):
  1543. if self.pool is not None:
  1544. try:
  1545. # This call will send the tensors along with Python callable index to the process pool.
  1546. # Block, yield GIL. Current thread will reacquire GIL once result is returned.
  1547. return self.pool.apply(_pyfunc_worker_exec, [self.idx, *args])
  1548. except KeyboardInterrupt:
  1549. self.pool.terminate()
  1550. self.pool.join()
  1551. raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
  1552. # Invoke original python callable in master process in case the pool is gone.
  1553. return self.py_callable(*args)
  1554. class MapDataset(DatasetOp):
  1555. """
  1556. The result of applying Map operator to the input Dataset.
  1557. Args:
  1558. input_dataset (Dataset): Input Dataset to be mapped.
  1559. input_columns (list[str]): List of names of the input columns
  1560. (default=None, the operations will be applied on the first columns in the dataset).
  1561. The size of the list should match the number of inputs of the first operator.
  1562. operations (TensorOp): A function mapping a nested structure of tensors
  1563. to another nested structure of tensor (default=None).
  1564. output_columns (list[str], optional): list of names of the output columns.
  1565. The size of the list should match the number of outputs of the last operator
  1566. (default=None, output columns will be the input columns, i.e., the columns will
  1567. be replaced).
  1568. columns_order (list[str], optional): list of all the desired columns of the dataset (default=None).
  1569. The argument is mandatory if len(input_columns) != len(output_columns).
  1570. num_parallel_workers (int, optional): Number of workers to process the Dataset
  1571. in parallel (default=None).
  1572. python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
  1573. option could be beneficial if the python operation is computational heavy (default=False).
  1574. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used)
  1575. Raises:
  1576. ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
  1577. """
  1578. def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
  1579. num_parallel_workers=None, python_multiprocessing=False, cache=None):
  1580. super().__init__(num_parallel_workers)
  1581. self.children.append(input_dataset)
  1582. if input_columns is not None and not isinstance(input_columns, list):
  1583. input_columns = [input_columns]
  1584. self.input_columns = input_columns
  1585. if operations is not None and not isinstance(operations, list):
  1586. operations = [operations]
  1587. self.operations = operations
  1588. if output_columns is not None and not isinstance(output_columns, list):
  1589. output_columns = [output_columns]
  1590. self.output_columns = output_columns
  1591. self.cache = cache
  1592. self.columns_order = columns_order
  1593. if self.input_columns and self.output_columns \
  1594. and len(self.input_columns) != len(self.output_columns) \
  1595. and self.columns_order is None:
  1596. raise ValueError("When (len(input_columns) != len(output_columns)), columns_order must be specified.")
  1597. input_dataset.parent.append(self)
  1598. self._input_indexs = input_dataset.input_indexs
  1599. self.python_multiprocessing = python_multiprocessing
  1600. self.process_pool = None
  1601. def get_args(self):
  1602. args = super().get_args()
  1603. args["input_columns"] = self.input_columns
  1604. args["operations"] = self.operations
  1605. args["output_columns"] = self.output_columns
  1606. args["columns_order"] = self.columns_order
  1607. args["cache"] = self.cache.cache_client if self.cache is not None else None
  1608. return args
  1609. def get_dataset_size(self):
  1610. """
  1611. Get the number of batches in an epoch.
  1612. Return:
  1613. Number, number of batches.
  1614. """
  1615. return self.children[0].get_dataset_size()
  1616. def __deepcopy__(self, memodict):
  1617. if id(self) in memodict:
  1618. return memodict[id(self)]
  1619. cls = self.__class__
  1620. new_op = cls.__new__(cls)
  1621. memodict[id(self)] = new_op
  1622. new_op.children = copy.deepcopy(self.children, memodict)
  1623. new_op.input_columns = copy.deepcopy(self.input_columns, memodict)
  1624. new_op.output_columns = copy.deepcopy(self.output_columns, memodict)
  1625. new_op.columns_order = copy.deepcopy(self.columns_order, memodict)
  1626. new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
  1627. new_op.parent = copy.deepcopy(self.parent, memodict)
  1628. new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict)
  1629. new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
  1630. new_op.cache = copy.deepcopy(self.cache, memodict)
  1631. new_op.operations = self.operations
  1632. return new_op
  1633. # Iterator bootstrap will be called on iterator construction.
  1634. # A deep copy of Dataset object is created prior of iterator_bootstrap.
  1635. # This method will create per iterator process pool and bind pyfunc execution to the pool.
  1636. def iterator_bootstrap(self):
  1637. """
  1638. Per iterator bootstrap callback.
  1639. """
  1640. if self.python_multiprocessing:
  1641. iter_specific_operations = []
  1642. callable_list = []
  1643. # Pass #1, look for python callables and build list
  1644. for op in self.operations:
  1645. if callable(op):
  1646. callable_list.append(op)
  1647. if callable_list:
  1648. # Construct pool with the callable list
  1649. # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
  1650. self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
  1651. initializer=_pyfunc_worker_init,
  1652. initargs=(callable_list,))
  1653. # Pass #2
  1654. idx = 0
  1655. for op in self.operations:
  1656. if callable(op):
  1657. # Wrap python callable into _PythonCallable
  1658. iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool))
  1659. idx += 1
  1660. else:
  1661. # CPP ops remain the same
  1662. iter_specific_operations.append(op)
  1663. self.operations = iter_specific_operations
  1664. def __del__(self):
  1665. if hasattr(self, 'process_pool') and self.process_pool is not None:
  1666. self.process_pool.terminate()
  1667. class FilterDataset(DatasetOp):
  1668. """
  1669. The result of applying filter predicate to the input Dataset.
  1670. Args:
  1671. input_dataset: Input Dataset to be mapped.
  1672. predicate: python callable which returns a boolean value, if False then filter the element.
  1673. input_columns: (list[str]): List of names of the input columns, when
  1674. default=None, the predicate will be applied all columns in the dataset.
  1675. num_parallel_workers (int, optional): Number of workers to process the Dataset
  1676. in parallel (default=None).
  1677. """
  1678. def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None):
  1679. super().__init__(num_parallel_workers)
  1680. self.predicate = lambda *args: bool(predicate(*args))
  1681. self.children.append(input_dataset)
  1682. input_dataset.parent.append(self)
  1683. if input_columns is not None and not isinstance(input_columns, list):
  1684. input_columns = [input_columns]
  1685. self.input_columns = input_columns
  1686. def get_args(self):
  1687. args = super().get_args()
  1688. args["predicate"] = self.predicate
  1689. args["input_columns"] = self.input_columns
  1690. return args
  1691. def get_dataset_size(self):
  1692. """
  1693. Get the number of batches in an epoch.
  1694. the size cannot be determined before we run the pipeline.
  1695. Return:
  1696. 0
  1697. """
  1698. return 0
  1699. class RepeatDataset(DatasetOp):
  1700. """
  1701. The result of applying Repeat operator to the input Dataset.
  1702. Args:
  1703. input_dataset (Dataset): Input Dataset to be repeated.
  1704. count (int): Number of times the dataset should be repeated.
  1705. """
  1706. def __init__(self, input_dataset, count):
  1707. super().__init__()
  1708. if count is None:
  1709. self.count = -1
  1710. else:
  1711. self.count = count
  1712. self.children.append(input_dataset)
  1713. input_dataset.parent.append(self)
  1714. self._input_indexs = input_dataset.input_indexs
  1715. def get_args(self):
  1716. args = super().get_args()
  1717. args["count"] = self.count
  1718. return args
  1719. def get_dataset_size(self):
  1720. """
  1721. Get the number of batches in an epoch.
  1722. Return:
  1723. Number, number of batches.
  1724. """
  1725. child_size = self.children[0].get_dataset_size()
  1726. if child_size is not None:
  1727. return child_size * self.count
  1728. return None
  1729. def get_repeat_count(self):
  1730. """
  1731. Get the replication times in RepeatDataset.
  1732. Return:
  1733. Number, the count of repeat.
  1734. """
  1735. return self.count
  1736. class SkipDataset(DatasetOp):
  1737. """
  1738. The result of applying Skip operator to the input Dataset.
  1739. Args:
  1740. input_dataset (tuple): A tuple of datasets to be skipped.
  1741. count (int): Number of rows the dataset should be skipped.
  1742. """
  1743. def __init__(self, input_dataset, count):
  1744. super().__init__()
  1745. self.count = count
  1746. self.children.append(input_dataset)
  1747. input_dataset.parent.append(self)
  1748. self._input_indexs = input_dataset.input_indexs
  1749. def get_args(self):
  1750. args = super().get_args()
  1751. args["count"] = self.count
  1752. return args
  1753. def get_dataset_size(self):
  1754. """
  1755. Get the number of batches in an epoch.
  1756. Return:
  1757. Number, number of batches.
  1758. """
  1759. child_size = self.children[0].get_dataset_size()
  1760. output_size = 0
  1761. if self.count >= 0 and self.count < child_size:
  1762. output_size = child_size - self.count
  1763. return output_size
  1764. class TakeDataset(DatasetOp):
  1765. """
  1766. The result of applying Take operator to the input Dataset.
  1767. Args:
  1768. input_dataset (Dataset): Input Dataset to be taken element from.
  1769. count (int): Number of elements to be taken from the dataset.
  1770. """
  1771. def __init__(self, input_dataset, count):
  1772. super().__init__()
  1773. self.count = count
  1774. self.children.append(input_dataset)
  1775. input_dataset.parent.append(self)
  1776. self._input_indexs = input_dataset.input_indexs
  1777. def get_args(self):
  1778. args = super().get_args()
  1779. args["count"] = self.count
  1780. return args
  1781. def get_dataset_size(self):
  1782. """
  1783. Get the number of batches in an epoch.
  1784. Return:
  1785. Number, number of batches.
  1786. """
  1787. child_size = self.children[0].get_dataset_size()
  1788. if child_size < self.count:
  1789. return child_size
  1790. return self.count
  1791. class ZipDataset(DatasetOp):
  1792. """
  1793. The result of applying Zip operator to the input Dataset.
  1794. Args:
  1795. datasets (tuple): A tuple of datasets to be zipped together.
  1796. Raises:
  1797. TypeError: If dataset is not an instance of Dataset.
  1798. """
  1799. def __init__(self, datasets):
  1800. super().__init__()
  1801. for dataset in datasets:
  1802. if not isinstance(dataset, Dataset):
  1803. raise TypeError("The parameter %s of zip has type error!" % (dataset))
  1804. self.datasets = datasets
  1805. for data in datasets:
  1806. self.children.append(data)
  1807. data.parent.append(self)
  1808. def get_dataset_size(self):
  1809. """
  1810. Get the number of batches in an epoch.
  1811. Return:
  1812. Number, number of batches.
  1813. """
  1814. children_sizes = [c.get_dataset_size() for c in self.children]
  1815. if all(c is not None for c in children_sizes):
  1816. return min(children_sizes)
  1817. return None
  1818. def num_classes(self):
  1819. """
  1820. Get the number of classes in a dataset.
  1821. Return:
  1822. Number, number of classes.
  1823. """
  1824. return None
  1825. def is_sync(self):
  1826. return any([c.is_sync() for c in self.children])
  1827. def get_args(self):
  1828. args = super().get_args()
  1829. return args
  1830. class ConcatDataset(DatasetOp):
  1831. """
  1832. The result of applying concat dataset operator to the input Dataset.
  1833. Args:
  1834. datasets (list): A list of datasets to be concatenated together.
  1835. Raises:
  1836. TypeError: If dataset is not an instance of Dataset.
  1837. """
  1838. def __init__(self, datasets):
  1839. super().__init__()
  1840. for dataset in datasets:
  1841. if not isinstance(dataset, Dataset):
  1842. raise TypeError("The parameter %s of concat has type error!" % (dataset))
  1843. self.datasets = datasets
  1844. for data in datasets:
  1845. self.children.append(data)
  1846. data.parent.append(self)
  1847. def get_dataset_size(self):
  1848. """
  1849. Get the number of batches in an epoch.
  1850. Return:
  1851. Number, number of batches.
  1852. """
  1853. children_sizes = [c.get_dataset_size() for c in self.children]
  1854. dataset_size = sum(children_sizes)
  1855. return dataset_size
  1856. class RenameDataset(DatasetOp):
  1857. """
  1858. The result of applying Rename operator to the input Dataset.
  1859. Args:
  1860. input_dataset (Dataset): Input Dataset to be Renamed.
  1861. input_columns (list[str]): list of names of the input columns.
  1862. output_columns (list[str]): list of names of the output columns.
  1863. """
  1864. def __init__(self, input_dataset, input_columns, output_columns):
  1865. super().__init__()
  1866. if not isinstance(input_columns, list):
  1867. input_columns = [input_columns]
  1868. if not isinstance(output_columns, list):
  1869. output_columns = [output_columns]
  1870. self.input_column_names = input_columns
  1871. self.output_column_names = output_columns
  1872. self.children.append(input_dataset)
  1873. input_dataset.parent.append(self)
  1874. self._input_indexs = input_dataset.input_indexs
  1875. def get_args(self):
  1876. args = super().get_args()
  1877. args["input_columns"] = self.input_column_names
  1878. args["output_columns"] = self.output_column_names
  1879. return args
  1880. class ProjectDataset(DatasetOp):
  1881. """
  1882. The result of applying Project operator to the input Dataset.
  1883. Args:
  1884. input_dataset (Dataset): Input Dataset to be Project.
  1885. columns (list[str]): List of names of the columns to project.
  1886. prefetch_size (int, optional): Prefetch number of records ahead of the
  1887. user's request (default=None).
  1888. """
  1889. def __init__(self, input_dataset, columns, prefetch_size=None):
  1890. super().__init__()
  1891. if not isinstance(columns, list):
  1892. columns = [columns]
  1893. self.columns = columns
  1894. self.children.append(input_dataset)
  1895. self.prefetch_size = prefetch_size
  1896. input_dataset.parent.append(self)
  1897. self._input_indexs = input_dataset.input_indexs
  1898. def get_args(self):
  1899. args = super().get_args()
  1900. args["columns"] = self.columns
  1901. args["prefetch_size"] = self.prefetch_size
  1902. return args
  1903. class TransferDataset(DatasetOp):
  1904. """
  1905. The result of applying TDT operator to the input Dataset.
  1906. Args:
  1907. input_dataset (Dataset): Input Dataset to be transferred.
  1908. queue_name (str): Name of device queue.
  1909. device_id (int): Id of device.
  1910. device_type (str): Type of device, including "CPU", "GPU", and "Ascend".
  1911. send_epoch_end (bool, optional): Whether send end of sequence to device or not.(default=True)
  1912. """
  1913. def __init__(self, input_dataset, queue_name, device_id, device_type, send_epoch_end=True):
  1914. super().__init__()
  1915. self.children.append(input_dataset)
  1916. input_dataset.parent.append(self)
  1917. self.queue_name = queue_name
  1918. self._input_indexs = input_dataset.input_indexs
  1919. self._device_type = device_type
  1920. self._device_id = device_id
  1921. self._send_epoch_end = send_epoch_end
  1922. self.iterator = None
  1923. def get_args(self):
  1924. args = super().get_args()
  1925. args["queue_name"] = self.queue_name
  1926. args["device_type"] = self._device_type
  1927. args["device_id"] = self._device_id
  1928. args["send_epoch_end"] = self._send_epoch_end
  1929. return args
  1930. def create_dict_iterator(self, num_epochs=-1):
  1931. raise RuntimeError("TransferDataset is not iterable")
  1932. def create_tuple_iterator(self, columns=None, num_epochs=-1):
  1933. raise RuntimeError("TransferDataset is not iterable")
  1934. def __iter__(self):
  1935. raise RuntimeError("TransferDataset is not iterable")
  1936. def output_shapes(self):
  1937. raise RuntimeError("TransferDataset does not support output_shapes")
  1938. def output_types(self):
  1939. raise RuntimeError("TransferDataset does not support output_types")
  1940. def send(self, num_epochs=-1):
  1941. # need to keep iterator alive so the executionTree is not destroyed
  1942. if self._noop_mode():
  1943. return
  1944. self.iterator = TupleIterator(self, num_epochs=-1)
  1945. def stop_send(self):
  1946. self.iterator.depipeline.StopSend()
  1947. class RangeDataset(MappableDataset):
  1948. """
  1949. A source dataset that reads and parses datasets stored on disk in a range.
  1950. Args:
  1951. start (int): starting index.
  1952. stop (int): ending index.
  1953. step (int): step size in a range.
  1954. """
  1955. def __init__(self, start, stop, step):
  1956. super().__init__()
  1957. self.start = start
  1958. self.stop = stop
  1959. self.step = step
  1960. def get_args(self):
  1961. args = super().get_args()
  1962. args["start"] = self.start
  1963. args["stop"] = self.stop
  1964. args["step"] = self.step
  1965. return args
  1966. def is_shuffled(self):
  1967. return False
  1968. def is_sharded(self):
  1969. return False
  1970. def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False):
  1971. """
  1972. Create sampler based on user input.
  1973. Args:
  1974. num_samples (int): Number of samples.
  1975. input_sampler (Iterable / Sampler): Sampler from user.
  1976. shuffle (bool): Shuffle.
  1977. num_shards (int): Number of shard for sharding.
  1978. shard_id (int): Shard ID.
  1979. non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False).
  1980. """
  1981. if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]):
  1982. return None
  1983. if input_sampler is not None:
  1984. # If the user provided a sampler, then it doesn't matter what the other args are because
  1985. # we are being asked specifically to use the given sampler.
  1986. # That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
  1987. # be None. Consider this example:
  1988. # sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
  1989. # data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
  1990. # In this case, the user has given different sample-related arguments that contradict each other.
  1991. # To prevent this, only allow the user to manually specify the sampler if those arguments are all None
  1992. if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
  1993. samplers.RandomSampler, samplers.SubsetRandomSampler,
  1994. samplers.WeightedRandomSampler, samplers.Sampler)) and
  1995. (any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
  1996. raise ValueError(
  1997. 'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
  1998. ' shard_id: {}, shuffle: {})'.format(num_samples, num_shards, shard_id, shuffle))
  1999. return input_sampler
  2000. if shuffle is None:
  2001. if num_shards is not None:
  2002. # If shuffle is not specified, sharding enabled, use distributed random sampler
  2003. shuffle = True
  2004. return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
  2005. # If shuffle is not specified, sharding disabled, use random sampler
  2006. if num_samples is not None:
  2007. return samplers.RandomSampler(replacement=True, num_samples=num_samples)
  2008. return samplers.RandomSampler(num_samples=num_samples)
  2009. if shuffle is True:
  2010. if num_shards is not None:
  2011. # If shuffle enabled, sharding enabled, use distributed random sampler
  2012. return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
  2013. # If shuffle enabled, sharding disabled, use random sampler
  2014. if num_samples is not None:
  2015. return samplers.RandomSampler(replacement=True, num_samples=num_samples)
  2016. return samplers.RandomSampler(num_samples=num_samples)
  2017. if num_shards is not None:
  2018. # If shuffle disabled, sharding enabled, use distributed sequential sampler
  2019. return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
  2020. # If shuffle disabled, sharding disabled, use sequential sampler
  2021. return samplers.SequentialSampler(num_samples=num_samples)
  2022. class ImageFolderDatasetV2(MappableDataset):
  2023. """
  2024. A source dataset that reads images from a tree of directories.
  2025. All images within one folder have the same label.
  2026. The generated dataset has two columns ['image', 'label'].
  2027. The shape of the image column is [image_size] if decode flag is False, or [H,W,C]
  2028. otherwise.
  2029. The type of the image tensor is uint8. The label is just a scalar uint64
  2030. tensor.
  2031. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  2032. below shows what input args are allowed and their expected behavior.
  2033. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  2034. :widths: 25 25 50
  2035. :header-rows: 1
  2036. * - Parameter 'sampler'
  2037. - Parameter 'shuffle'
  2038. - Expected Order Behavior
  2039. * - None
  2040. - None
  2041. - random order
  2042. * - None
  2043. - True
  2044. - random order
  2045. * - None
  2046. - False
  2047. - sequential order
  2048. * - Sampler object
  2049. - None
  2050. - order defined by sampler
  2051. * - Sampler object
  2052. - True
  2053. - not allowed
  2054. * - Sampler object
  2055. - False
  2056. - not allowed
  2057. Args:
  2058. dataset_dir (str): Path to the root directory that contains the dataset.
  2059. num_samples (int, optional): The number of images to be included in the dataset
  2060. (default=None, all images).
  2061. num_parallel_workers (int, optional): Number of workers to read the data
  2062. (default=None, set in the config).
  2063. shuffle (bool, optional): Whether or not to perform shuffle on the dataset
  2064. (default=None, expected order behavior shown in the table).
  2065. sampler (Sampler, optional): Object used to choose samples from the
  2066. dataset (default=None, expected order behavior shown in the table).
  2067. extensions (list[str], optional): List of file extensions to be
  2068. included in the dataset (default=None).
  2069. class_indexing (dict, optional): A str-to-int mapping from folder name to index
  2070. (default=None, the folder names will be sorted
  2071. alphabetically and each class will be given a
  2072. unique index starting from 0).
  2073. decode (bool, optional): decode the images after reading (default=False).
  2074. num_shards (int, optional): Number of shards that the dataset should be divided
  2075. into (default=None).
  2076. shard_id (int, optional): The shard ID within num_shards (default=None). This
  2077. argument should be specified only when num_shards is also specified.
  2078. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used)
  2079. Raises:
  2080. RuntimeError: If sampler and shuffle are specified at the same time.
  2081. RuntimeError: If sampler and sharding are specified at the same time.
  2082. RuntimeError: If num_shards is specified but shard_id is None.
  2083. RuntimeError: If shard_id is specified but num_shards is None.
  2084. RuntimeError: If class_indexing is not a dictionary.
  2085. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  2086. Examples:
  2087. >>> import mindspore.dataset as ds
  2088. >>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images
  2089. >>> dataset_dir = "/path/to/imagefolder_directory"
  2090. >>> # 1) read all samples (image files) in dataset_dir with 8 threads
  2091. >>> imagefolder_dataset = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8)
  2092. >>> # 2) read all samples (image files) from folder cat and folder dog with label 0 and 1
  2093. >>> imagefolder_dataset = ds.ImageFolderDatasetV2(dataset_dir,class_indexing={"cat":0,"dog":1})
  2094. >>> # 3) read all samples (image files) in dataset_dir with extensions .JPEG and .png (case sensitive)
  2095. >>> imagefolder_dataset = ds.ImageFolderDatasetV2(dataset_dir, extensions={".JPEG",".png"})
  2096. """
  2097. @check_imagefolderdatasetv2
  2098. def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
  2099. shuffle=None, sampler=None, extensions=None, class_indexing=None,
  2100. decode=False, num_shards=None, shard_id=None, cache=None):
  2101. super().__init__(num_parallel_workers)
  2102. self.dataset_dir = dataset_dir
  2103. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  2104. self.num_samples = num_samples
  2105. self.shuffle_level = shuffle
  2106. self.extensions = extensions
  2107. self.class_indexing = class_indexing
  2108. self.decode = decode
  2109. self.num_shards = num_shards
  2110. self.shard_id = shard_id
  2111. self.cache = cache
  2112. def get_args(self):
  2113. args = super().get_args()
  2114. args["dataset_dir"] = self.dataset_dir
  2115. args["num_samples"] = self.num_samples
  2116. args["sampler"] = self.sampler
  2117. args["shuffle"] = self.shuffle_level
  2118. args["extensions"] = self.extensions
  2119. args["class_indexing"] = self.class_indexing
  2120. args["decode"] = self.decode
  2121. args["num_shards"] = self.num_shards
  2122. args["shard_id"] = self.shard_id
  2123. args["cache"] = self.cache.cache_client if self.cache is not None else None
  2124. return args
  2125. def get_dataset_size(self):
  2126. """
  2127. Get the number of batches in an epoch.
  2128. Return:
  2129. Number, number of batches.
  2130. """
  2131. num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[0]
  2132. rows_per_shard = get_num_rows(num_rows, self.num_shards)
  2133. rows_from_sampler = self._get_sampler_dataset_size()
  2134. if rows_from_sampler is None:
  2135. return rows_per_shard
  2136. return min(rows_from_sampler, rows_per_shard)
  2137. def num_classes(self):
  2138. """
  2139. Get the number of classes in dataset.
  2140. Return:
  2141. Number, number of classes.
  2142. """
  2143. return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[1]
  2144. def is_shuffled(self):
  2145. if self.shuffle_level is None:
  2146. return True
  2147. return self.shuffle_level or self.sampler.is_shuffled()
  2148. def is_sharded(self):
  2149. if self.num_shards is not None:
  2150. return self.num_shards > 1
  2151. return self.sampler.is_sharded()
  2152. class MnistDataset(MappableDataset):
  2153. """
  2154. A source dataset for reading and parsing the Mnist dataset.
  2155. The generated dataset has two columns ['image', 'label'].
  2156. The type of the image tensor is uint8. The label is just a scalar uint32 tensor.
  2157. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  2158. below shows what input args are allowed and their expected behavior.
  2159. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  2160. :widths: 25 25 50
  2161. :header-rows: 1
  2162. * - Parameter 'sampler'
  2163. - Parameter 'shuffle'
  2164. - Expected Order Behavior
  2165. * - None
  2166. - None
  2167. - random order
  2168. * - None
  2169. - True
  2170. - random order
  2171. * - None
  2172. - False
  2173. - sequential order
  2174. * - Sampler object
  2175. - None
  2176. - order defined by sampler
  2177. * - Sampler object
  2178. - True
  2179. - not allowed
  2180. * - Sampler object
  2181. - False
  2182. - not allowed
  2183. Citation of Mnist dataset.
  2184. .. code-block::
  2185. @article{lecun2010mnist,
  2186. title = {MNIST handwritten digit database},
  2187. author = {LeCun, Yann and Cortes, Corinna and Burges, CJ},
  2188. journal = {ATT Labs [Online]},
  2189. volume = {2},
  2190. year = {2010},
  2191. howpublished = {http://yann.lecun.com/exdb/mnist},
  2192. description = {The MNIST database of handwritten digits has a training set of 60,000 examples,
  2193. and a test set of 10,000 examples. It is a subset of a larger set available from
  2194. NIST. The digits have been size-normalized and centered in a fixed-size image.}
  2195. }
  2196. Args:
  2197. dataset_dir (str): Path to the root directory that contains the dataset.
  2198. num_samples (int, optional): The number of images to be included in the dataset
  2199. (default=None, all images).
  2200. num_parallel_workers (int, optional): Number of workers to read the data
  2201. (default=value, set in the config).
  2202. shuffle (bool, optional): Whether or not to perform shuffle on the dataset
  2203. (default=None, expected order behavior shown in the table).
  2204. sampler (Sampler, optional): Object used to choose samples from the
  2205. dataset (default=None, expected order behavior shown in the table).
  2206. num_shards (int, optional): Number of shards that the dataset should be divided
  2207. into (default=None).
  2208. shard_id (int, optional): The shard ID within num_shards (default=None). This
  2209. argument should be specified only when num_shards is also specified.
  2210. Raises:
  2211. RuntimeError: If sampler and shuffle are specified at the same time.
  2212. RuntimeError: If sampler and sharding are specified at the same time.
  2213. RuntimeError: If num_shards is specified but shard_id is None.
  2214. RuntimeError: If shard_id is specified but num_shards is None.
  2215. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  2216. Examples:
  2217. >>> import mindspore.dataset as ds
  2218. >>> dataset_dir = "/path/to/mnist_folder"
  2219. >>> # 1) read 3 samples from mnist_dataset
  2220. >>> mnist_dataset = ds.MnistDataset(dataset_dir=dataset_dir, num_samples=3)
  2221. >>> # in mnist_dataset dataset, each dictionary has keys "image" and "label"
  2222. """
  2223. @check_mnist_cifar_dataset
  2224. def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
  2225. shuffle=None, sampler=None, num_shards=None, shard_id=None):
  2226. super().__init__(num_parallel_workers)
  2227. self.dataset_dir = dataset_dir
  2228. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  2229. self.num_samples = num_samples
  2230. self.shuffle_level = shuffle
  2231. self.num_shards = num_shards
  2232. self.shard_id = shard_id
  2233. def get_args(self):
  2234. args = super().get_args()
  2235. args["dataset_dir"] = self.dataset_dir
  2236. args["num_samples"] = self.num_samples
  2237. args["shuffle"] = self.shuffle_level
  2238. args["sampler"] = self.sampler
  2239. args["num_shards"] = self.num_shards
  2240. args["shard_id"] = self.shard_id
  2241. return args
  2242. def get_dataset_size(self):
  2243. """
  2244. Get the number of batches in an epoch.
  2245. Return:
  2246. Number, number of batches.
  2247. """
  2248. num_rows = MnistOp.get_num_rows(self.dataset_dir)
  2249. rows_per_shard = get_num_rows(num_rows, self.num_shards)
  2250. rows_from_sampler = self._get_sampler_dataset_size()
  2251. if rows_from_sampler is None:
  2252. return rows_per_shard
  2253. return min(rows_from_sampler, rows_per_shard)
  2254. def is_shuffled(self):
  2255. if self.shuffle_level is None:
  2256. return True
  2257. return self.shuffle_level or self.sampler.is_shuffled()
  2258. def is_sharded(self):
  2259. if self.num_shards is not None:
  2260. return self.num_shards > 1
  2261. return self.sampler.is_sharded()
  2262. class MindDataset(MappableDataset):
  2263. """
  2264. A source dataset that reads from shard files and database.
  2265. Args:
  2266. dataset_file (str, list[str]): One of file names or file list in dataset.
  2267. columns_list (list[str], optional): List of columns to be read (default=None).
  2268. num_parallel_workers (int, optional): The number of readers (default=None).
  2269. shuffle (bool, optional): Whether or not to perform shuffle on the dataset
  2270. (default=None, performs shuffle).
  2271. num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
  2272. shard_id (int, optional): The shard ID within num_shards (default=None). This
  2273. argument should be specified only when num_shards is also specified.
  2274. block_reader (bool, optional): Whether read data by block mode (default=False).
  2275. sampler (Sampler, optional): Object used to choose samples from the
  2276. dataset (default=None, sampler is exclusive
  2277. with shuffle and block_reader). Support list: SubsetRandomSampler,
  2278. PkSampler, RandomSampler, SequentialSampler, DistributedSampler.
  2279. padded_sample (dict, optional): Samples will be appended to dataset, which
  2280. keys are the same as column_list.
  2281. num_padded (int, optional): Number of padding samples.Dataset size
  2282. plus num_padded should be divisible by num_shards.
  2283. num_samples (int, optional): The number of samples to be included in the dataset
  2284. (default=None, all samples).
  2285. Raises:
  2286. ValueError: If num_shards is specified but shard_id is None.
  2287. ValueError: If shard_id is specified but num_shards is None.
  2288. ValueError: If block reader is true but partition is specified.
  2289. """
  2290. @check_minddataset
  2291. def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None,
  2292. shuffle=None, num_shards=None, shard_id=None,
  2293. block_reader=False, sampler=None, padded_sample=None,
  2294. num_padded=None, num_samples=None):
  2295. super().__init__(num_parallel_workers)
  2296. if isinstance(dataset_file, list):
  2297. self.load_dataset = False
  2298. else:
  2299. self.load_dataset = True
  2300. self.dataset_file = dataset_file
  2301. self.columns_list = columns_list
  2302. self.shuffle_option = shuffle
  2303. self.num_shards = num_shards
  2304. self.shard_id = shard_id
  2305. if block_reader is True and num_shards is not None:
  2306. raise ValueError("block_reader not allowed true when use partitions")
  2307. if block_reader is True and shuffle is True:
  2308. raise ValueError("block_reader not allowed true when use shuffle")
  2309. if block_reader is True:
  2310. logger.warning("WARN: global shuffle is not used.")
  2311. if sampler is not None:
  2312. if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler,
  2313. samplers.DistributedSampler, samplers.RandomSampler,
  2314. samplers.SequentialSampler)) is False:
  2315. raise ValueError("The sampler is not supported yet.")
  2316. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  2317. self.num_samples = num_samples
  2318. # sampler exclusive
  2319. if block_reader is True and sampler is not None:
  2320. raise ValueError("block_reader not allowed true when use sampler")
  2321. if num_padded is None:
  2322. num_padded = 0
  2323. self.block_reader = block_reader
  2324. self.padded_sample = padded_sample
  2325. self.num_padded = num_padded
  2326. def get_args(self):
  2327. args = super().get_args()
  2328. padded_sample = None
  2329. if self.padded_sample:
  2330. padded_sample = {}
  2331. for k, v in self.padded_sample.items():
  2332. if isinstance(v, np.ndarray):
  2333. padded_sample[k] = v.tobytes()
  2334. else:
  2335. padded_sample[k] = v
  2336. args["dataset_file"] = self.dataset_file
  2337. args["load_dataset"] = self.load_dataset
  2338. args["columns_list"] = self.columns_list
  2339. args["shuffle_option"] = self.shuffle_option
  2340. args["num_samples"] = self.num_samples
  2341. args["block_reader"] = self.block_reader
  2342. args["num_padded"] = self.num_padded
  2343. args["padded_sample"] = padded_sample
  2344. args["sampler"] = self.sampler
  2345. return args
  2346. def get_dataset_size(self):
  2347. """
  2348. Get the number of batches in an epoch.
  2349. Return:
  2350. Number, number of batches.
  2351. """
  2352. if self._dataset_size is None:
  2353. if self.load_dataset:
  2354. dataset_file = [self.dataset_file]
  2355. else:
  2356. dataset_file = self.dataset_file
  2357. num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded)
  2358. return num_rows
  2359. return self._dataset_size
  2360. # manually set dataset_size as a tempoary solution.
  2361. def set_dataset_size(self, value):
  2362. logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
  2363. if value >= 0:
  2364. self._dataset_size = value
  2365. else:
  2366. raise ValueError('Set dataset_size with negative value {}'.format(value))
  2367. def is_shuffled(self):
  2368. if self.shuffle_option is None:
  2369. return True
  2370. return self.shuffle_option or self.sampler.is_shuffled()
  2371. def is_sharded(self):
  2372. if self.num_shards is not None:
  2373. return self.num_shards > 1
  2374. return self.sampler.is_sharded()
  2375. def _iter_fn(dataset, num_samples):
  2376. """
  2377. Generator function wrapper for iterable dataset.
  2378. """
  2379. if num_samples is not None:
  2380. ds_iter = iter(dataset)
  2381. for _ in range(num_samples):
  2382. try:
  2383. val = next(ds_iter)
  2384. except StopIteration:
  2385. return
  2386. # convert output tensors to ndarrays
  2387. yield tuple([np.array(x, copy=False) for x in val])
  2388. else:
  2389. for val in dataset:
  2390. # convert output tensors to ndarrays
  2391. yield tuple([np.array(x, copy=False) for x in val])
  2392. def _generator_fn(generator, num_samples):
  2393. """
  2394. Generator function wrapper for generator function dataset.
  2395. """
  2396. if num_samples is not None:
  2397. gen_iter = generator()
  2398. for _ in range(num_samples):
  2399. try:
  2400. val = next(gen_iter)
  2401. except StopIteration:
  2402. return
  2403. yield val
  2404. else:
  2405. gen_iter = generator()
  2406. for val in gen_iter:
  2407. yield val
  2408. def _py_sampler_fn(sampler, num_samples, dataset):
  2409. """
  2410. Generator function wrapper for mappable dataset with python sampler.
  2411. """
  2412. if num_samples is not None:
  2413. sampler_iter = iter(sampler)
  2414. for _ in range(num_samples):
  2415. try:
  2416. idx = next(sampler_iter)
  2417. except StopIteration:
  2418. return
  2419. val = dataset[idx]
  2420. # convert output tensors to ndarrays
  2421. yield tuple([np.array(x, copy=False) for x in val])
  2422. else:
  2423. for i in sampler:
  2424. val = dataset[i]
  2425. # convert output tensors to ndarrays
  2426. yield tuple([np.array(x, copy=False) for x in val])
  2427. def _cpp_sampler_fn(sampler, dataset):
  2428. """
  2429. Generator function wrapper for mappable dataset with cpp sampler.
  2430. """
  2431. indices = sampler.get_indices()
  2432. for i in indices:
  2433. val = dataset[i]
  2434. # convert output tensors to ndarrays
  2435. yield tuple([np.array(x, copy=False) for x in val])
  2436. def _cpp_sampler_fn_mp(sampler, dataset, num_worker):
  2437. """
  2438. Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
  2439. """
  2440. indices = sampler.get_indices()
  2441. return _sampler_fn_mp(indices, dataset, num_worker)
  2442. def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker):
  2443. """
  2444. Multiprocessing generator function wrapper for mappable dataset with python sampler.
  2445. """
  2446. indices = _fetch_py_sampler_indices(sampler, num_samples)
  2447. return _sampler_fn_mp(indices, dataset, num_worker)
  2448. def _fetch_py_sampler_indices(sampler, num_samples):
  2449. """
  2450. Indice fetcher for python sampler.
  2451. """
  2452. if num_samples is not None:
  2453. sampler_iter = iter(sampler)
  2454. ret = []
  2455. for _ in range(num_samples):
  2456. try:
  2457. val = next(sampler_iter)
  2458. ret.append(val)
  2459. except StopIteration:
  2460. break
  2461. return ret
  2462. return [i for i in sampler]
  2463. def _fill_worker_indices(workers, indices, idx):
  2464. """
  2465. Worker index queue filler, fill worker index queue in round robin order.
  2466. """
  2467. num_worker = len(workers)
  2468. while idx < len(indices):
  2469. try:
  2470. workers[idx % num_worker].put(indices[idx])
  2471. idx += 1
  2472. except queue.Full:
  2473. break
  2474. return idx
  2475. def _sampler_fn_mp(indices, dataset, num_worker):
  2476. """
  2477. Multiprocessing generator function wrapper master process.
  2478. """
  2479. workers = []
  2480. # Event for end of epoch
  2481. eoe = multiprocessing.Event()
  2482. # Create workers
  2483. for _ in range(num_worker):
  2484. worker = _GeneratorWorker(dataset, eoe)
  2485. worker.daemon = True
  2486. workers.append(worker)
  2487. # Fill initial index queues
  2488. idx_cursor = 0
  2489. idx_cursor = _fill_worker_indices(workers, indices, idx_cursor)
  2490. # Start all workers
  2491. for w in workers:
  2492. w.start()
  2493. # Fetch results
  2494. for i in range(len(indices)):
  2495. # Fetch result and put index
  2496. try:
  2497. result = workers[i % num_worker].get()
  2498. except queue.Empty:
  2499. raise Exception("Generator worker process timeout")
  2500. except KeyboardInterrupt:
  2501. for w in workers:
  2502. w.terminate()
  2503. w.join()
  2504. raise Exception("Generator worker receives KeyboardInterrupt")
  2505. if idx_cursor < len(indices):
  2506. idx_cursor = _fill_worker_indices(workers, indices, idx_cursor)
  2507. # Set eoe event once all indices are sent
  2508. if idx_cursor == len(indices) and not eoe.is_set():
  2509. eoe.set()
  2510. yield tuple([np.array(x, copy=False) for x in result])
  2511. def _generator_worker_loop(dataset, idx_queue, result_queue, eoe):
  2512. """
  2513. Multiprocessing generator worker process loop.
  2514. """
  2515. while True:
  2516. # Fetch index, block
  2517. try:
  2518. idx = idx_queue.get()
  2519. except KeyboardInterrupt:
  2520. raise Exception("Generator worker receives KeyboardInterrupt")
  2521. if idx is None:
  2522. # When the queue is out of scope from master process, a None item can be fetched from the queue.
  2523. # Upon receiving None, worker process should check if EOE is set.
  2524. assert eoe.is_set(), ""
  2525. return
  2526. # Fetch data, any exception from __getitem__ will terminate worker and timeout master process
  2527. result = dataset[idx]
  2528. # Send data, block
  2529. try:
  2530. result_queue.put(result)
  2531. except KeyboardInterrupt:
  2532. raise Exception("Generator worker receives KeyboardInterrupt")
  2533. del result, idx
  2534. class _GeneratorWorker(multiprocessing.Process):
  2535. """
  2536. Worker process for multiprocess Generator.
  2537. """
  2538. def __init__(self, dataset, eoe):
  2539. self.idx_queue = multiprocessing.Queue(16)
  2540. self.res_queue = multiprocessing.Queue(16)
  2541. super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe))
  2542. def put(self, item):
  2543. """
  2544. Put function for worker index queue. Never block. Raise queue.Full on failure.
  2545. """
  2546. self.idx_queue.put_nowait(item)
  2547. def get(self):
  2548. """
  2549. Get function for worker result queue. Block with timeout.
  2550. """
  2551. return self.res_queue.get(timeout=5)
  2552. def __del__(self):
  2553. self.terminate()
  2554. class GeneratorDataset(MappableDataset):
  2555. """
  2556. A source dataset that generate data from python by invoking python data source each epoch.
  2557. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  2558. below shows what input args are allowed and their expected behavior.
  2559. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  2560. :widths: 25 25 50
  2561. :header-rows: 1
  2562. * - Parameter 'sampler'
  2563. - Parameter 'shuffle'
  2564. - Expected Order Behavior
  2565. * - None
  2566. - None
  2567. - random order
  2568. * - None
  2569. - True
  2570. - random order
  2571. * - None
  2572. - False
  2573. - sequential order
  2574. * - Sampler object
  2575. - None
  2576. - order defined by sampler
  2577. * - Sampler object
  2578. - True
  2579. - not allowed
  2580. * - Sampler object
  2581. - False
  2582. - not allowed
  2583. Args:
  2584. source (Callable/Iterable/Random Accessible):
  2585. A generator callable object, an iterable python object or a random accessible python object.
  2586. Callable source is required to return a tuple of numpy array as a row of the dataset on source().next().
  2587. Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next().
  2588. Random accessible source is required to return a tuple of numpy array as a row of the dataset on
  2589. source[idx].
  2590. column_names (list[str], optional): List of column names of the dataset (default=None). Users are required to
  2591. provide either column_names or schema.
  2592. column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None).
  2593. If provided, sanity check will be performed on generator output.
  2594. schema (Schema/str, optional): Path to the json schema file or schema object (default=None). Users are
  2595. required to provide either column_names or schema. If both are provided, schema will be used.
  2596. num_samples (int, optional): The number of samples to be included in the dataset
  2597. (default=None, all images).
  2598. num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
  2599. shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
  2600. (default=None, expected order behavior shown in the table).
  2601. sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
  2602. required (default=None, expected order behavior shown in the table).
  2603. num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
  2604. When this argument is specified, 'num_samples' will not effect. Random accessible input is required.
  2605. shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
  2606. when num_shards is also specified. Random accessible input is required.
  2607. Examples:
  2608. >>> import mindspore.dataset as ds
  2609. >>> # 1) Multidimensional generator function as callable input
  2610. >>> def generator_md():
  2611. >>> for i in range(64):
  2612. >>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
  2613. >>> # create multi_dimension_generator_dataset with GeneratorMD and column name "multi_dimensional_data"
  2614. >>> multi_dimension_generator_dataset = ds.GeneratorDataset(generator_md, ["multi_dimensional_data"])
  2615. >>> # 2) Multi-column generator function as callable input
  2616. >>> def generator_mc(maxid = 64):
  2617. >>> for i in range(maxid):
  2618. >>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
  2619. >>> # create multi_column_generator_dataset with GeneratorMC and column names "col1" and "col2"
  2620. >>> multi_column_generator_dataset = ds.GeneratorDataset(generator_mc, ["col1", "col2"])
  2621. >>> # 3) Iterable dataset as iterable input
  2622. >>> class MyIterable():
  2623. >>> def __iter__(self):
  2624. >>> return # User implementation
  2625. >>> # create iterable_generator_dataset with MyIterable object
  2626. >>> iterable_generator_dataset = ds.GeneratorDataset(MyIterable(), ["col1"])
  2627. >>> # 4) Random accessible dataset as Random accessible input
  2628. >>> class MyRA():
  2629. >>> def __getitem__(self, index):
  2630. >>> return # User implementation
  2631. >>> # create ra_generator_dataset with MyRA object
  2632. >>> ra_generator_dataset = ds.GeneratorDataset(MyRA(), ["col1"])
  2633. >>> # List/Dict/Tuple is also random accessible
  2634. >>> list_generator = ds.GeneratorDataset([(np.array(0),), (np.array(1)), (np.array(2))], ["col1"])
  2635. >>> # 5) Built-in Sampler
  2636. >>> my_generator = ds.GeneratorDataset(my_ds, ["img", "label"], sampler=samplers.RandomSampler())
  2637. >>>
  2638. """
  2639. @check_generatordataset
  2640. def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
  2641. num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None):
  2642. super().__init__(num_parallel_workers)
  2643. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  2644. if self.sampler is not None and hasattr(source, "__getitem__"):
  2645. if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
  2646. samplers.RandomSampler, samplers.SubsetRandomSampler,
  2647. samplers.WeightedRandomSampler, samplers.Sampler)):
  2648. sampler_instance = self.sampler.create()
  2649. sampler_instance.set_num_rows(len(source))
  2650. sampler_instance.initialize()
  2651. if num_parallel_workers > 1:
  2652. self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers))
  2653. else:
  2654. self.source = (lambda: _cpp_sampler_fn(sampler_instance, source))
  2655. else:
  2656. if num_parallel_workers > 1:
  2657. self.source = (lambda: _py_sampler_fn_mp(self.sampler, num_samples, source, num_parallel_workers))
  2658. else:
  2659. self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source))
  2660. else:
  2661. try:
  2662. iter(source)
  2663. except TypeError:
  2664. # Use generator function if input callable
  2665. self.source = (lambda: _generator_fn(source, num_samples))
  2666. else:
  2667. # Use iterator function if input is iterable
  2668. # Random accessible input is also iterable
  2669. self.source = (lambda: _iter_fn(source, num_samples))
  2670. if column_names is not None and not isinstance(column_names, list):
  2671. column_names = [column_names]
  2672. self.column_names = column_names
  2673. if column_types is not None:
  2674. self.column_types = mstypelist_to_detypelist(column_types)
  2675. else:
  2676. self.column_types = column_types
  2677. if schema is not None:
  2678. self.schema = schema
  2679. if not isinstance(schema, Schema):
  2680. self.schema = Schema(schema)
  2681. self.column_names = []
  2682. self.column_types = []
  2683. for col in self.schema.columns:
  2684. self.column_names.append(col["name"])
  2685. self.column_types.append(DataType(col["type"]))
  2686. if source is not None and hasattr(source, "__len__"):
  2687. self._dataset_size = len(source)
  2688. def get_args(self):
  2689. args = super().get_args()
  2690. args["source"] = self.source
  2691. args["column_names"] = self.column_names
  2692. args["column_types"] = self.column_types
  2693. return args
  2694. def get_dataset_size(self):
  2695. """
  2696. Get the number of batches in an epoch.
  2697. Return:
  2698. Number, number of batches.
  2699. """
  2700. rows_from_sampler = self._get_sampler_dataset_size()
  2701. if rows_from_sampler is None:
  2702. return self._dataset_size
  2703. if self._dataset_size is None:
  2704. return None
  2705. return min(rows_from_sampler, self._dataset_size)
  2706. # manually set dataset_size as a temporary solution.
  2707. def set_dataset_size(self, value):
  2708. if value >= 0:
  2709. self._dataset_size = value
  2710. else:
  2711. raise ValueError('Set dataset_size with negative value {}'.format(value))
  2712. def __deepcopy__(self, memodict):
  2713. if id(self) in memodict:
  2714. return memodict[id(self)]
  2715. cls = self.__class__
  2716. new_op = cls.__new__(cls)
  2717. memodict[id(self)] = new_op
  2718. new_op.children = copy.deepcopy(self.children, memodict)
  2719. new_op.parent = copy.deepcopy(self.parent, memodict)
  2720. new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
  2721. new_op.column_types = copy.deepcopy(self.column_types, memodict)
  2722. new_op.column_names = copy.deepcopy(self.column_names, memodict)
  2723. new_op.source = self.source
  2724. new_op.sampler = self.sampler
  2725. return new_op
  2726. def is_shuffled(self):
  2727. return self.sampler.is_shuffled()
  2728. def is_sharded(self):
  2729. return self.sampler.is_sharded()
  2730. class TFRecordDataset(SourceDataset):
  2731. """
  2732. A source dataset that reads and parses datasets stored on disk in TFData format.
  2733. Args:
  2734. dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of
  2735. files. The list will be sorted in a lexicographical order.
  2736. schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
  2737. If the schema is not provided, the meta data from the TFData file is considered the schema.
  2738. columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
  2739. num_samples (int, optional): number of samples(rows) to read (default=None).
  2740. If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset;
  2741. If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows;
  2742. If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
  2743. num_parallel_workers (int, optional): number of workers to read the data
  2744. (default=None, number set in the config).
  2745. shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
  2746. If shuffle is False, no shuffling will be performed;
  2747. If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
  2748. Otherwise, there are two levels of shuffling:
  2749. - Shuffle.GLOBAL: Shuffle both the files and samples.
  2750. - Shuffle.FILES: Shuffle files only.
  2751. num_shards (int, optional): Number of shards that the dataset should be divided
  2752. into (default=None).
  2753. shard_id (int, optional): The shard ID within num_shards (default=None). This
  2754. argument should be specified only when num_shards is also specified.
  2755. shard_equal_rows (bool): Get equal rows for all shards(default=False). If shard_equal_rows is false, number
  2756. of rows of each shard may be not equal.
  2757. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used)
  2758. Examples:
  2759. >>> import mindspore.dataset as ds
  2760. >>> import mindspore.common.dtype as mstype
  2761. >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple tf data files
  2762. >>> # 1) get all rows from dataset_files with no explicit schema:
  2763. >>> # The meta-data in the first row will be used as a schema.
  2764. >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files)
  2765. >>> # 2) get all rows from dataset_files with user-defined schema:
  2766. >>> schema = ds.Schema()
  2767. >>> schema.add_column('col_1d', de_type=mindspore.int64, shape=[2])
  2768. >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema=schema)
  2769. >>> # 3) get all rows from dataset_files with schema file "./schema.json":
  2770. >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json")
  2771. """
  2772. @check_tfrecorddataset
  2773. def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
  2774. shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None):
  2775. super().__init__(num_parallel_workers)
  2776. self.dataset_files = self._find_files(dataset_files)
  2777. self.dataset_files.sort()
  2778. self.num_shards = num_shards
  2779. self.shard_id = shard_id
  2780. schema_obj = None
  2781. if (schema is not None) and (not isinstance(schema, Schema)):
  2782. schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it
  2783. self.schema = schema
  2784. self.columns_list = columns_list
  2785. self.num_samples = num_samples
  2786. self.cache = cache
  2787. if schema_obj is not None and num_samples is None:
  2788. self.num_samples = schema_obj.num_rows
  2789. if not isinstance(shuffle, (bool, Shuffle)):
  2790. raise TypeError("shuffle should be of boolean or enum 'Shuffle'.")
  2791. if not isinstance(shuffle, Shuffle):
  2792. if shuffle:
  2793. self.shuffle_level = Shuffle.GLOBAL
  2794. self.shuffle_files = True
  2795. else:
  2796. self.shuffle_level = None
  2797. self.shuffle_files = False
  2798. else:
  2799. self.shuffle_level = shuffle
  2800. self.shuffle_files = True
  2801. # The TF record dataset does not directly support a sampler. It has provided sampling arguments
  2802. # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
  2803. # the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
  2804. sampler_shuffle = self.shuffle_files
  2805. sampler = None
  2806. self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id,
  2807. non_mappable=True)
  2808. self.shard_equal_rows = shard_equal_rows
  2809. def get_args(self):
  2810. args = super().get_args()
  2811. args["dataset_files"] = self.dataset_files
  2812. if self.schema is not None:
  2813. if isinstance(self.schema, Schema):
  2814. self.schema.datasetType = 'TF'
  2815. if self.num_samples is not None:
  2816. self.schema.num_rows = self.num_samples
  2817. args["schema_json_string"] = self.schema.to_json()
  2818. else:
  2819. args["schema_file_path"] = self.schema
  2820. args["schema"] = self.schema
  2821. args["columns_list"] = self.columns_list
  2822. args["num_samples"] = self.num_samples
  2823. if self.shuffle_files is not None:
  2824. args["shuffle_files"] = self.shuffle_files
  2825. args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL)
  2826. args["shuffle"] = self.shuffle_level
  2827. args["num_shards"] = self.num_shards
  2828. args["shard_id"] = self.shard_id
  2829. args["shard_equal_rows"] = self.shard_equal_rows
  2830. args["cache"] = self.cache.cache_client if self.cache is not None else None
  2831. args["sampler"] = self.sampler
  2832. return args
  2833. def get_dataset_size(self, estimate=False):
  2834. """
  2835. Get the number of batches in an epoch.
  2836. Args:
  2837. estimate (bool, optional): Fast estimation of the dataset size instead of a full scan.
  2838. Return:
  2839. Number, number of batches.
  2840. """
  2841. if self._dataset_size is None:
  2842. num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate)
  2843. num_rows = get_num_rows(num_rows, self.num_shards)
  2844. if self.num_samples is None:
  2845. return num_rows
  2846. return min(self.num_samples, num_rows)
  2847. return self._dataset_size
  2848. # manually set dataset_size as a tempoary solution.
  2849. def set_dataset_size(self, value):
  2850. logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
  2851. if value >= 0:
  2852. self._dataset_size = value
  2853. else:
  2854. raise ValueError('Set dataset_size with negative value {}'.format(value))
  2855. def is_shuffled(self):
  2856. return self.shuffle_files
  2857. def is_sharded(self):
  2858. if self.num_shards is not None:
  2859. return self.num_shards > 1
  2860. return False
  2861. class ManifestDataset(MappableDataset):
  2862. """
  2863. A source dataset that reads images from a manifest file.
  2864. The generated dataset has two columns ['image', 'label'].
  2865. The shape of the image column is [image_size] if decode flag is False, or [H,W,C]
  2866. otherwise.
  2867. The type of the image tensor is uint8. The label is just a scalar uint64
  2868. tensor.
  2869. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  2870. below shows what input args are allowed and their expected behavior.
  2871. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  2872. :widths: 25 25 50
  2873. :header-rows: 1
  2874. * - Parameter 'sampler'
  2875. - Parameter 'shuffle'
  2876. - Expected Order Behavior
  2877. * - None
  2878. - None
  2879. - random order
  2880. * - None
  2881. - True
  2882. - random order
  2883. * - None
  2884. - False
  2885. - sequential order
  2886. * - Sampler object
  2887. - None
  2888. - order defined by sampler
  2889. * - Sampler object
  2890. - True
  2891. - not allowed
  2892. * - Sampler object
  2893. - False
  2894. - not allowed
  2895. Args:
  2896. dataset_file (str): File to be read.
  2897. usage (str, optional): Need train, eval or inference data (default="train").
  2898. num_samples (int, optional): The number of images to be included in the dataset.
  2899. (default=None, all images).
  2900. num_parallel_workers (int, optional): Number of workers to read the data
  2901. (default=None, number set in the config).
  2902. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  2903. order behavior shown in the table).
  2904. sampler (Sampler, optional): Object used to choose samples from the
  2905. dataset (default=None, expected order behavior shown in the table).
  2906. class_indexing (dict, optional): A str-to-int mapping from label name to index
  2907. (default=None, the folder names will be sorted alphabetically and each
  2908. class will be given a unique index starting from 0).
  2909. decode (bool, optional): decode the images after reading (default=False).
  2910. num_shards (int, optional): Number of shards that the dataset should be divided
  2911. into (default=None).
  2912. shard_id (int, optional): The shard ID within num_shards (default=None). This
  2913. argument should be specified only when num_shards is also specified.
  2914. Raises:
  2915. RuntimeError: If sampler and shuffle are specified at the same time.
  2916. RuntimeError: If sampler and sharding are specified at the same time.
  2917. RuntimeError: If num_shards is specified but shard_id is None.
  2918. RuntimeError: If shard_id is specified but num_shards is None.
  2919. RuntimeError: If class_indexing is not a dictionary.
  2920. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  2921. Examples:
  2922. >>> import mindspore.dataset as ds
  2923. >>> dataset_file = "/path/to/manifest_file.manifest"
  2924. >>> # 1) read all samples specified in manifest_file dataset with 8 threads for training:
  2925. >>> manifest_dataset = ds.ManifestDataset(dataset_file, usage="train", num_parallel_workers=8)
  2926. >>> # 2) reads samples (specified in manifest_file.manifest) for shard 0 in a 2-way distributed training setup:
  2927. >>> manifest_dataset = ds.ManifestDataset(dataset_file, num_shards=2, shard_id=0)
  2928. """
  2929. @check_manifestdataset
  2930. def __init__(self, dataset_file, usage="train", num_samples=None, num_parallel_workers=None,
  2931. shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None):
  2932. super().__init__(num_parallel_workers)
  2933. self.dataset_file = dataset_file
  2934. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  2935. if class_indexing is not None and not isinstance(class_indexing, dict):
  2936. raise RuntimeError("class_indexing should be a dictionary.")
  2937. self.num_samples = num_samples
  2938. self.class_indexing = class_indexing
  2939. self.decode = decode
  2940. self.usage = usage
  2941. self.shuffle_level = shuffle
  2942. self.num_shards = num_shards
  2943. self.shard_id = shard_id
  2944. def get_args(self):
  2945. args = super().get_args()
  2946. args["dataset_file"] = self.dataset_file
  2947. args["usage"] = self.usage
  2948. args["num_samples"] = self.num_samples
  2949. args["shuffle"] = self.shuffle_level
  2950. args["sampler"] = self.sampler
  2951. args["class_indexing"] = self.class_indexing
  2952. args["decode"] = self.decode
  2953. args["num_shards"] = self.num_shards
  2954. args["shard_id"] = self.shard_id
  2955. return args
  2956. def get_dataset_size(self):
  2957. """
  2958. Get the number of batches in an epoch.
  2959. Return:
  2960. Number, number of batches.
  2961. """
  2962. if self.class_indexing is None:
  2963. class_indexing = dict()
  2964. else:
  2965. class_indexing = self.class_indexing
  2966. num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[0]
  2967. rows_per_shard = get_num_rows(num_rows, self.num_shards)
  2968. rows_from_sampler = self._get_sampler_dataset_size()
  2969. if rows_from_sampler is None:
  2970. return rows_per_shard
  2971. return min(rows_from_sampler, rows_per_shard)
  2972. def num_classes(self):
  2973. """
  2974. Get the number of classes in a dataset.
  2975. Return:
  2976. Number, number of classes.
  2977. """
  2978. if self.class_indexing is None:
  2979. class_indexing = dict()
  2980. else:
  2981. class_indexing = self.class_indexing
  2982. return ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[1]
  2983. def get_class_indexing(self):
  2984. """
  2985. Get the class index.
  2986. Return:
  2987. Dict, A str-to-int mapping from label name to index.
  2988. """
  2989. if self.class_indexing is None:
  2990. class_indexing = dict()
  2991. else:
  2992. class_indexing = self.class_indexing
  2993. return ManifestOp.get_class_indexing(self.dataset_file, class_indexing, self.usage)
  2994. def is_shuffled(self):
  2995. if self.shuffle_level is None:
  2996. return True
  2997. return self.shuffle_level or self.sampler.is_shuffled()
  2998. def is_sharded(self):
  2999. if self.num_shards is not None:
  3000. return self.num_shards > 1
  3001. return self.sampler.is_sharded()
  3002. class Cifar10Dataset(MappableDataset):
  3003. """
  3004. A source dataset that reads cifar10 data.
  3005. The generated dataset has two columns ['image', 'label'].
  3006. The type of the image tensor is uint8. The label is just a scalar uint32
  3007. tensor.
  3008. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  3009. below shows what input args are allowed and their expected behavior.
  3010. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  3011. :widths: 25 25 50
  3012. :header-rows: 1
  3013. * - Parameter 'sampler'
  3014. - Parameter 'shuffle'
  3015. - Expected Order Behavior
  3016. * - None
  3017. - None
  3018. - random order
  3019. * - None
  3020. - True
  3021. - random order
  3022. * - None
  3023. - False
  3024. - sequential order
  3025. * - Sampler object
  3026. - None
  3027. - order defined by sampler
  3028. * - Sampler object
  3029. - True
  3030. - not allowed
  3031. * - Sampler object
  3032. - False
  3033. - not allowed
  3034. Citation of Cifar10 dataset.
  3035. .. code-block::
  3036. @techreport{Krizhevsky09,
  3037. author = {Alex Krizhevsky},
  3038. title = {Learning multiple layers of features from tiny images},
  3039. institution = {},
  3040. year = {2009},
  3041. howpublished = {http://www.cs.toronto.edu/~kriz/cifar.html},
  3042. description = {The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes,
  3043. with 6000 images per class. There are 50000 training images and 10000 test images.}
  3044. }
  3045. Args:
  3046. dataset_dir (str): Path to the root directory that contains the dataset.
  3047. num_samples (int, optional): The number of images to be included in the dataset.
  3048. (default=None, all images).
  3049. num_parallel_workers (int, optional): Number of workers to read the data
  3050. (default=None, number set in the config).
  3051. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  3052. order behavior shown in the table).
  3053. sampler (Sampler, optional): Object used to choose samples from the
  3054. dataset (default=None, expected order behavior shown in the table).
  3055. num_shards (int, optional): Number of shards that the dataset should be divided
  3056. into (default=None).
  3057. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3058. argument should be specified only when num_shards is also specified.
  3059. Raises:
  3060. RuntimeError: If sampler and shuffle are specified at the same time.
  3061. RuntimeError: If sampler and sharding are specified at the same time.
  3062. RuntimeError: If num_shards is specified but shard_id is None.
  3063. RuntimeError: If shard_id is specified but num_shards is None.
  3064. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  3065. Examples:
  3066. >>> import mindspore.dataset as ds
  3067. >>> dataset_dir = "/path/to/cifar10_dataset_directory"
  3068. >>> # 1) get all samples from CIFAR10 dataset in sequence:
  3069. >>> dataset = ds.Cifar10Dataset(dataset_dir=dataset_dir,shuffle=False)
  3070. >>> # 2) randomly select 350 samples from CIFAR10 dataset:
  3071. >>> dataset = ds.Cifar10Dataset(dataset_dir=dataset_dir,num_samples=350, shuffle=True)
  3072. >>> # 3) get samples from CIFAR10 dataset for shard 0 in a 2 way distributed training:
  3073. >>> dataset = ds.Cifar10Dataset(dataset_dir=dataset_dir,num_shards=2,shard_id=0)
  3074. >>> # in CIFAR10 dataset, each dictionary has keys "image" and "label"
  3075. """
  3076. @check_mnist_cifar_dataset
  3077. def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
  3078. shuffle=None, sampler=None, num_shards=None, shard_id=None):
  3079. super().__init__(num_parallel_workers)
  3080. self.dataset_dir = dataset_dir
  3081. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  3082. self.num_samples = num_samples
  3083. self.num_shards = num_shards
  3084. self.shard_id = shard_id
  3085. self.shuffle_level = shuffle
  3086. def get_args(self):
  3087. args = super().get_args()
  3088. args["dataset_dir"] = self.dataset_dir
  3089. args["num_samples"] = self.num_samples
  3090. args["sampler"] = self.sampler
  3091. args["num_shards"] = self.num_shards
  3092. args["shard_id"] = self.shard_id
  3093. args["shuffle"] = self.shuffle_level
  3094. return args
  3095. def get_dataset_size(self):
  3096. """
  3097. Get the number of batches in an epoch.
  3098. Return:
  3099. Number, number of batches.
  3100. """
  3101. num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
  3102. rows_per_shard = get_num_rows(num_rows, self.num_shards)
  3103. rows_from_sampler = self._get_sampler_dataset_size()
  3104. if rows_from_sampler is None:
  3105. return rows_per_shard
  3106. return min(rows_from_sampler, rows_per_shard)
  3107. def is_shuffled(self):
  3108. if self.shuffle_level is None:
  3109. return True
  3110. return self.shuffle_level or self.sampler.is_shuffled()
  3111. def is_sharded(self):
  3112. if self.num_shards is not None:
  3113. return self.num_shards > 1
  3114. return self.sampler.is_sharded()
  3115. class Cifar100Dataset(MappableDataset):
  3116. """
  3117. A source dataset that reads cifar100 data.
  3118. The generated dataset has three columns ['image', 'coarse_label', 'fine_label'].
  3119. The type of the image tensor is uint8. The coarse and fine are just a scalar uint32
  3120. tensor.
  3121. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  3122. below shows what input args are allowed and their expected behavior.
  3123. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  3124. :widths: 25 25 50
  3125. :header-rows: 1
  3126. * - Parameter 'sampler'
  3127. - Parameter 'shuffle'
  3128. - Expected Order Behavior
  3129. * - None
  3130. - None
  3131. - random order
  3132. * - None
  3133. - True
  3134. - random order
  3135. * - None
  3136. - False
  3137. - sequential order
  3138. * - Sampler object
  3139. - None
  3140. - order defined by sampler
  3141. * - Sampler object
  3142. - True
  3143. - not allowed
  3144. * - Sampler object
  3145. - False
  3146. - not allowed
  3147. Citation of Cifar100 dataset.
  3148. .. code-block::
  3149. @techreport{Krizhevsky09,
  3150. author = {Alex Krizhevsky},
  3151. title = {Learning multiple layers of features from tiny images},
  3152. institution = {},
  3153. year = {2009},
  3154. howpublished = {http://www.cs.toronto.edu/~kriz/cifar.html},
  3155. description = {This dataset is just like the CIFAR-10, except it has 100 classes containing 600 images
  3156. each. There are 500 training images and 100 testing images per class. The 100 classes in
  3157. the CIFAR-100 are grouped into 20 superclasses. Each image comes with a "fine" label (the
  3158. class to which it belongs) and a "coarse" label (the superclass to which it belongs).}
  3159. }
  3160. Args:
  3161. dataset_dir (str): Path to the root directory that contains the dataset.
  3162. num_samples (int, optional): The number of images to be included in the dataset.
  3163. (default=None, all images).
  3164. num_parallel_workers (int, optional): Number of workers to read the data
  3165. (default=None, number set in the config).
  3166. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  3167. order behavior shown in the table).
  3168. sampler (Sampler, optional): Object used to choose samples from the
  3169. dataset (default=None, expected order behavior shown in the table).
  3170. num_shards (int, optional): Number of shards that the dataset should be divided
  3171. into (default=None).
  3172. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3173. argument should be specified only when num_shards is also specified.
  3174. Raises:
  3175. RuntimeError: If sampler and shuffle are specified at the same time.
  3176. RuntimeError: If sampler and sharding are specified at the same time.
  3177. RuntimeError: If num_shards is specified but shard_id is None.
  3178. RuntimeError: If shard_id is specified but num_shards is None.
  3179. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  3180. Examples:
  3181. >>> import mindspore.dataset as ds
  3182. >>> dataset_dir = "/path/to/cifar100_dataset_directory"
  3183. >>> # 1) get all samples from CIFAR100 dataset in sequence:
  3184. >>> cifar100_dataset = ds.Cifar100Dataset(dataset_dir=dataset_dir,shuffle=False)
  3185. >>> # 2) randomly select 350 samples from CIFAR100 dataset:
  3186. >>> cifar100_dataset = ds.Cifar100Dataset(dataset_dir=dataset_dir,num_samples=350, shuffle=True)
  3187. >>> # in CIFAR100 dataset, each dictionary has 3 keys: "image", "fine_label" and "coarse_label"
  3188. """
  3189. @check_mnist_cifar_dataset
  3190. def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
  3191. shuffle=None, sampler=None, num_shards=None, shard_id=None):
  3192. super().__init__(num_parallel_workers)
  3193. self.dataset_dir = dataset_dir
  3194. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  3195. self.num_samples = num_samples
  3196. self.num_shards = num_shards
  3197. self.shard_id = shard_id
  3198. self.shuffle_level = shuffle
  3199. def get_args(self):
  3200. args = super().get_args()
  3201. args["dataset_dir"] = self.dataset_dir
  3202. args["num_samples"] = self.num_samples
  3203. args["sampler"] = self.sampler
  3204. args["num_shards"] = self.num_shards
  3205. args["shard_id"] = self.shard_id
  3206. args["shuffle"] = self.shuffle_level
  3207. return args
  3208. def get_dataset_size(self):
  3209. """
  3210. Get the number of batches in an epoch.
  3211. Return:
  3212. Number, number of batches.
  3213. """
  3214. num_rows = CifarOp.get_num_rows(self.dataset_dir, False)
  3215. rows_per_shard = get_num_rows(num_rows, self.num_shards)
  3216. rows_from_sampler = self._get_sampler_dataset_size()
  3217. if rows_from_sampler is None:
  3218. return rows_per_shard
  3219. return min(rows_from_sampler, rows_per_shard)
  3220. def is_shuffled(self):
  3221. if self.shuffle_level is None:
  3222. return True
  3223. return self.shuffle_level or self.sampler.is_shuffled()
  3224. def is_sharded(self):
  3225. if self.num_shards is not None:
  3226. return self.num_shards > 1
  3227. return self.sampler.is_sharded()
  3228. class RandomDataset(SourceDataset):
  3229. """
  3230. A source dataset that generates random data.
  3231. Args:
  3232. total_rows (int): number of rows for the dataset to generate (default=None, number of rows is random)
  3233. schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
  3234. If the schema is not provided, the random dataset generates a random schema.
  3235. columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
  3236. num_samples (int): number of samples to draw from the total. (default=None, which means all rows)
  3237. num_parallel_workers (int, optional): number of workers to read the data
  3238. (default=None, number set in the config).
  3239. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used)
  3240. shuffle (bool, optional): Whether or not to perform shuffle on the dataset
  3241. (default=None, expected order behavior shown in the table).
  3242. num_shards (int, optional): Number of shards that the dataset should be divided
  3243. into (default=None).
  3244. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3245. argument should be specified only when num_shards is also specified.
  3246. """
  3247. @check_random_dataset
  3248. def __init__(self, total_rows=None, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
  3249. cache=None, shuffle=None, num_shards=None, shard_id=None):
  3250. super().__init__(num_parallel_workers)
  3251. schema_obj = None
  3252. if (schema is not None) and (not isinstance(schema, Schema)):
  3253. schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it
  3254. self.schema = schema
  3255. self.columns_list = columns_list
  3256. sampler = None
  3257. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id, non_mappable=True)
  3258. self.num_samples = num_samples
  3259. self.cache = cache
  3260. if schema_obj is not None and total_rows is None:
  3261. self.total_rows = schema_obj.num_rows
  3262. elif total_rows is None:
  3263. self.total_rows = 0
  3264. else:
  3265. self.total_rows = total_rows
  3266. self.num_shards = num_shards
  3267. self.shard_id = shard_id
  3268. self.shuffle_level = shuffle
  3269. def get_args(self):
  3270. args = super().get_args()
  3271. if self.schema is not None:
  3272. if isinstance(self.schema, Schema):
  3273. self.schema.datasetType = 'Random'
  3274. if self.total_rows is not None:
  3275. self.schema.num_rows = self.total_rows
  3276. args["schema_json_string"] = self.schema.to_json()
  3277. else:
  3278. args["schema_file_path"] = self.schema
  3279. args["schema"] = self.schema
  3280. args["columns_list"] = self.columns_list
  3281. args["num_samples"] = self.num_samples
  3282. args["total_rows"] = self.total_rows
  3283. args["cache"] = self.cache.cache_client if self.cache is not None else None
  3284. args["sampler"] = self.sampler
  3285. return args
  3286. def get_dataset_size(self):
  3287. """
  3288. Get the number of batches in an epoch.
  3289. Return:
  3290. Number, number of batches.
  3291. """
  3292. num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
  3293. rows_per_shard = get_num_rows(num_rows, self.num_shards)
  3294. rows_from_sampler = self._get_sampler_dataset_size()
  3295. if rows_from_sampler is None:
  3296. return rows_per_shard
  3297. return min(rows_from_sampler, rows_per_shard)
  3298. def is_shuffled(self):
  3299. if self.shuffle_level is None:
  3300. return True
  3301. return self.shuffle_level or self.sampler.is_shuffled()
  3302. def is_sharded(self):
  3303. if self.num_shards is not None:
  3304. return self.num_shards > 1
  3305. return self.sampler.is_sharded()
  3306. class Schema:
  3307. """
  3308. Class to represent a schema of dataset.
  3309. Args:
  3310. schema_file(str): Path of schema file (default=None).
  3311. Return:
  3312. Schema object, schema info about dataset.
  3313. Raises:
  3314. RuntimeError: If schema file failed to load.
  3315. Example:
  3316. >>> import mindspore.dataset as ds
  3317. >>> import mindspore.common.dtype as mstype
  3318. >>> # create schema, specify column name, mindspore.dtype and shape of the column
  3319. >>> schema = ds.Schema()
  3320. >>> schema.add_column('col1', de_type=mindspore.int64, shape=[2])
  3321. """
  3322. def __init__(self, schema_file=None):
  3323. self.num_rows = None
  3324. if schema_file is None:
  3325. self.columns = []
  3326. self.dataset_type = ''
  3327. else:
  3328. if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK):
  3329. raise ValueError("The file %s does not exist or permission denied!" % schema_file)
  3330. try:
  3331. with open(schema_file, 'r') as load_f:
  3332. json_obj = json.load(load_f)
  3333. except json.decoder.JSONDecodeError:
  3334. raise RuntimeError("Schema file failed to load.")
  3335. except UnicodeDecodeError:
  3336. raise RuntimeError("Schema file failed to decode.")
  3337. except Exception:
  3338. raise RuntimeError("Schema file failed to open.")
  3339. self.from_json(json_obj)
  3340. @check_add_column
  3341. def add_column(self, name, de_type, shape=None):
  3342. """
  3343. Add new column to the schema.
  3344. Args:
  3345. name (str): name of the column.
  3346. de_type (str): data type of the column.
  3347. shape (list[int], optional): shape of the column
  3348. (default=None, [-1] which is an unknown shape of rank 1).
  3349. Raises:
  3350. ValueError: If column type is unknown.
  3351. """
  3352. new_column = dict()
  3353. new_column["name"] = name
  3354. if isinstance(de_type, typing.Type):
  3355. de_type = mstype_to_detype(de_type)
  3356. new_column["type"] = str(de_type)
  3357. else:
  3358. new_column["type"] = str(DataType(de_type))
  3359. if shape is not None:
  3360. new_column["shape"] = shape
  3361. new_column["rank"] = len(shape)
  3362. else:
  3363. new_column["rank"] = 1
  3364. self.columns.append(new_column)
  3365. def to_json(self):
  3366. """
  3367. Get a JSON string of the schema.
  3368. Returns:
  3369. Str, JSON string of the schema.
  3370. """
  3371. json_file = dict()
  3372. json_file["columns"] = self.columns
  3373. if self.dataset_type:
  3374. json_file["datasetType"] = self.dataset_type
  3375. if self.num_rows:
  3376. json_file["numRows"] = self.num_rows
  3377. return json.dumps(json_file, indent=2)
  3378. def parse_columns(self, columns):
  3379. """
  3380. Parse the columns and add it to self.
  3381. Args:
  3382. columns (dict or list[dict]): dataset attribution information, decoded from schema file.
  3383. - list[dict], 'name' and 'type' must be in keys, 'shape' optional.
  3384. - dict, columns.keys() as name, columns.values() is dict, and 'type' inside, 'shape' optional.
  3385. Raises:
  3386. RuntimeError: If failed to parse columns.
  3387. RuntimeError: If unknown items in columns.
  3388. RuntimeError: If column's name field is missing.
  3389. RuntimeError: If column's type field is missing.
  3390. Example:
  3391. >>> schema = Schema()
  3392. >>> columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]},
  3393. >>> {'name': 'label', 'type': 'int8', 'shape': [1]}]
  3394. >>> schema.parse_columns(columns1)
  3395. >>> columns2 = {'image': {'shape': [3, 3], 'type': 'int8'}, 'label': {'shape': [1], 'type': 'int8'}}
  3396. >>> schema.parse_columns(columns2)
  3397. """
  3398. self.columns = []
  3399. if isinstance(columns, list):
  3400. for column in columns:
  3401. try:
  3402. name = column.pop("name")
  3403. except KeyError:
  3404. raise RuntimeError("Column's name is missing")
  3405. try:
  3406. de_type = column.pop("type")
  3407. except KeyError:
  3408. raise RuntimeError("Column' type is missing")
  3409. shape = column.pop("shape", None)
  3410. column.pop("t_impl", None)
  3411. column.pop("rank", None)
  3412. if column:
  3413. raise RuntimeError("Unknown field {}".format(",".join(column.keys())))
  3414. self.add_column(name, de_type, shape)
  3415. elif isinstance(columns, dict):
  3416. for key, value in columns.items():
  3417. name = key
  3418. try:
  3419. de_type = value.pop("type")
  3420. except KeyError:
  3421. raise RuntimeError("Column' type is missing")
  3422. shape = value.pop("shape", None)
  3423. value.pop("t_impl", None)
  3424. value.pop("rank", None)
  3425. if value:
  3426. raise RuntimeError("Unknown field {}".format(",".join(value.keys())))
  3427. self.add_column(name, de_type, shape)
  3428. else:
  3429. raise RuntimeError("columns must be dict or list, columns contain name, type, shape(optional).")
  3430. def from_json(self, json_obj):
  3431. """
  3432. Get schema file from json file.
  3433. Args:
  3434. json_obj(dictionary): object of json parsed.
  3435. Raises:
  3436. RuntimeError: if there is unknown item in the object.
  3437. RuntimeError: if dataset type is missing in the object.
  3438. RuntimeError: if columns are missing in the object.
  3439. """
  3440. if not isinstance(json_obj, dict) or json_obj is None:
  3441. raise ValueError("Expected non-empty dict.")
  3442. for k, v in json_obj.items():
  3443. if k == "datasetType":
  3444. self.dataset_type = v
  3445. elif k == "numRows":
  3446. self.num_rows = v
  3447. elif k == "columns":
  3448. self.parse_columns(v)
  3449. else:
  3450. raise RuntimeError("Unknown field %s" % k)
  3451. if self.columns is None:
  3452. raise RuntimeError("Columns are missing.")
  3453. if self.num_rows is not None:
  3454. if not isinstance(self.num_rows, int) or self.num_rows <= 0:
  3455. raise ValueError("numRows must be greater than 0")
  3456. def __str__(self):
  3457. return self.to_json()
  3458. class VOCDataset(MappableDataset):
  3459. """
  3460. A source dataset for reading and parsing VOC dataset.
  3461. The generated dataset has multi-columns :
  3462. - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32],
  3463. ['difficult', dtype=uint32], ['truncate', dtype=uint32]].
  3464. - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]].
  3465. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  3466. below shows what input args are allowed and their expected behavior.
  3467. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  3468. :widths: 25 25 50
  3469. :header-rows: 1
  3470. * - Parameter 'sampler'
  3471. - Parameter 'shuffle'
  3472. - Expected Order Behavior
  3473. * - None
  3474. - None
  3475. - random order
  3476. * - None
  3477. - True
  3478. - random order
  3479. * - None
  3480. - False
  3481. - sequential order
  3482. * - Sampler object
  3483. - None
  3484. - order defined by sampler
  3485. * - Sampler object
  3486. - True
  3487. - not allowed
  3488. * - Sampler object
  3489. - False
  3490. - not allowed
  3491. Citation of VOC dataset.
  3492. .. code-block::
  3493. @article{Everingham10,
  3494. author = {Everingham, M. and Van~Gool, L. and Williams, C. K. I. and Winn, J. and Zisserman, A.},
  3495. title = {The Pascal Visual Object Classes (VOC) Challenge},
  3496. journal = {International Journal of Computer Vision},
  3497. volume = {88},
  3498. year = {2010},
  3499. number = {2},
  3500. month = {jun},
  3501. pages = {303--338},
  3502. biburl = {http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham10.html#bibtex},
  3503. howpublished = {http://host.robots.ox.ac.uk/pascal/VOC/voc{year}/index.html},
  3504. description = {The PASCAL Visual Object Classes (VOC) challenge is a benchmark in visual
  3505. object category recognition and detection, providing the vision and machine
  3506. learning communities with a standard dataset of images and annotation, and
  3507. standard evaluation procedures.}
  3508. }
  3509. Args:
  3510. dataset_dir (str): Path to the root directory that contains the dataset.
  3511. task (str): Set the task type of reading voc data, now only support "Segmentation" or "Detection"
  3512. (default="Segmentation").
  3513. mode (str): Set the data list txt file to be readed (default="train").
  3514. class_indexing (dict, optional): A str-to-int mapping from label name to index
  3515. (default=None, the folder names will be sorted alphabetically and each
  3516. class will be given a unique index starting from 0).
  3517. num_samples (int, optional): The number of images to be included in the dataset
  3518. (default=None, all images).
  3519. num_parallel_workers (int, optional): Number of workers to read the data
  3520. (default=None, number set in the config).
  3521. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  3522. order behavior shown in the table).
  3523. decode (bool, optional): Decode the images after reading (default=False).
  3524. sampler (Sampler, optional): Object used to choose samples from the dataset
  3525. (default=None, expected order behavior shown in the table).
  3526. num_shards (int, optional): Number of shards that the dataset should be divided
  3527. into (default=None).
  3528. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3529. argument should be specified only when num_shards is also specified.
  3530. Raises:
  3531. RuntimeError: If xml of Annotations is a invalid format.
  3532. RuntimeError: If xml of Annotations loss attribution of "object".
  3533. RuntimeError: If xml of Annotations loss attribution of "bndbox".
  3534. RuntimeError: If sampler and shuffle are specified at the same time.
  3535. RuntimeError: If sampler and sharding are specified at the same time.
  3536. RuntimeError: If num_shards is specified but shard_id is None.
  3537. RuntimeError: If shard_id is specified but num_shards is None.
  3538. ValueError: If task is not equal 'Segmentation' or 'Detection'.
  3539. ValueError: If task equal 'Segmentation' but class_indexing is not None.
  3540. ValueError: If txt related to mode is not exist.
  3541. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  3542. Examples:
  3543. >>> import mindspore.dataset as ds
  3544. >>> dataset_dir = "/path/to/voc_dataset_directory"
  3545. >>> # 1) read VOC data for segmenatation train
  3546. >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Segmentation", mode="train")
  3547. >>> # 2) read VOC data for detection train
  3548. >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train")
  3549. >>> # 3) read all VOC dataset samples in dataset_dir with 8 threads in random order:
  3550. >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train", num_parallel_workers=8)
  3551. >>> # 4) read then decode all VOC dataset samples in dataset_dir in sequence:
  3552. >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train", decode=True, shuffle=False)
  3553. >>> # in VOC dataset, if task='Segmentation', each dictionary has keys "image" and "target"
  3554. >>> # in VOC dataset, if task='Detection', each dictionary has keys "image" and "annotation"
  3555. """
  3556. @check_vocdataset
  3557. def __init__(self, dataset_dir, task="Segmentation", mode="train", class_indexing=None, num_samples=None,
  3558. num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
  3559. super().__init__(num_parallel_workers)
  3560. self.dataset_dir = dataset_dir
  3561. self.task = task
  3562. self.mode = mode
  3563. self.class_indexing = class_indexing
  3564. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  3565. self.num_samples = num_samples
  3566. self.decode = decode
  3567. self.shuffle_level = shuffle
  3568. self.num_shards = num_shards
  3569. self.shard_id = shard_id
  3570. def get_args(self):
  3571. args = super().get_args()
  3572. args["dataset_dir"] = self.dataset_dir
  3573. args["task"] = self.task
  3574. args["mode"] = self.mode
  3575. args["class_indexing"] = self.class_indexing
  3576. args["num_samples"] = self.num_samples
  3577. args["sampler"] = self.sampler
  3578. args["decode"] = self.decode
  3579. args["shuffle"] = self.shuffle_level
  3580. args["num_shards"] = self.num_shards
  3581. args["shard_id"] = self.shard_id
  3582. return args
  3583. def get_dataset_size(self):
  3584. """
  3585. Get the number of batches in an epoch.
  3586. Return:
  3587. Number, number of batches.
  3588. """
  3589. if self.num_samples is None:
  3590. num_samples = 0
  3591. else:
  3592. num_samples = self.num_samples
  3593. if self.class_indexing is None:
  3594. class_indexing = dict()
  3595. else:
  3596. class_indexing = self.class_indexing
  3597. num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
  3598. rows_per_shard = get_num_rows(num_rows, self.num_shards)
  3599. rows_from_sampler = self._get_sampler_dataset_size()
  3600. if rows_from_sampler is None:
  3601. return rows_per_shard
  3602. return min(rows_from_sampler, rows_per_shard)
  3603. def get_class_indexing(self):
  3604. """
  3605. Get the class index.
  3606. Return:
  3607. Dict, A str-to-int mapping from label name to index.
  3608. """
  3609. if self.task != "Detection":
  3610. raise NotImplementedError()
  3611. if self.class_indexing is None:
  3612. class_indexing = dict()
  3613. else:
  3614. class_indexing = self.class_indexing
  3615. return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing)
  3616. def is_shuffled(self):
  3617. if self.shuffle_level is None:
  3618. return True
  3619. return self.shuffle_level or self.sampler.is_shuffled()
  3620. def is_sharded(self):
  3621. if self.num_shards is not None:
  3622. return self.num_shards > 1
  3623. return self.sampler.is_sharded()
  3624. class CocoDataset(MappableDataset):
  3625. """
  3626. A source dataset for reading and parsing COCO dataset.
  3627. CocoDataset support four kinds of task: 2017 Train/Val/Test Detection, Keypoints, Stuff, Panoptic.
  3628. The generated dataset has multi-columns :
  3629. - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
  3630. ['iscrowd', dtype=uint32]].
  3631. - task='Stuff', column: [['image', dtype=uint8], ['segmentation',dtype=float32], ['iscrowd',dtype=uint32]].
  3632. - task='Keypoint', column: [['image', dtype=uint8], ['keypoints', dtype=float32],
  3633. ['num_keypoints', dtype=uint32]].
  3634. - task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
  3635. ['iscrowd', dtype=uint32], ['area', dtype=uint32]].
  3636. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. CocoDataset doesn't support
  3637. PKSampler. Table below shows what input args are allowed and their expected behavior.
  3638. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  3639. :widths: 25 25 50
  3640. :header-rows: 1
  3641. * - Parameter 'sampler'
  3642. - Parameter 'shuffle'
  3643. - Expected Order Behavior
  3644. * - None
  3645. - None
  3646. - random order
  3647. * - None
  3648. - True
  3649. - random order
  3650. * - None
  3651. - False
  3652. - sequential order
  3653. * - Sampler object
  3654. - None
  3655. - order defined by sampler
  3656. * - Sampler object
  3657. - True
  3658. - not allowed
  3659. * - Sampler object
  3660. - False
  3661. - not allowed
  3662. Citation of Coco dataset.
  3663. .. code-block::
  3664. @article{DBLP:journals/corr/LinMBHPRDZ14,
  3665. author = {Tsung{-}Yi Lin and Michael Maire and Serge J. Belongie and
  3666. Lubomir D. Bourdev and Ross B. Girshick and James Hays and
  3667. Pietro Perona and Deva Ramanan and Piotr Doll{\'{a}}r and C. Lawrence Zitnick},
  3668. title = {Microsoft {COCO:} Common Objects in Context},
  3669. journal = {CoRR},
  3670. volume = {abs/1405.0312},
  3671. year = {2014},
  3672. url = {http://arxiv.org/abs/1405.0312},
  3673. archivePrefix = {arXiv},
  3674. eprint = {1405.0312},
  3675. timestamp = {Mon, 13 Aug 2018 16:48:13 +0200},
  3676. biburl = {https://dblp.org/rec/journals/corr/LinMBHPRDZ14.bib},
  3677. bibsource = {dblp computer science bibliography, https://dblp.org},
  3678. description = {COCO is a large-scale object detection, segmentation, and captioning dataset.
  3679. It contains 91 common object categories with 82 of them having more than 5,000
  3680. labeled instances. In contrast to the popular ImageNet dataset, COCO has fewer
  3681. categories but more instances per category.}
  3682. }
  3683. Args:
  3684. dataset_dir (str): Path to the root directory that contains the dataset.
  3685. annotation_file (str): Path to the annotation json.
  3686. task (str): Set the task type of reading coco data, now support 'Detection'/'Stuff'/'Panoptic'/'Keypoint'
  3687. (default='Detection').
  3688. num_samples (int, optional): The number of images to be included in the dataset
  3689. (default=None, all images).
  3690. num_parallel_workers (int, optional): Number of workers to read the data
  3691. (default=None, number set in the config).
  3692. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  3693. order behavior shown in the table).
  3694. decode (bool, optional): Decode the images after reading (default=False).
  3695. sampler (Sampler, optional): Object used to choose samples from the dataset
  3696. (default=None, expected order behavior shown in the table).
  3697. num_shards (int, optional): Number of shards that the dataset should be divided
  3698. into (default=None).
  3699. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3700. argument should be specified only when num_shards is also specified.
  3701. Raises:
  3702. RuntimeError: If sampler and shuffle are specified at the same time.
  3703. RuntimeError: If sampler and sharding are specified at the same time.
  3704. RuntimeError: If num_shards is specified but shard_id is None.
  3705. RuntimeError: If shard_id is specified but num_shards is None.
  3706. RuntimeError: If parse json file failed.
  3707. ValueError: If task is not in ['Detection', 'Stuff', 'Panoptic', 'Keypoint'].
  3708. ValueError: If annotation_file is not exist.
  3709. ValueError: If dataset_dir is not exist.
  3710. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  3711. Examples:
  3712. >>> import mindspore.dataset as ds
  3713. >>> dataset_dir = "/path/to/coco_dataset_directory/image_folder"
  3714. >>> annotation_file = "/path/to/coco_dataset_directory/annotation_folder/annotation.json"
  3715. >>> # 1) read COCO data for Detection task
  3716. >>> coco_dataset = ds.CocoDataset(dataset_dir, annotation_file=annotation_file, task='Detection')
  3717. >>> # 2) read COCO data for Stuff task
  3718. >>> coco_dataset = ds.CocoDataset(dataset_dir, annotation_file=annotation_file, task='Stuff')
  3719. >>> # 3) read COCO data for Panoptic task
  3720. >>> coco_dataset = ds.CocoDataset(dataset_dir, annotation_file=annotation_file, task='Panoptic')
  3721. >>> # 4) read COCO data for Keypoint task
  3722. >>> coco_dataset = ds.CocoDataset(dataset_dir, annotation_file=annotation_file, task='Keypoint')
  3723. >>> # in COCO dataset, each dictionary has keys "image" and "annotation"
  3724. """
  3725. @check_cocodataset
  3726. def __init__(self, dataset_dir, annotation_file, task="Detection", num_samples=None, num_parallel_workers=None,
  3727. shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
  3728. super().__init__(num_parallel_workers)
  3729. self.dataset_dir = dataset_dir
  3730. self.annotation_file = annotation_file
  3731. self.task = task
  3732. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  3733. self.num_samples = num_samples
  3734. self.decode = decode
  3735. self.shuffle_level = shuffle
  3736. self.num_shards = num_shards
  3737. self.shard_id = shard_id
  3738. def get_args(self):
  3739. args = super().get_args()
  3740. args["dataset_dir"] = self.dataset_dir
  3741. args["annotation_file"] = self.annotation_file
  3742. args["task"] = self.task
  3743. args["num_samples"] = self.num_samples
  3744. args["sampler"] = self.sampler
  3745. args["decode"] = self.decode
  3746. args["shuffle"] = self.shuffle_level
  3747. args["num_shards"] = self.num_shards
  3748. args["shard_id"] = self.shard_id
  3749. return args
  3750. def get_dataset_size(self):
  3751. """
  3752. Get the number of batches in an epoch.
  3753. Return:
  3754. Number, number of batches.
  3755. """
  3756. num_rows = CocoOp.get_num_rows(self.dataset_dir, self.annotation_file, self.task)
  3757. rows_per_shard = get_num_rows(num_rows, self.num_shards)
  3758. rows_from_sampler = self._get_sampler_dataset_size()
  3759. if rows_from_sampler is None:
  3760. return rows_per_shard
  3761. return min(rows_from_sampler, rows_per_shard)
  3762. def get_class_indexing(self):
  3763. """
  3764. Get the class index.
  3765. Return:
  3766. Dict, A str-to-int mapping from label name to index.
  3767. """
  3768. if self.task not in {"Detection", "Panoptic"}:
  3769. raise NotImplementedError("Only 'Detection' and 'Panoptic' support get_class_indexing.")
  3770. class_index = CocoOp.get_class_indexing(self.dataset_dir, self.annotation_file, self.task)
  3771. return dict(class_index)
  3772. def is_shuffled(self):
  3773. if self.shuffle_level is None:
  3774. return True
  3775. return self.shuffle_level or self.sampler.is_shuffled()
  3776. def is_sharded(self):
  3777. if self.num_shards is not None:
  3778. return self.num_shards > 1
  3779. return self.sampler.is_sharded()
  3780. class CelebADataset(MappableDataset):
  3781. """
  3782. A source dataset for reading and parsing CelebA dataset.Only support list_attr_celeba.txt currently.
  3783. Note:
  3784. The generated dataset has two columns ['image', 'attr'].
  3785. The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
  3786. Citation of CelebA dataset.
  3787. .. code-block::
  3788. @article{DBLP:journals/corr/LiuLWT14,
  3789. author = {Ziwei Liu and Ping Luo and Xiaogang Wang and Xiaoou Tang},
  3790. title = {Deep Learning Face Attributes in the Wild},
  3791. journal = {CoRR},
  3792. volume = {abs/1411.7766},
  3793. year = {2014},
  3794. url = {http://arxiv.org/abs/1411.7766},
  3795. archivePrefix = {arXiv},
  3796. eprint = {1411.7766},
  3797. timestamp = {Tue, 10 Dec 2019 15:37:26 +0100},
  3798. biburl = {https://dblp.org/rec/journals/corr/LiuLWT14.bib},
  3799. bibsource = {dblp computer science bibliography, https://dblp.org},
  3800. howpublished = {http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html},
  3801. description = {CelebFaces Attributes Dataset (CelebA) is a large-scale face attributes dataset
  3802. with more than 200K celebrity images, each with 40 attribute annotations.
  3803. The images in this dataset cover large pose variations and background clutter.
  3804. CelebA has large diversities, large quantities, and rich annotations, including
  3805. * 10,177 number of identities,
  3806. * 202,599 number of face images, and
  3807. * 5 landmark locations, 40 binary attributes annotations per image.
  3808. The dataset can be employed as the training and test sets for the following computer
  3809. vision tasks: face attribute recognition, face detection, landmark (or facial part)
  3810. localization, and face editing & synthesis.}
  3811. }
  3812. Args:
  3813. dataset_dir (str): Path to the root directory that contains the dataset.
  3814. num_parallel_workers (int, optional): Number of workers to read the data (default=value set in the config).
  3815. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None).
  3816. dataset_type (str): one of 'all', 'train', 'valid' or 'test'.
  3817. sampler (Sampler, optional): Object used to choose samples from the dataset (default=None).
  3818. decode (bool, optional): decode the images after reading (default=False).
  3819. extensions (list[str], optional): List of file extensions to be
  3820. included in the dataset (default=None).
  3821. num_samples (int, optional): The number of images to be included in the dataset.
  3822. (default=None, all images).
  3823. num_shards (int, optional): Number of shards that the dataset should be divided
  3824. into (default=None).
  3825. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3826. argument should be specified only when num_shards is also specified.
  3827. """
  3828. @check_celebadataset
  3829. def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, dataset_type='all',
  3830. sampler=None, decode=False, extensions=None, num_samples=None, num_shards=None, shard_id=None):
  3831. super().__init__(num_parallel_workers)
  3832. self.dataset_dir = dataset_dir
  3833. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  3834. self.num_parallel_workers = num_parallel_workers
  3835. self.decode = decode
  3836. self.extensions = extensions
  3837. self.num_samples = num_samples
  3838. self.dataset_type = dataset_type
  3839. self.num_shards = num_shards
  3840. self.shard_id = shard_id
  3841. self.shuffle_level = shuffle
  3842. def get_args(self):
  3843. args = super().get_args()
  3844. args["dataset_dir"] = self.dataset_dir
  3845. args["sampler"] = self.sampler
  3846. args["shuffle"] = self.shuffle_level
  3847. args["decode"] = self.decode
  3848. args["extensions"] = self.extensions
  3849. args["num_samples"] = self.num_samples
  3850. args["dataset_type"] = self.dataset_type
  3851. args["num_shards"] = self.num_shards
  3852. args["shard_id"] = self.shard_id
  3853. return args
  3854. def get_dataset_size(self):
  3855. """
  3856. Get the number of batches in an epoch.
  3857. Return:
  3858. Number, number of batches.
  3859. """
  3860. if self._dataset_size is None:
  3861. dir = os.path.realpath(self.dataset_dir)
  3862. attr_file = os.path.join(dir, "list_attr_celeba.txt")
  3863. num_rows = ''
  3864. try:
  3865. with open(attr_file, 'r') as f:
  3866. num_rows = int(f.readline())
  3867. except FileNotFoundError:
  3868. raise RuntimeError("attr_file not found.")
  3869. except BaseException:
  3870. raise RuntimeError("Get dataset size failed from attribution file.")
  3871. rows_per_shard = get_num_rows(num_rows, self.num_shards)
  3872. if self.num_samples is not None:
  3873. rows_per_shard = min(self.num_samples, rows_per_shard)
  3874. rows_from_sampler = self._get_sampler_dataset_size()
  3875. if rows_from_sampler is None:
  3876. return rows_per_shard
  3877. return min(rows_from_sampler, rows_per_shard)
  3878. return self._dataset_size
  3879. def is_shuffled(self):
  3880. if self.shuffle_level is None:
  3881. return True
  3882. return self.shuffle_level or self.sampler.is_shuffled()
  3883. def is_sharded(self):
  3884. if self.num_shards is not None:
  3885. return self.num_shards > 1
  3886. return self.sampler.is_sharded()
  3887. class CLUEDataset(SourceDataset):
  3888. """
  3889. A source dataset that reads and parses CLUE datasets.
  3890. CLUE, the Chinese Language Understanding Evaluation Benchmark, a collection of datasets, baselines, pre-trained
  3891. models, corpus and leaderboard. Here we bring in classification task of CLUE, which are AFQMC, TNEWS, IFLYTEK,
  3892. CMNLI, WSC and CSL.
  3893. Citation of CLUE dataset.
  3894. .. code-block::
  3895. @article{CLUEbenchmark,
  3896. title = {CLUE: A Chinese Language Understanding Evaluation Benchmark},
  3897. author = {Liang Xu, Xuanwei Zhang, Lu Li, Hai Hu, Chenjie Cao, Weitang Liu, Junyi Li, Yudong Li,
  3898. Kai Sun, Yechen Xu, Yiming Cui, Cong Yu, Qianqian Dong, Yin Tian, Dian Yu, Bo Shi, Jun Zeng,
  3899. Rongzhao Wang, Weijian Xie, Yanting Li, Yina Patterson, Zuoyu Tian, Yiwen Zhang, He Zhou,
  3900. Shaoweihua Liu, Qipeng Zhao, Cong Yue, Xinrui Zhang, Zhengliang Yang, Zhenzhong Lan},
  3901. journal = {arXiv preprint arXiv:2004.05986},
  3902. year = {2020},
  3903. howpublished = {https://github.com/CLUEbenchmark/CLUE},
  3904. description = {CLUE, a Chinese Language Understanding Evaluation benchmark. It contains eight different
  3905. tasks, including single-sentence classification, sentence pair classification, and machine
  3906. reading comprehension.}
  3907. }
  3908. Args:
  3909. dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of
  3910. files. The list will be sorted in a lexicographical order.
  3911. task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'.
  3912. (default=AFQMC).
  3913. usage (str, optional): Need train, test or eval data (default="train").
  3914. num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
  3915. num_parallel_workers (int, optional): number of workers to read the data
  3916. (default=None, number set in the config).
  3917. shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
  3918. If shuffle is False, no shuffling will be performed;
  3919. If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
  3920. Otherwise, there are two levels of shuffling:
  3921. - Shuffle.GLOBAL: Shuffle both the files and samples.
  3922. - Shuffle.FILES: Shuffle files only.
  3923. num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
  3924. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3925. argument should be specified only when num_shards is also specified.
  3926. Examples:
  3927. >>> import mindspore.dataset as ds
  3928. >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
  3929. >>> dataset = ds.CLUEDataset(dataset_files=dataset_files, task='AFQMC', usage='train')
  3930. """
  3931. @check_cluedataset
  3932. def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None,
  3933. num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
  3934. super().__init__(num_parallel_workers)
  3935. self.dataset_files = self._find_files(dataset_files)
  3936. self.dataset_files.sort()
  3937. self.num_samples = num_samples
  3938. self.task_dict = {
  3939. 'AFQMC': {
  3940. 'train': {
  3941. 'sentence1': 'sentence1',
  3942. 'sentence2': 'sentence2',
  3943. 'label': 'label'
  3944. },
  3945. 'test': {
  3946. 'id': 'id',
  3947. 'sentence1': 'sentence1',
  3948. 'sentence2': 'sentence2'
  3949. },
  3950. 'eval': {
  3951. 'sentence1': 'sentence1',
  3952. 'sentence2': 'sentence2',
  3953. 'label': 'label'
  3954. }
  3955. },
  3956. 'CMNLI': {
  3957. 'train': {
  3958. 'sentence1': 'sentence1',
  3959. 'sentence2': 'sentence2',
  3960. 'label': 'label'
  3961. },
  3962. 'test': {
  3963. 'id': 'id',
  3964. 'sentence1': 'sentence1',
  3965. 'sentence2': 'sentence2'
  3966. },
  3967. 'eval': {
  3968. 'sentence1': 'sentence1',
  3969. 'sentence2': 'sentence2',
  3970. 'label': 'label'
  3971. }
  3972. },
  3973. 'CSL': {
  3974. 'train': {
  3975. 'id': 'id',
  3976. 'abst': 'abst',
  3977. 'keyword': 'keyword',
  3978. 'label': 'label'
  3979. },
  3980. 'test': {
  3981. 'id': 'id',
  3982. 'abst': 'abst',
  3983. 'keyword': 'keyword'
  3984. },
  3985. 'eval': {
  3986. 'id': 'id',
  3987. 'abst': 'abst',
  3988. 'keyword': 'keyword',
  3989. 'label': 'label'
  3990. }
  3991. },
  3992. 'IFLYTEK': {
  3993. 'train': {
  3994. 'label': 'label',
  3995. 'label_des': 'label_des',
  3996. 'sentence': 'sentence'
  3997. },
  3998. 'test': {
  3999. 'id': 'id',
  4000. 'sentence': 'sentence',
  4001. },
  4002. 'eval': {
  4003. 'label': 'label',
  4004. 'label_des': 'label_des',
  4005. 'sentence': 'sentence'
  4006. }
  4007. },
  4008. 'TNEWS': {
  4009. 'train': {
  4010. 'label': 'label',
  4011. 'label_desc': 'label_desc',
  4012. 'sentence': 'sentence',
  4013. 'keywords': 'keywords'
  4014. },
  4015. 'test': {
  4016. 'id': 'id',
  4017. 'sentence': 'sentence',
  4018. 'keywords': 'keywords'
  4019. },
  4020. 'eval': {
  4021. 'label': 'label',
  4022. 'label_desc': 'label_desc',
  4023. 'sentence': 'sentence',
  4024. 'keywords': 'keywords'
  4025. }
  4026. },
  4027. 'WSC': {
  4028. 'train': {
  4029. 'span1_index': 'target/span1_index',
  4030. 'span2_index': 'target/span2_index',
  4031. 'span1_text': 'target/span1_text',
  4032. 'span2_text': 'target/span2_text',
  4033. 'idx': 'idx',
  4034. 'label': 'label',
  4035. 'text': 'text'
  4036. },
  4037. 'test': {
  4038. 'span1_index': 'target/span1_index',
  4039. 'span2_index': 'target/span2_index',
  4040. 'span1_text': 'target/span1_text',
  4041. 'span2_text': 'target/span2_text',
  4042. 'idx': 'idx',
  4043. 'text': 'text'
  4044. },
  4045. 'eval': {
  4046. 'span1_index': 'target/span1_index',
  4047. 'span2_index': 'target/span2_index',
  4048. 'span1_text': 'target/span1_text',
  4049. 'span2_text': 'target/span2_text',
  4050. 'idx': 'idx',
  4051. 'label': 'label',
  4052. 'text': 'text'
  4053. }
  4054. }
  4055. }
  4056. self.cols_to_keyword = self.task_dict[task][usage]
  4057. if not isinstance(shuffle, (bool, Shuffle)):
  4058. raise TypeError("shuffle should be of boolean or enum 'Shuffle'.")
  4059. if not isinstance(shuffle, Shuffle):
  4060. if shuffle:
  4061. self.shuffle_level = Shuffle.GLOBAL
  4062. self.shuffle_files = True
  4063. else:
  4064. self.shuffle_level = None
  4065. self.shuffle_files = False
  4066. else:
  4067. self.shuffle_level = shuffle
  4068. self.shuffle_files = True
  4069. self.num_shards = num_shards
  4070. self.shard_id = shard_id
  4071. def get_args(self):
  4072. args = super().get_args()
  4073. args["dataset_files"] = self.dataset_files
  4074. args["num_samples"] = self.num_samples
  4075. if self.shuffle_files is not None:
  4076. args["shuffle_files"] = self.shuffle_files
  4077. args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL)
  4078. args["shuffle"] = self.shuffle_level
  4079. args["num_shards"] = self.num_shards
  4080. args["shard_id"] = self.shard_id
  4081. args["cols_to_keyword"] = self.cols_to_keyword
  4082. return args
  4083. def get_dataset_size(self):
  4084. """
  4085. Get the number of batches in an epoch.
  4086. Return:
  4087. Number, number of batches.
  4088. """
  4089. if self._dataset_size is None:
  4090. num_rows = ClueOp.get_num_rows(self.dataset_files)
  4091. num_rows = get_num_rows(num_rows, self.num_shards)
  4092. if self.num_samples is None:
  4093. return num_rows
  4094. return min(self.num_samples, num_rows)
  4095. return self._dataset_size
  4096. def is_shuffled(self):
  4097. return self.shuffle_files
  4098. def is_sharded(self):
  4099. if self.num_shards is not None:
  4100. return self.num_shards > 1
  4101. return False
  4102. class TextFileDataset(SourceDataset):
  4103. """
  4104. A source dataset that reads and parses datasets stored on disk in text format.
  4105. The generated dataset has one columns ['text'].
  4106. Args:
  4107. dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of
  4108. files. The list will be sorted in a lexicographical order.
  4109. num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
  4110. num_parallel_workers (int, optional): number of workers to read the data
  4111. (default=None, number set in the config).
  4112. shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
  4113. If shuffle is False, no shuffling will be performed;
  4114. If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
  4115. Otherwise, there are two levels of shuffling:
  4116. - Shuffle.GLOBAL: Shuffle both the files and samples.
  4117. - Shuffle.FILES: Shuffle files only.
  4118. num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
  4119. shard_id (int, optional): The shard ID within num_shards (default=None). This
  4120. argument should be specified only when num_shards is also specified.
  4121. Examples:
  4122. >>> import mindspore.dataset as ds
  4123. >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
  4124. >>> dataset = ds.TextFileDataset(dataset_files=dataset_files)
  4125. """
  4126. @check_textfiledataset
  4127. def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None,
  4128. shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
  4129. super().__init__(num_parallel_workers)
  4130. self.dataset_files = self._find_files(dataset_files)
  4131. self.dataset_files.sort()
  4132. self.num_samples = num_samples
  4133. if not isinstance(shuffle, (bool, Shuffle)):
  4134. raise TypeError("shuffle should be of boolean or enum 'Shuffle'.")
  4135. if not isinstance(shuffle, Shuffle):
  4136. if shuffle:
  4137. self.shuffle_level = Shuffle.GLOBAL
  4138. self.shuffle_files = True
  4139. else:
  4140. self.shuffle_level = None
  4141. self.shuffle_files = False
  4142. else:
  4143. self.shuffle_level = shuffle
  4144. self.shuffle_files = True
  4145. self.num_shards = num_shards
  4146. self.shard_id = shard_id
  4147. def get_args(self):
  4148. args = super().get_args()
  4149. args["dataset_files"] = self.dataset_files
  4150. args["num_samples"] = self.num_samples
  4151. if self.shuffle_files is not None:
  4152. args["shuffle_files"] = self.shuffle_files
  4153. args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL)
  4154. args["shuffle"] = self.shuffle_level
  4155. args["num_shards"] = self.num_shards
  4156. args["shard_id"] = self.shard_id
  4157. return args
  4158. def get_dataset_size(self):
  4159. """
  4160. Get the number of batches in an epoch.
  4161. Return:
  4162. Number, number of batches.
  4163. """
  4164. if self._dataset_size is None:
  4165. num_rows = TextFileOp.get_num_rows(self.dataset_files)
  4166. num_rows = get_num_rows(num_rows, self.num_shards)
  4167. # If the user gave a num samples in the dataset, then the sampler will limit the rows returned
  4168. # to that amount. Account for that here in the row count
  4169. if self.num_samples is not None and self.num_samples > 0 and num_rows > self.num_samples:
  4170. num_rows = self.num_samples
  4171. return num_rows
  4172. return self._dataset_size
  4173. def is_shuffled(self):
  4174. return self.shuffle_files
  4175. def is_sharded(self):
  4176. if self.num_shards is not None:
  4177. return self.num_shards > 1
  4178. return False
  4179. class _NumpySlicesDataset:
  4180. """
  4181. Mainly for dealing with several kinds of format of python data, and return one row each time.
  4182. """
  4183. def __init__(self, data, column_list=None):
  4184. self.column_list = None
  4185. # Convert dict data into tuple
  4186. if isinstance(data, dict):
  4187. data = self.process_dict(data)
  4188. if isinstance(data, tuple):
  4189. self.data = ()
  4190. data_len = len(data)
  4191. for i in range(data_len):
  4192. self.data = self.data + (np.array(data[i]),)
  4193. else:
  4194. self.data = (np.array(data),)
  4195. # check whether the data length in each column is equal
  4196. data_len = [len(data_item) for data_item in self.data]
  4197. if data_len[1:] != data_len[:-1]:
  4198. raise ValueError("Data length in each column is not equal.")
  4199. # Init column_name
  4200. if column_list is not None:
  4201. self.column_list = column_list
  4202. elif self.column_list is None:
  4203. self.column_list = []
  4204. column_num = len(self.data)
  4205. for i in range(column_num):
  4206. self.column_list.append("column_" + str(i))
  4207. def __getitem__(self, index):
  4208. data_row = [d[index, ...] for d in self.data]
  4209. data_res = tuple(data_row)
  4210. return data_res
  4211. def __len__(self):
  4212. return len(self.data[0])
  4213. def process_dict(self, input_data):
  4214. """
  4215. Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first.
  4216. """
  4217. # Convert pandas like dict(has "values" column) into General dict
  4218. data_keys = list(input_data.keys())
  4219. data_col = input_data[data_keys[0]]
  4220. if hasattr(data_col, "values"):
  4221. new_dict = {}
  4222. for key in data_keys:
  4223. item1 = input_data.pop(key)
  4224. new_dict[key] = item1.values
  4225. input_data = new_dict
  4226. # Convert the data in dict into tuple
  4227. data = ()
  4228. keys = list(input_data.keys())
  4229. self.column_list = keys
  4230. for key in keys:
  4231. value = input_data[key]
  4232. data = data + (list(value),)
  4233. return data
  4234. class NumpySlicesDataset(GeneratorDataset):
  4235. """
  4236. Create a dataset with given data slices, mainly for loading python data into dataset.
  4237. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  4238. below shows what input args are allowed and their expected behavior.
  4239. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  4240. :widths: 25 25 50
  4241. :header-rows: 1
  4242. * - Parameter 'sampler'
  4243. - Parameter 'shuffle'
  4244. - Expected Order Behavior
  4245. * - None
  4246. - None
  4247. - random order
  4248. * - None
  4249. - True
  4250. - random order
  4251. * - None
  4252. - False
  4253. - sequential order
  4254. * - Sampler object
  4255. - None
  4256. - order defined by sampler
  4257. * - Sampler object
  4258. - True
  4259. - not allowed
  4260. * - Sampler object
  4261. - False
  4262. - not allowed
  4263. Args:
  4264. data (list, tuple or dict) Input of Given data, supported data type includes list, tuple, dict and other numpy
  4265. format. Input data will be sliced in first dimension and generate many rows, large data is not recommend to
  4266. load in this way as data is loading into memory.
  4267. column_names (list[str], optional): List of column names of the dataset (default=None). If column_names not
  4268. provided, when data is dict, column_names will be its key, otherwise it will be like column_1, column_2 ...
  4269. num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images).
  4270. num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
  4271. shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
  4272. (default=None, expected order behavior shown in the table).
  4273. sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
  4274. required (default=None, expected order behavior shown in the table).
  4275. num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
  4276. When this argument is specified, 'num_samples' will not effect. Random accessible input is required.
  4277. shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
  4278. when num_shards is also specified. Random accessible input is required.
  4279. Examples:
  4280. >>> import mindspore.dataset as ds
  4281. >>> # 1) Input data can be a list
  4282. >>> data = [1, 2, 3]
  4283. >>> dataset1 = ds.NumpySlicesDataset(data, column_names=["column_1"])
  4284. >>> # 2) Input data can be a dict, and column_names will be its key
  4285. >>> data = {"a": [1, 2], "b": [3, 4]}
  4286. >>> dataset2 = ds.NumpySlicesDataset(data)
  4287. >>> # 3) Input data can be a tuple of lists (or numpy arrays), each tuple element refers to data in each column
  4288. >>> data = ([1, 2], [3, 4], [5, 6])
  4289. >>> dataset3 = ds.NumpySlicesDataset(data, column_names=["column_1", "column_2", "column_3"])
  4290. >>> # 4) Load data from csv file
  4291. >>> import pandas as pd
  4292. >>> df = pd.read_csv("file.csv")
  4293. >>> dataset4 = ds.NumpySlicesDataset(dict(df), shuffle=False)
  4294. """
  4295. @check_numpyslicesdataset
  4296. def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None,
  4297. sampler=None, num_shards=None, shard_id=None):
  4298. dataset = _NumpySlicesDataset(data, column_names)
  4299. super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples,
  4300. num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
  4301. num_shards=num_shards, shard_id=shard_id)
  4302. class BuildVocabDataset(DatasetOp):
  4303. """
  4304. Build a vocab from a dataset. This would collect all the unique words in a dataset and return a vocab
  4305. which contains top_k most frequent words (if top_k is specified)
  4306. This function is not meant to be called directly by user. To build vocab, please use the function
  4307. text.Vocab.from_dataset()
  4308. Args:
  4309. vocab(Vocab): text.vocab object.
  4310. columns(str or list, optional): column names to get words from. It can be a list of column names (Default is
  4311. None, all columns are used, return error if any column isn't string).
  4312. freq_range(tuple, optional): A tuple of integers (min_frequency, max_frequency). Words within the frequency
  4313. range would be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency
  4314. can be None, which corresponds to 0/total_words separately (default=None, all words are included).
  4315. top_k(int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are
  4316. taken. The top_k is taken after freq_range. If not enough top_k, all words will be taken (default=None,
  4317. all words are included).
  4318. special_tokens(list, optional): a list of strings, each one is a special token. for example
  4319. special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
  4320. special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens
  4321. is specified and special_first is set to None, special_tokens will be prepended. (default=None).
  4322. prefetch_size (int, optional): prefetch number of records ahead of the user's request (default=None).
  4323. """
  4324. def __init__(self, input_dataset, vocab, columns, freq_range, top_k, special_tokens, special_first,
  4325. prefetch_size=None):
  4326. super().__init__()
  4327. self.columns = columns
  4328. self.children.append(input_dataset)
  4329. self.prefetch_size = prefetch_size
  4330. self.vocab = vocab
  4331. self.freq_range = freq_range
  4332. self.top_k = top_k
  4333. self.special_tokens = special_tokens
  4334. self.special_first = special_first
  4335. input_dataset.parent.append(self)
  4336. def get_args(self):
  4337. args = super().get_args()
  4338. args["columns"] = self.columns
  4339. args["vocab"] = self.vocab
  4340. args["freq_range"] = self.freq_range
  4341. args["prefetch_size"] = self.prefetch_size
  4342. args["top_k"] = self.top_k
  4343. args["special_tokens"] = self.special_tokens
  4344. args["special_first"] = self.special_first
  4345. return args
  4346. def __deepcopy__(self, memodict):
  4347. if id(self) in memodict:
  4348. return memodict[id(self)]
  4349. cls = self.__class__
  4350. new_op = cls.__new__(cls)
  4351. memodict[id(self)] = new_op
  4352. new_op.children = copy.deepcopy(self.children, memodict)
  4353. new_op.columns = copy.deepcopy(self.columns, memodict)
  4354. new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
  4355. new_op.prefetch_size = copy.deepcopy(self.prefetch_size, memodict)
  4356. new_op.parent = copy.deepcopy(self.parent, memodict)
  4357. new_op.freq_range = copy.deepcopy(self.freq_range, memodict)
  4358. new_op.top_k = copy.deepcopy(self.top_k, memodict)
  4359. new_op.vocab = self.vocab
  4360. new_op.special_tokens = copy.deepcopy(self.special_tokens)
  4361. new_op.special_first = copy.deepcopy(self.special_first)
  4362. return new_op