You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 249 kB

5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521552255235524552555265527552855295530553155325533553455355536553755385539554055415542554355445545554655475548554955505551555255535554555555565557555855595560556155625563556455655566556755685569557055715572557355745575557655775578557955805581558255835584558555865587558855895590559155925593559455955596559755985599560056015602560356045605560656075608560956105611561256135614561556165617561856195620562156225623562456255626562756285629563056315632563356345635563656375638563956405641564256435644564556465647564856495650565156525653565456555656565756585659566056615662566356645665566656675668566956705671567256735674567556765677567856795680568156825683568456855686568756885689569056915692569356945695569656975698569957005701570257035704570557065707570857095710571157125713571457155716571757185719572057215722572357245725572657275728572957305731573257335734573557365737573857395740574157425743574457455746574757485749575057515752575357545755575657575758575957605761576257635764576557665767576857695770577157725773577457755776577757785779578057815782578357845785578657875788578957905791579257935794579557965797579857995800580158025803580458055806580758085809581058115812581358145815581658175818581958205821582258235824582558265827582858295830583158325833583458355836583758385839584058415842584358445845584658475848584958505851585258535854585558565857585858595860586158625863586458655866586758685869587058715872587358745875587658775878587958805881588258835884588558865887588858895890589158925893589458955896589758985899590059015902590359045905590659075908590959105911591259135914591559165917591859195920592159225923592459255926
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This dataset module supports various formats of datasets, including ImageNet, TFData,
  17. MNIST, Cifar10/100, Manifest, MindRecord, and more. This module loads data with
  18. high performance and parses data precisely. Some of the operations that are
  19. provided to users to preprocess data include shuffle, batch, repeat, map, and zip.
  20. """
  21. import glob
  22. import json
  23. import math
  24. import os
  25. import uuid
  26. import multiprocessing
  27. import queue
  28. from enum import Enum
  29. from importlib import import_module
  30. import threading
  31. import copy
  32. import numpy as np
  33. from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
  34. MindRecordOp, TextFileOp, ClueOp, CsvOp, VOCOp, CocoOp, CBatchInfo
  35. from mindspore._c_expression import typing
  36. from mindspore import log as logger
  37. from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
  38. import mindspore.dataset.transforms.py_transforms as py_transforms
  39. from . import samplers
  40. from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator, check_iterator_cleanup
  41. from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
  42. check_rename, check_numpyslicesdataset, check_device_send, \
  43. check_take, check_project, check_imagefolderdataset, check_mnist_cifar_dataset, check_manifestdataset, \
  44. check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
  45. check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
  46. check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, \
  47. check_paddeddataset, check_iterator
  48. from ..core.config import get_callback_timeout
  49. from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
  50. from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
  51. try:
  52. context = import_module("mindspore.context")
  53. except ModuleNotFoundError:
  54. context = None
  55. class Shuffle(str, Enum):
  56. GLOBAL: str = "global"
  57. FILES: str = "file"
  58. @check_zip
  59. def zip(datasets):
  60. """
  61. Zip the datasets in the input tuple of datasets.
  62. Args:
  63. datasets (tuple of class Dataset): A tuple of datasets to be zipped together.
  64. The number of datasets must be more than 1.
  65. Returns:
  66. DatasetOp, ZipDataset.
  67. Raises:
  68. ValueError: If the number of datasets is 1.
  69. TypeError: If datasets is not a tuple.
  70. Examples:
  71. >>> import mindspore.dataset as ds
  72. >>>
  73. >>> dataset_dir1 = "path/to/imagefolder_directory1"
  74. >>> dataset_dir2 = "path/to/imagefolder_directory2"
  75. >>> ds1 = ds.ImageFolderDataset(dataset_dir1, num_parallel_workers=8)
  76. >>> ds2 = ds.ImageFolderDataset(dataset_dir2, num_parallel_workers=8)
  77. >>>
  78. >>> # Create a dataset which is the combination of ds1 and ds2
  79. >>> data = ds.zip((ds1, ds2))
  80. """
  81. if len(datasets) <= 1:
  82. raise ValueError(
  83. "Can't zip empty or just one dataset!")
  84. return ZipDataset(datasets)
  85. def get_num_rows(num_rows, num_shards):
  86. """
  87. Get the number rows of the dataset according to the shards.
  88. Args:
  89. num_rows (int): Number rows of the dataset. It must be more than 0.
  90. num_shards (int or None): Number of shards that the dataset will be divided into.
  91. The number of shards must be None or more than 1.
  92. Returns:
  93. Int, number of rows.
  94. Raises:
  95. ValueError: If num_rows is invalid (< 0).
  96. ValueError: If num_shards is invalid (<= 0).
  97. """
  98. if num_rows < 0:
  99. raise ValueError("num_rows is invalid (< 0)")
  100. if num_shards is not None:
  101. if num_shards <= 0:
  102. raise ValueError("num_shards is invalid (<= 0)")
  103. if num_rows % num_shards == 0:
  104. num_rows = num_rows // num_shards
  105. else:
  106. num_rows = num_rows // num_shards + 1
  107. return num_rows
  108. class Dataset:
  109. """
  110. Abstract class to represent a dataset in DataEngine's data pipeline.
  111. This class is the base class of SourceDataset and DatasetOp, and represents
  112. a node in the data flow graph.
  113. Args:
  114. num_parallel_workers (int, optional): Number of workers to process the dataset in parallel
  115. (default=None).
  116. """
  117. def __init__(self, num_parallel_workers=None):
  118. # Note: children and parent are internal variables, not recommand for external using.
  119. self.children = []
  120. self.parent = []
  121. self.num_parallel_workers = num_parallel_workers
  122. self._device_iter = 0
  123. self._input_indexs = ()
  124. self._output_types = None
  125. self._output_shapes = None
  126. self.dataset_size = None
  127. self._batch_size = None
  128. self._num_classes = None
  129. self._repeat_count = None
  130. self._sync = False
  131. def _noop_mode(self):
  132. if _is_role_sched() or _is_role_pserver():
  133. return True
  134. return False
  135. def __add__(self, datasets):
  136. return self.concat(datasets)
  137. def get_args(self):
  138. """
  139. Return attributes (member variables) related to the current class.
  140. Must include all arguments passed to the __init__() of the current class, excluding 'input_dataset'.
  141. Args:
  142. Returns:
  143. Python dictionary.
  144. """
  145. args = dict()
  146. args["num_parallel_workers"] = self.num_parallel_workers
  147. return args
  148. @check_bucket_batch_by_length
  149. def bucket_batch_by_length(self, column_names, bucket_boundaries, bucket_batch_sizes,
  150. element_length_function=None, pad_info=None,
  151. pad_to_bucket_boundary=False, drop_remainder=False):
  152. """
  153. Bucket elements according to their lengths. Each bucket will be padded and batched when
  154. they are full.
  155. A length function is called on each row in the dataset. The row is then
  156. bucketed based on its length and bucket_boundaries. When a bucket reaches its
  157. corresponding size specified in bucket_batch_sizes, the entire bucket will be
  158. padded according to batch_info, and then batched. Each batch will be full,
  159. except for maybe the last batch for each bucket.
  160. Args:
  161. column_names (list[str]): Columns passed to element_length_function.
  162. bucket_boundaries (list[int]): A list consisting of the upper boundaries
  163. of the buckets. Must be strictly increasing. If there are n boundaries,
  164. n+1 buckets are created: One bucket for [0, bucket_boundaries[0]), one
  165. bucket for [bucket_boundaries[i], bucket_boundaries[i+1]) for each
  166. 0<i<n, and one bucket for [bucket_boundaries[n-1], inf).
  167. bucket_batch_sizes (list[int]): A list consisting of the batch sizes for
  168. each bucket. Must contain len(bucket_boundaries)+1 elements.
  169. element_length_function (Callable, optional): A function that takes in
  170. len(column_names) arguments and returns an int. If no value is
  171. provided, then len(column_names) must be 1, and the size of the first
  172. dimension of that column will be taken as the length (default=None).
  173. pad_info (dict, optional): Represents how to batch each column. The key
  174. corresponds to the column name, and the value must be a tuple of 2 elements.
  175. The first element corresponds to the shape to pad to, and the second
  176. element corresponds to the value to pad with. If a column is not
  177. specified, then that column will be padded to the longest in the current
  178. batch, and 0 will be used as the padding value. Any None dimensions will
  179. be padded to the longest in the current batch, unless if
  180. pad_to_bucket_boundary is True. If no padding is wanted, set pad_info
  181. to None (default=None).
  182. pad_to_bucket_boundary (bool, optional): If True, will pad each None
  183. dimension in pad_info to the bucket_boundary minus 1. If there are any
  184. elements that fall into the last bucket, an error will occur
  185. (default=False).
  186. drop_remainder (bool, optional): If True, will drop the last batch for each
  187. bucket if it is not a full batch (default=False).
  188. Examples:
  189. >>> import mindspore.dataset as ds
  190. >>>
  191. >>> # data is an instance of Dataset object.
  192. >>>
  193. >>> # Create a dataset where every 100 rows is combined into a batch
  194. >>> # and drops the last incomplete batch if there is one.
  195. >>> column_names = ["col1", "col2"]
  196. >>> buket_boundaries = [5, 10]
  197. >>> bucket_batch_sizes = [5, 1, 1]
  198. >>> element_length_function = (lambda col1, col2: max(len(col1), len(col2)))
  199. >>>
  200. >>> # Will pad col1 to shape [2, bucket_boundaries[i]] where i is the
  201. >>> # index of the bucket that is currently being batched.
  202. >>> # Will pad col2 to a shape where each dimension is the longest in all
  203. >>> # the elements currently being batched.
  204. >>> pad_info = {"col1", ([2, None], -1)}
  205. >>> pad_to_bucket_boundary = True
  206. >>>
  207. >>> data = data.bucket_batch_by_length(column_names, bucket_boundaries,
  208. >>> bucket_batch_sizes,
  209. >>> element_length_function, pad_info,
  210. >>> pad_to_bucket_boundary)
  211. """
  212. return BucketBatchByLengthDataset(self, column_names, bucket_boundaries, bucket_batch_sizes,
  213. element_length_function, pad_info,
  214. pad_to_bucket_boundary, drop_remainder)
  215. @check_batch
  216. def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
  217. input_columns=None, output_columns=None, column_order=None, pad_info=None):
  218. """
  219. Combine batch_size number of consecutive rows into batches.
  220. For any child node, a batch is treated as a single row.
  221. For any column, all the elements within that column must have the same shape.
  222. If a per_batch_map callable is provided, it will be applied to the batches of tensors.
  223. Note:
  224. The order of using repeat and batch reflects the number of batches and per_batch_map.
  225. It is recommended that the repeat operation be used after the batch operation.
  226. Args:
  227. batch_size (int or function): The number of rows each batch is created with. An
  228. int or callable which takes exactly 1 parameter, BatchInfo.
  229. drop_remainder (bool, optional): Determines whether or not to drop the last
  230. possibly incomplete batch (default=False). If True, and if there are less
  231. than batch_size rows available to make the last batch, then those rows will
  232. be dropped and not propagated to the child node.
  233. num_parallel_workers (int, optional): Number of workers to process the dataset in parallel (default=None).
  234. per_batch_map (callable, optional): Per batch map callable. A callable which takes
  235. (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represents a batch
  236. of Tensors on a given column. The number of lists should match with number of entries in input_columns.
  237. The last parameter of the callable should always be a BatchInfo object.
  238. input_columns (list[str], optional): List of names of the input columns. The size of the list should
  239. match with signature of per_batch_map callable.
  240. output_columns (list[str], optional): List of names assigned to the columns
  241. outputted by the last operation. This parameter is mandatory if len(input_columns) !=
  242. len(output_columns). The size of this list must match the number of output
  243. columns of the last operation. (default=None, output columns will have the same
  244. name as the input columns, i.e., the columns will be replaced).
  245. column_order (list[str], optional): List of all the desired columns to propagate to
  246. the child node. This list must be a subset of all the columns in the dataset after
  247. all operations are applied. The order of the columns in each row propagated to the
  248. child node follow the order they appear in this list. The parameter is mandatory
  249. if the len(input_columns) != len(output_columns). (default=None, all columns
  250. will be propagated to the child node, the order of the columns will remain the
  251. same).
  252. pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)}
  253. would pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0.
  254. Returns:
  255. BatchDataset, dataset batched.
  256. Examples:
  257. >>> import mindspore.dataset as ds
  258. >>>
  259. >>> # data is an instance of Dataset object.
  260. >>>
  261. >>> # Create a dataset where every 100 rows is combined into a batch
  262. >>> # and drops the last incomplete batch if there is one.
  263. >>> data = data.batch(100, True)
  264. >>>
  265. >>> # resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25)
  266. >>> def np_resize(col, batchInfo):
  267. >>> output = col.copy()
  268. >>> s = (batchInfo.get_batch_num() + 1) ** 2
  269. >>> index = 0
  270. >>> for c in col:
  271. >>> img = Image.fromarray(c.astype('uint8')).convert('RGB')
  272. >>> img = img.resize((s, s), Image.ANTIALIAS)
  273. >>> output[index] = np.array(img)
  274. >>> index += 1
  275. >>> return (output,)
  276. >>> data = data.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize)
  277. """
  278. return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns,
  279. output_columns, column_order, pad_info)
  280. @check_sync_wait
  281. def sync_wait(self, condition_name, num_batch=1, callback=None):
  282. '''
  283. Add a blocking condition to the input Dataset.
  284. Args:
  285. num_batch (int): the number of batches without blocking at the start of each epoch.
  286. condition_name (str): The condition name that is used to toggle sending next row.
  287. callback (function): The callback funciton that will be invoked when sync_update is called.
  288. Raises:
  289. RuntimeError: If condition name already exists.
  290. Examples:
  291. >>> import mindspore.dataset as ds
  292. >>>
  293. >>> # data is an instance of Dataset object.
  294. >>> data = data.sync_wait("callback1")
  295. >>> data = data.batch(batch_size)
  296. >>> for batch_data in data.create_dict_iterator():
  297. >>> data = data.sync_update("callback1")
  298. '''
  299. return SyncWaitDataset(self, condition_name, num_batch, callback)
  300. @check_shuffle
  301. def shuffle(self, buffer_size):
  302. """
  303. Randomly shuffles the rows of this dataset using the following algorithm:
  304. 1. Make a shuffle buffer that contains the first buffer_size rows.
  305. 2. Randomly select an element from the shuffle buffer to be the next row
  306. propogated to the child node.
  307. 3. Get the next row (if any) from the parent node and put it in the shuffle buffer.
  308. 4. Repeat steps 2 and 3 until there are no more rows left in the shuffle buffer.
  309. A seed can be provided to be used on the first epoch. In every subsequent
  310. epoch, the seed is changed to a new one, randomly generated value.
  311. Args:
  312. buffer_size (int): The size of the buffer (must be larger than 1) for
  313. shuffling. Setting buffer_size equal to the number of rows in the entire
  314. dataset will result in a global shuffle.
  315. Returns:
  316. ShuffleDataset, dataset shuffled.
  317. Raises:
  318. RuntimeError: If exist sync operators before shuffle.
  319. Examples:
  320. >>> import mindspore.dataset as ds
  321. >>>
  322. >>> # data is an instance of Dataset object.
  323. >>> # Optionally set the seed for the first epoch
  324. >>> ds.config.set_seed(58)
  325. >>>
  326. >>> # Create a shuffled dataset using a shuffle buffer of size 4
  327. >>> data = data.shuffle(4)
  328. """
  329. return ShuffleDataset(self, buffer_size)
  330. def flat_map(self, func):
  331. """
  332. Map `func` to each row in dataset and flatten the result.
  333. The specified `func` is a function that must take one 'Ndarray' as input
  334. and return a 'Dataset'.
  335. Args:
  336. func (function): A function that must take one 'Ndarray' as an argument and
  337. return a 'Dataset'.
  338. Returns:
  339. Dataset, applied by the function.
  340. Examples:
  341. >>> import mindspore.dataset as ds
  342. >>> import mindspore.dataset.text as text
  343. >>>
  344. >>> # Declare a function which returns a Dataset object
  345. >>> def flat_map_func(x):
  346. >>> data_dir = text.to_str(x[0])
  347. >>> d = ds.ImageFolderDataset(data_dir)
  348. >>> return d
  349. >>> # data is an instance of a Dataset object.
  350. >>> data = ds.TextFileDataset(DATA_FILE)
  351. >>> data = data.flat_map(flat_map_func)
  352. Raises:
  353. TypeError: If `func` is not a function.
  354. TypeError: If `func` doesn't return a Dataset.
  355. """
  356. dataset = None
  357. if not hasattr(func, '__call__'):
  358. logger.error("func must be a function.")
  359. raise TypeError("func must be a function.")
  360. for row_data in self.create_tuple_iterator(output_numpy=True):
  361. if dataset is None:
  362. dataset = func(row_data)
  363. else:
  364. dataset += func(row_data)
  365. if not isinstance(dataset, Dataset):
  366. logger.error("flat_map must return a Dataset object.")
  367. raise TypeError("flat_map must return a Dataset object.")
  368. return dataset
  369. @check_map
  370. def map(self, operations, input_columns=None, output_columns=None, column_order=None,
  371. num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
  372. """
  373. Apply each operation in operations to this dataset.
  374. The order of operations is determined by the position of each operation in the operations parameter.
  375. operations[0] will be applied first, then operations[1], then operations[2], etc.
  376. Each operation will be passed one or more columns from the dataset as input, and zero or
  377. more columns will be outputted. The first operation will be passed the columns specified
  378. in input_columns as input. If there is more than one operator in operations, the outputted
  379. columns of the previous operation are used as the input columns for the next operation.
  380. The columns outputted by the very last operation will be assigned names specified by
  381. output_columns.
  382. Only the columns specified in column_order will be propagated to the child node. These
  383. columns will be in the same order as specified in column_order.
  384. Args:
  385. operations (Union[list[TensorOp], list[functions]]): List of operations to be
  386. applied on the dataset. Operations are applied in the order they appear in this list.
  387. input_columns (list[str], optional): List of the names of the columns that will be passed to
  388. the first operation as input. The size of this list must match the number of
  389. input columns expected by the first operator. (default=None, the first
  390. operation will be passed however many columns that is required, starting from
  391. the first column).
  392. output_columns (list[str], optional): List of names assigned to the columns outputted by
  393. the last operation. This parameter is mandatory if len(input_columns) !=
  394. len(output_columns). The size of this list must match the number of output
  395. columns of the last operation. (default=None, output columns will have the same
  396. name as the input columns, i.e., the columns will be replaced).
  397. column_order (list[str], optional): List of all the desired columns to propagate to the
  398. child node. This list must be a subset of all the columns in the dataset after
  399. all operations are applied. The order of the columns in each row propagated to the
  400. child node follow the order they appear in this list. The parameter is mandatory
  401. if the len(input_columns) != len(output_columns). (default=None, all columns
  402. will be propagated to the child node, the order of the columns will remain the
  403. same).
  404. num_parallel_workers (int, optional): Number of threads used to process the dataset in
  405. parallel (default=None, the value from the configuration will be used).
  406. python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker processes. This
  407. option could be beneficial if the Python operation is computational heavy (default=False).
  408. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
  409. The cache feature is under development and is not recommended.
  410. callbacks: (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None).
  411. Returns:
  412. MapDataset, dataset after mapping operation.
  413. Examples:
  414. >>> import mindspore.dataset as ds
  415. >>> import mindspore.dataset.vision.c_transforms as c_transforms
  416. >>>
  417. >>> # data is an instance of Dataset which has 2 columns, "image" and "label".
  418. >>> # ds_pyfunc is an instance of Dataset which has 3 columns, "col0", "col1", and "col2".
  419. >>> # Each column is a 2D array of integers.
  420. >>>
  421. >>> # Set the global configuration value for num_parallel_workers to be 2.
  422. >>> # Operations which use this configuration value will use 2 worker threads,
  423. >>> # unless otherwise specified in the operator's constructor.
  424. >>> # set_num_parallel_workers can be called again later if a different
  425. >>> # global configuration value for the number of worker threads is desired.
  426. >>> ds.config.set_num_parallel_workers(2)
  427. >>>
  428. >>> # Define two operations, where each operation accepts 1 input column and outputs 1 column.
  429. >>> decode_op = c_transforms.Decode(rgb_format=True)
  430. >>> random_jitter_op = c_transforms.RandomColorAdjust((0.8, 0.8), (1, 1), (1, 1), (0, 0))
  431. >>>
  432. >>> # 1) Simple map example
  433. >>>
  434. >>> operations = [decode_op]
  435. >>> input_columns = ["image"]
  436. >>>
  437. >>> # Apply decode_op on column "image". This column will be replaced by the outputted
  438. >>> # column of decode_op. Since column_order is not provided, both columns "image"
  439. >>> # and "label" will be propagated to the child node in their original order.
  440. >>> ds_decoded = data.map(operations, input_columns)
  441. >>>
  442. >>> # Rename column "image" to "decoded_image".
  443. >>> output_columns = ["decoded_image"]
  444. >>> ds_decoded = data.map(operations, input_columns, output_columns)
  445. >>>
  446. >>> # Specify the order of the columns.
  447. >>> column_order ["label", "image"]
  448. >>> ds_decoded = data.map(operations, input_columns, None, column_order)
  449. >>>
  450. >>> # Rename column "image" to "decoded_image" and also specify the order of the columns.
  451. >>> column_order ["label", "decoded_image"]
  452. >>> output_columns = ["decoded_image"]
  453. >>> ds_decoded = data.map(operations, input_columns, output_columns, column_order)
  454. >>>
  455. >>> # Rename column "image" to "decoded_image" and keep only this column.
  456. >>> column_order ["decoded_image"]
  457. >>> output_columns = ["decoded_image"]
  458. >>> ds_decoded = data.map(operations, input_columns, output_columns, column_order)
  459. >>>
  460. >>> # A simple example using pyfunc: Renaming columns and specifying column order
  461. >>> # work in the same way as the previous examples.
  462. >>> input_columns = ["col0"]
  463. >>> operations = [(lambda x: x + 1)]
  464. >>> ds_mapped = ds_pyfunc.map(operations, input_columns)
  465. >>>
  466. >>> # 2) Map example with more than one operation
  467. >>>
  468. >>> # If this list of operations is used with map, decode_op will be applied
  469. >>> # first, then random_jitter_op will be applied.
  470. >>> operations = [decode_op, random_jitter_op]
  471. >>>
  472. >>> input_columns = ["image"]
  473. >>>
  474. >>> # Create a dataset where the images are decoded, then randomly color jittered.
  475. >>> # decode_op takes column "image" as input and outputs one column. The column
  476. >>> # outputted by decode_op is passed as input to random_jitter_op.
  477. >>> # random_jitter_op will output one column. Column "image" will be replaced by
  478. >>> # the column outputted by random_jitter_op (the very last operation). All other
  479. >>> # columns are unchanged. Since column_order is not specified, the order of the
  480. >>> # columns will remain the same.
  481. >>> ds_mapped = data.map(operations, input_columns)
  482. >>>
  483. >>> # Create a dataset that is identical to ds_mapped, except the column "image"
  484. >>> # that is outputted by random_jitter_op is renamed to "image_transformed".
  485. >>> # Specifying column order works in the same way as examples in 1).
  486. >>> output_columns = ["image_transformed"]
  487. >>> ds_mapped_and_renamed = data.map(operation, input_columns, output_columns)
  488. >>>
  489. >>> # Multiple operations using pyfunc: Renaming columns and specifying column order
  490. >>> # work in the same way as examples in 1).
  491. >>> input_columns = ["col0"]
  492. >>> operations = [(lambda x: x + x), (lambda x: x - 1)]
  493. >>> output_columns = ["col0_mapped"]
  494. >>> ds_mapped = ds_pyfunc.map(operations, input_columns, output_columns)
  495. >>>
  496. >>> # 3) Example where number of input columns is not equal to number of output columns
  497. >>>
  498. >>> # operations[0] is a lambda that takes 2 columns as input and outputs 3 columns.
  499. >>> # operations[1] is a lambda that takes 3 columns as input and outputs 1 column.
  500. >>> # operations[1] is a lambda that takes 1 column as input and outputs 4 columns.
  501. >>> #
  502. >>> # Note: The number of output columns of operation[i] must equal the number of
  503. >>> # input columns of operation[i+1]. Otherwise, this map call will also result
  504. >>> # in an error.
  505. >>> operations = [(lambda x y: (x, x + y, x + y + 1)),
  506. >>> (lambda x y z: x * y * z),
  507. >>> (lambda x: (x % 2, x % 3, x % 5, x % 7))]
  508. >>>
  509. >>> # Note: Since the number of input columns is not the same as the number of
  510. >>> # output columns, the output_columns and column_order parameters must be
  511. >>> # specified. Otherwise, this map call will also result in an error.
  512. >>> input_columns = ["col2", "col0"]
  513. >>> output_columns = ["mod2", "mod3", "mod5", "mod7"]
  514. >>>
  515. >>> # Propagate all columns to the child node in this order:
  516. >>> column_order = ["col0", "col2", "mod2", "mod3", "mod5", "mod7", "col1"]
  517. >>> ds_mapped = ds_pyfunc.map(operations, input_columns, output_columns, column_order)
  518. >>>
  519. >>> # Propagate some columns to the child node in this order:
  520. >>> column_order = ["mod7", "mod3", "col1"]
  521. >>> ds_mapped = ds_pyfunc.map(operations, input_columns, output_columns, column_order)
  522. """
  523. return MapDataset(self, operations, input_columns, output_columns, column_order, num_parallel_workers,
  524. python_multiprocessing, cache, callbacks)
  525. @check_filter
  526. def filter(self, predicate, input_columns=None, num_parallel_workers=1):
  527. """
  528. Filter dataset by predicate.
  529. Note:
  530. If input_columns not provided or empty, all columns will be used.
  531. Args:
  532. predicate (callable): Python callable which returns a boolean value. If False then filter the element.
  533. input_columns (list[str], optional): List of names of the input columns, when
  534. default=None, the predicate will be applied on all columns in the dataset.
  535. num_parallel_workers (int, optional): Number of workers to process the dataset
  536. in parallel (default=None).
  537. Returns:
  538. FilterDataset, dataset filter.
  539. Examples:
  540. >>> import mindspore.dataset as ds
  541. >>> # generator data(0 ~ 63)
  542. >>> # filter the data that greater than or equal to 11
  543. >>> dataset_f = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"])
  544. """
  545. return FilterDataset(self, predicate, input_columns, num_parallel_workers)
  546. @check_repeat
  547. def repeat(self, count=None):
  548. """
  549. Repeat this dataset count times. Repeat indefinitely if the count is None or -1.
  550. Note:
  551. The order of using repeat and batch reflects the number of batches. It is recommended that
  552. the repeat operation be used after the batch operation.
  553. If dataset_sink_mode is False, the repeat operation is invalid.
  554. If dataset_sink_mode is True, repeat count must be equal to the epoch of training. Otherwise,
  555. errors could occur since the amount of data is not the amount training requires.
  556. Args:
  557. count (int): Number of times the dataset is repeated (default=None).
  558. Returns:
  559. RepeatDataset, dataset repeated.
  560. Examples:
  561. >>> import mindspore.dataset as ds
  562. >>>
  563. >>> # data is an instance of Dataset object.
  564. >>>
  565. >>> # Create a dataset where the dataset is repeated for 50 epochs
  566. >>> repeated = data.repeat(50)
  567. >>>
  568. >>> # Create a dataset where each epoch is shuffled individually
  569. >>> shuffled_and_repeated = data.shuffle(10)
  570. >>> shuffled_and_repeated = shuffled_and_repeated.repeat(50)
  571. >>>
  572. >>> # Create a dataset where the dataset is first repeated for
  573. >>> # 50 epochs before shuffling. The shuffle operator will treat
  574. >>> # the entire 50 epochs as one big dataset.
  575. >>> repeat_and_shuffle = data.repeat(50)
  576. >>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10)
  577. """
  578. if count == 1:
  579. return self
  580. return RepeatDataset(self, count)
  581. @check_skip
  582. def skip(self, count):
  583. """
  584. Skip the first N elements of this dataset.
  585. Args:
  586. count (int): Number of elements in the dataset to be skipped.
  587. Returns:
  588. SkipDataset, dataset skipped.
  589. Examples:
  590. >>> import mindspore.dataset as ds
  591. >>>
  592. >>> # data is an instance of Dataset object.
  593. >>> # Create a dataset which skips first 3 elements from data
  594. >>> data = data.skip(3)
  595. """
  596. return SkipDataset(self, count)
  597. @check_take
  598. def take(self, count=-1):
  599. """
  600. Takes at most given numbers of elements from the dataset.
  601. Note:
  602. 1. If count is greater than the number of elements in the dataset or equal to -1,
  603. all the elements in dataset will be taken.
  604. 2. The order of using take and batch matters. If take is before batch operation,
  605. then take given number of rows; otherwise take given number of batches.
  606. Args:
  607. count (int, optional): Number of elements to be taken from the dataset (default=-1).
  608. Returns:
  609. TakeDataset, dataset taken.
  610. Examples:
  611. >>> import mindspore.dataset as ds
  612. >>>
  613. >>> # data is an instance of Dataset object.
  614. >>> # Create a dataset where the dataset includes 50 elements.
  615. >>> data = data.take(50)
  616. """
  617. if count == -1:
  618. return self
  619. return TakeDataset(self, count)
  620. def _get_absolute_split_sizes(self, sizes):
  621. """
  622. Internal method called by split to calculate absolute split sizes and to
  623. do some error checking after calculating absolute split sizes.
  624. """
  625. # Call get_dataset_size here and check input here because
  626. # don't want to call this once in check_split and another time in
  627. # here again
  628. dataset_size = self.get_dataset_size()
  629. if dataset_size is None or dataset_size <= 0:
  630. raise RuntimeError("dataset_size is unknown, unable to split.")
  631. if not isinstance(sizes, list):
  632. raise RuntimeError("sizes must be a list.")
  633. all_int = all(isinstance(item, int) for item in sizes)
  634. if all_int:
  635. sizes_sum = sum(sizes)
  636. if sizes_sum != dataset_size:
  637. raise RuntimeError("Sum of split sizes {} is not equal to dataset size {}."
  638. .format(sizes_sum, dataset_size))
  639. return sizes
  640. absolute_sizes = []
  641. for item in sizes:
  642. absolute_size = int(round(item * dataset_size))
  643. if absolute_size == 0:
  644. raise RuntimeError("Split percentage {} is too small.".format(item))
  645. absolute_sizes.append(absolute_size)
  646. absolute_sizes_sum = sum(absolute_sizes)
  647. # if we still need more rows, give them to the first split.
  648. # if we have too many rows, remove the extras from the first split that has
  649. # enough rows.
  650. size_difference = int(dataset_size - absolute_sizes_sum)
  651. if size_difference > 0:
  652. absolute_sizes[0] += size_difference
  653. else:
  654. for i, _ in enumerate(absolute_sizes):
  655. if absolute_sizes[i] + size_difference > 0:
  656. absolute_sizes[i] += size_difference
  657. break
  658. if sum(absolute_sizes) != dataset_size:
  659. raise RuntimeError("Sum of calculated split sizes {} is not equal to dataset size {}."
  660. .format(absolute_sizes_sum, dataset_size))
  661. return absolute_sizes
  662. @check_split
  663. def split(self, sizes, randomize=True):
  664. """
  665. Split the dataset into smaller, non-overlapping datasets.
  666. This is a general purpose split function which can be called from any operator in the pipeline.
  667. There is another, optimized split function, which will be called automatically if ds.split is
  668. called where ds is a MappableDataset.
  669. Args:
  670. sizes (Union[list[int], list[float]]): If a list of integers [s1, s2, …, sn] is
  671. provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
  672. respectively. If the sum of all sizes does not equal the original dataset size, an
  673. error will occur.
  674. If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1
  675. and must sum to 1, otherwise an error will occur. The dataset will be split into n
  676. Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the
  677. original dataset.
  678. If after rounding:
  679. - Any size equals 0, an error will occur.
  680. - The sum of split sizes < K, the difference will be added to the first split.
  681. - The sum of split sizes > K, the difference will be removed from the first large
  682. enough split such that it will have at least 1 row after removing the difference.
  683. randomize (bool, optional): Determines whether or not to split the data randomly (default=True).
  684. If True, the data will be randomly split. Otherwise, each split will be created with
  685. consecutive rows from the dataset.
  686. Note:
  687. 1. Dataset cannot be sharded if split is going to be called.
  688. 2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
  689. Shuffling the dataset may not be deterministic, which means the data in each split
  690. will be different in each epoch.
  691. Raises:
  692. RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
  693. RuntimeError: If sizes is list of integers and sum of all elements in sizes does not
  694. equal the dataset size.
  695. RuntimeError: If sizes is list of float and there is a split with size 0 after calculations.
  696. RuntimeError: If the dataset is sharded prior to calling split.
  697. ValueError: If sizes is list of float and not all floats are between 0 and 1, or if the
  698. floats don’t sum to 1.
  699. Returns
  700. tuple(Dataset), a tuple of datasets that have been split.
  701. Examples:
  702. >>> import mindspore.dataset as ds
  703. >>>
  704. >>> dataset_files = "/path/to/text_file/*"
  705. >>>
  706. >>> # TextFileDataset is not a mappable dataset, so this non-optimized split will be called.
  707. >>> # Since many datasets have shuffle on by default, set shuffle to False if split will be called!
  708. >>> data = ds.TextFileDataset(dataset_files, shuffle=False)
  709. >>> train, test = data.split([0.9, 0.1])
  710. """
  711. if self.is_shuffled():
  712. logger.warning("Dataset is shuffled before split.")
  713. if self.is_sharded():
  714. raise RuntimeError("Dataset should not be sharded before split.")
  715. absolute_sizes = self._get_absolute_split_sizes(sizes)
  716. splits = []
  717. rows_to_skip = 0
  718. for size in absolute_sizes:
  719. ds = copy.deepcopy(self)
  720. if randomize:
  721. # want to shuffle the same way every epoch before split
  722. # in alter_tree, shuffle buffer is minimum 10000, so use 10000 here
  723. ds = ds.shuffle(10000)
  724. ds.reshuffle_each_epoch = False
  725. if rows_to_skip > 0:
  726. ds = ds.skip(rows_to_skip)
  727. ds = ds.take(size)
  728. splits.append(ds)
  729. rows_to_skip += size
  730. return tuple(splits)
  731. @check_zip_dataset
  732. def zip(self, datasets):
  733. """
  734. Zip the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name.
  735. Args:
  736. datasets (Union[tuple, class Dataset]): A tuple of datasets or a single class Dataset
  737. to be zipped together with this dataset.
  738. Returns:
  739. ZipDataset, dataset zipped.
  740. Examples:
  741. >>> import mindspore.dataset as ds
  742. >>>
  743. >>> # ds1 and ds2 are instances of Dataset object
  744. >>> # Create a dataset which is the combination of ds1 and ds2
  745. >>> data = ds1.zip(ds2)
  746. """
  747. if isinstance(datasets, tuple):
  748. datasets = (self, *datasets)
  749. elif isinstance(datasets, Dataset):
  750. datasets = (self, datasets)
  751. else:
  752. raise TypeError("The zip function %s type error!" % (datasets))
  753. return ZipDataset(datasets)
  754. @check_concat
  755. def concat(self, datasets):
  756. """
  757. Concatenate the datasets in the input list of datasets. The "+" operator is also supported to concatenate.
  758. Note:
  759. The column name, and rank and type of the column data must be the same in the input datasets.
  760. Args:
  761. datasets (Union[list, class Dataset]): A list of datasets or a single class Dataset
  762. to be concatenated together with this dataset.
  763. Returns:
  764. ConcatDataset, dataset concatenated.
  765. Examples:
  766. >>> import mindspore.dataset as ds
  767. >>>
  768. >>> # ds1 and ds2 are instances of Dataset object
  769. >>>
  770. >>> # Create a dataset by concatenating ds1 and ds2 with "+" operator
  771. >>> data1 = ds1 + ds2
  772. >>> # Create a dataset by concatenating ds1 and ds2 with concat operation
  773. >>> data1 = ds1.concat(ds2)
  774. """
  775. if isinstance(datasets, Dataset):
  776. datasets = [self] + [datasets]
  777. elif isinstance(datasets, list):
  778. datasets = [self] + datasets
  779. else:
  780. raise TypeError("The concat_dataset function %s type error!" % (datasets))
  781. return ConcatDataset(datasets)
  782. @check_rename
  783. def rename(self, input_columns, output_columns):
  784. """
  785. Rename the columns in input datasets.
  786. Args:
  787. input_columns (list[str]): List of names of the input columns.
  788. output_columns (list[str]): List of names of the output columns.
  789. Returns:
  790. RenameDataset, dataset renamed.
  791. Examples:
  792. >>> import mindspore.dataset as ds
  793. >>>
  794. >>> # data is an instance of Dataset object.
  795. >>> input_columns = ["input_col1", "input_col2", "input_col3"]
  796. >>> output_columns = ["output_col1", "output_col2", "output_col3"]
  797. >>>
  798. >>> # Create a dataset where input_col1 is renamed to output_col1, and
  799. >>> # input_col2 is renamed to output_col2, and input_col3 is renamed
  800. >>> # to output_col3.
  801. >>> data = data.rename(input_columns=input_columns, output_columns=output_columns)
  802. """
  803. return RenameDataset(self, input_columns, output_columns)
  804. @check_project
  805. def project(self, columns):
  806. """
  807. Project certain columns in input dataset.
  808. The specified columns will be selected from the dataset and passed down
  809. the pipeline in the order specified. The other columns are discarded.
  810. Args:
  811. columns(list[str]): List of names of the columns to project.
  812. Returns:
  813. ProjectDataset, dataset projected.
  814. Examples:
  815. >>> import mindspore.dataset as ds
  816. >>>
  817. >>> # data is an instance of Dataset object
  818. >>> columns_to_project = ["column3", "column1", "column2"]
  819. >>>
  820. >>> # Create a dataset that consists of column3, column1, column2
  821. >>> # in that order, regardless of the original order of columns.
  822. >>> data = data.project(columns=columns_to_project)
  823. """
  824. return ProjectDataset(self, columns)
  825. def build_vocab(self, vocab, columns, freq_range, top_k, special_tokens, special_first):
  826. return BuildVocabDataset(self, vocab, columns, freq_range, top_k, special_tokens, special_first)
  827. def build_sentencepiece_vocab(self, vocab, col_names, vocab_size,
  828. character_coverage, model_type, params):
  829. return BuildSentencePieceVocabDataset(self, vocab, col_names, vocab_size, character_coverage,
  830. model_type, params)
  831. def apply(self, apply_func):
  832. """
  833. Apply a function in this dataset.
  834. Args:
  835. apply_func (function): A function that must take one 'Dataset' as an argument and
  836. return a preprogressing 'Dataset'.
  837. Returns:
  838. Dataset, applied by the function.
  839. Examples:
  840. >>> import mindspore.dataset as ds
  841. >>>
  842. >>> # data is an instance of Dataset object
  843. >>>
  844. >>> # Declare an apply_func function which returns a Dataset object
  845. >>> def apply_func(ds):
  846. >>> ds = ds.batch(2)
  847. >>> return ds
  848. >>>
  849. >>> # Use apply to call apply_func
  850. >>> data = data.apply(apply_func)
  851. Raises:
  852. TypeError: If apply_func is not a function.
  853. TypeError: If apply_func doesn't return a Dataset.
  854. """
  855. if not hasattr(apply_func, '__call__'):
  856. raise TypeError("apply_func must be a function.")
  857. dataset = apply_func(self)
  858. if not isinstance(dataset, Dataset):
  859. raise TypeError("apply_func must return a dataset.")
  860. return dataset
  861. @check_device_send
  862. def device_que(self, prefetch_size=None, send_epoch_end=True):
  863. """
  864. Return a transferred Dataset that transfers data through a device.
  865. Args:
  866. prefetch_size (int, optional): Prefetch number of records ahead of the
  867. user's request (default=None).
  868. send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True).
  869. Note:
  870. If device is Ascend, features of data will be transferred one by one. The limitation
  871. of data transmission per time is 256M.
  872. Return:
  873. TransferDataset, dataset for transferring.
  874. """
  875. return self.to_device(send_epoch_end=send_epoch_end)
  876. @check_device_send
  877. def to_device(self, send_epoch_end=True):
  878. """
  879. Transfer data through CPU, GPU or Ascend devices.
  880. Args:
  881. send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True).
  882. Note:
  883. If device is Ascend, features of data will be transferred one by one. The limitation
  884. of data transmission per time is 256M.
  885. Returns:
  886. TransferDataset, dataset for transferring.
  887. Raises:
  888. TypeError: If device_type is empty.
  889. ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'.
  890. RuntimeError: If dataset is unknown.
  891. RuntimeError: If distribution file path is given but failed to read.
  892. """
  893. queue_name = str(uuid.uuid1())
  894. if context:
  895. device_type = context.get_context("device_target")
  896. else:
  897. device_type = "CPU"
  898. if device_type == "":
  899. raise TypeError("Please set device_type in context")
  900. if device_type not in ('Ascend', 'GPU', 'CPU'):
  901. raise ValueError("Only support CPU, Ascend, GPU")
  902. def get_distribution(output_dataset):
  903. dev_id = 0
  904. if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDataset,
  905. ManifestDataset, MnistDataset, VOCDataset, CocoDataset, CelebADataset,
  906. MindDataset)):
  907. sampler = output_dataset.sampler
  908. if isinstance(sampler, samplers.DistributedSampler):
  909. dev_id = sampler.shard_id
  910. return "", dev_id
  911. if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset, CSVDataset)):
  912. if output_dataset.shard_id is not None:
  913. dev_id = output_dataset.shard_id
  914. return "", dev_id
  915. if not output_dataset.children:
  916. raise RuntimeError("Unknown output_dataset: {}".format(type(output_dataset)))
  917. input_dataset = output_dataset.children[0]
  918. return get_distribution(input_dataset)
  919. distribution_path, device_id = get_distribution(self)
  920. if distribution_path == "":
  921. return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end)
  922. try:
  923. with open(distribution_path, 'r') as distribution_f:
  924. dist = json.load(distribution_f)
  925. device_id = dist["deviceId"]
  926. except json.decoder.JSONDecodeError:
  927. raise RuntimeError("Json decode error when load distribution file")
  928. except Exception:
  929. raise RuntimeError("Distribution file failed to read")
  930. return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end)
  931. @check_save
  932. def save(self, file_name, num_files=1, file_type='mindrecord'):
  933. """
  934. Save the dynamic data processed by the dataset pipeline in common dataset format.
  935. Supported dataset formats: 'mindrecord' only
  936. Implicit type casting exists when saving data as 'mindrecord'. The table below shows how to do type casting.
  937. .. list-table:: Implicit Type Casting when Saving as 'mindrecord'
  938. :widths: 25 25 50
  939. :header-rows: 1
  940. * - Type in 'dataset'
  941. - Type in 'mindrecord'
  942. - Details
  943. * - bool
  944. - None
  945. - Not supported
  946. * - int8
  947. - int32
  948. -
  949. * - uint8
  950. - bytes(1D uint8)
  951. - Drop dimension
  952. * - int16
  953. - int32
  954. -
  955. * - uint16
  956. - int32
  957. -
  958. * - int32
  959. - int32
  960. -
  961. * - uint32
  962. - int64
  963. -
  964. * - int64
  965. - int64
  966. -
  967. * - uint64
  968. - None
  969. - Not supported
  970. * - float16
  971. - float32
  972. -
  973. * - float32
  974. - float32
  975. -
  976. * - float64
  977. - float64
  978. -
  979. * - string
  980. - string
  981. - Multi-dimensional string not supported
  982. Note:
  983. 1. To save the samples in order, set dataset's shuffle to False and num_files to 1.
  984. 2. Before calling the function, do not use batch operator, repeat operator or data augmentation operators
  985. with random attribute in map operator.
  986. 3. Mindrecord does not support DE_UINT64, multi-dimensional DE_UINT8(drop dimension) nor
  987. multi-dimensional DE_STRING.
  988. Args:
  989. file_name (str): Path to dataset file.
  990. num_files (int, optional): Number of dataset files (default=1).
  991. file_type (str, optional): Dataset format (default='mindrecord').
  992. """
  993. if num_files == 1:
  994. file_names = [file_name]
  995. else:
  996. suffix = len(str(num_files - 1))
  997. file_names = ["{}{}".format(file_name, str(x).rjust(suffix, '0'))
  998. for x in range(num_files)]
  999. return SaveOp(self).save(file_names, file_type)
  1000. @check_iterator
  1001. def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False):
  1002. """
  1003. Create an iterator over the dataset. The data retrieved will be a list of ndarrays of data.
  1004. To specify which columns to list and the order needed, use columns_list. If columns_list
  1005. is not provided, the order of the columns will not be changed.
  1006. Args:
  1007. columns (list[str], optional): List of columns to be used to specify the order of columns
  1008. (default=None, means all columns).
  1009. num_epochs (int, optional): Maximum number of epochs that iterator can be iterated.
  1010. (default=-1, iterator can be iterated infinite number of epochs)
  1011. output_numpy (bool, optional): Whether or not to output NumPy datatype.
  1012. If output_numpy=False, iterator will output MSTensor (default=False).
  1013. Returns:
  1014. Iterator, list of ndarrays.
  1015. Examples:
  1016. >>> import mindspore.dataset as ds
  1017. >>>
  1018. >>> # data is an instance of Dataset object
  1019. >>>
  1020. >>> # Create an iterator
  1021. >>> # The columns in the data obtained by the iterator will not be changed.
  1022. >>> iterator = data.create_tuple_iterator()
  1023. >>> for item in iterator:
  1024. >>> # convert the returned tuple to a list and print
  1025. >>> print(list(item))
  1026. """
  1027. if output_numpy is None:
  1028. output_numpy = False
  1029. if self._noop_mode():
  1030. return DummyIterator(self, 'tuple')
  1031. return TupleIterator(self, columns, num_epochs, output_numpy)
  1032. @check_iterator
  1033. def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
  1034. """
  1035. Create an iterator over the dataset. The data retrieved will be a dictionary.
  1036. The order of the columns in the dictionary may not be the same as the original order.
  1037. Args:
  1038. num_epochs (int, optional): Maximum number of epochs that iterator can be iterated
  1039. (default=-1, iterator can be iterated infinite number of epochs).
  1040. output_numpy (bool, optional): Whether or not to output NumPy datatype,
  1041. if output_numpy=False, iterator will output MSTensor (default=False).
  1042. Returns:
  1043. Iterator, dictionary of column name-ndarray pair.
  1044. Examples:
  1045. >>> import mindspore.dataset as ds
  1046. >>>
  1047. >>> # data is an instance of Dataset object
  1048. >>>
  1049. >>> # create an iterator
  1050. >>> # The columns in the data obtained by the iterator might be changed.
  1051. >>> iterator = data.create_dict_iterator()
  1052. >>> for item in iterator:
  1053. >>> # print the data in column1
  1054. >>> print(item["column1"])
  1055. """
  1056. if output_numpy is None:
  1057. output_numpy = False
  1058. if self._noop_mode():
  1059. return DummyIterator(self, 'dict')
  1060. return DictIterator(self, num_epochs, output_numpy)
  1061. def __iter__(self):
  1062. """Create an iterator over the dataset."""
  1063. return self.create_tuple_iterator(num_epochs=1)
  1064. @property
  1065. def input_indexs(self):
  1066. return self._input_indexs
  1067. @input_indexs.setter
  1068. def input_indexs(self, value):
  1069. self._input_indexs = value
  1070. def _get_pipeline_info(self):
  1071. """
  1072. Get pipeline information.
  1073. """
  1074. device_iter = TupleIterator(self)
  1075. self._output_shapes = device_iter.get_output_shapes()
  1076. self._output_types = device_iter.get_output_types()
  1077. self._batch_size = device_iter.get_batch_size()
  1078. self._num_classes = device_iter.num_classes()
  1079. self._repeat_count = device_iter.get_repeat_count()
  1080. device_iter.stop()
  1081. def get_col_names(self):
  1082. """
  1083. Get names of the columns in the dataset
  1084. """
  1085. return Iterator(self).get_col_names()
  1086. def output_shapes(self):
  1087. """
  1088. Get the shapes of output data.
  1089. Return:
  1090. List, list of shapes of each column.
  1091. """
  1092. if self._output_shapes is None:
  1093. self._get_pipeline_info()
  1094. return self._output_shapes
  1095. def output_types(self):
  1096. """
  1097. Get the types of output data.
  1098. Return:
  1099. List of data types.
  1100. """
  1101. if self._output_types is None:
  1102. self._get_pipeline_info()
  1103. return self._output_types
  1104. def get_dataset_size(self):
  1105. """
  1106. Get the number of batches in an epoch.
  1107. Return:
  1108. Number, number of batches.
  1109. """
  1110. if self.dataset_size is None:
  1111. if self.children:
  1112. self.dataset_size = self.children[0].get_dataset_size()
  1113. return self.dataset_size
  1114. def num_classes(self):
  1115. """
  1116. Get the number of classes in a dataset.
  1117. Return:
  1118. Number, number of classes.
  1119. """
  1120. if self.children:
  1121. return self.children[0].num_classes()
  1122. return None
  1123. def get_sync_notifiers(self):
  1124. if self.children:
  1125. return self.children[0].get_sync_notifiers()
  1126. return {}
  1127. def disable_sync(self):
  1128. if self.children:
  1129. return self.children[0].disable_sync()
  1130. return {}
  1131. def is_sync(self):
  1132. if self.children:
  1133. return self.children[0].is_sync()
  1134. return False
  1135. def sync_update(self, condition_name, num_batch=None, data=None):
  1136. """
  1137. Release a blocking condition and trigger callback with given data.
  1138. Args:
  1139. condition_name (str): The condition name that is used to toggle sending next row.
  1140. num_batch (Union[int, None]): The number of batches (rows) that are released.
  1141. When num_batch is None, it will default to the number specified by the
  1142. sync_wait operator (default=None).
  1143. data (Union[dict, None]): The data passed to the callback (default=None).
  1144. """
  1145. if isinstance(num_batch, int) and num_batch <= 0:
  1146. # throwing exception, disable all sync_wait in pipeline
  1147. self.disable_sync()
  1148. raise RuntimeError("Sync_update batch size can only be positive, got : {}".format(num_batch))
  1149. notifiers_dict = self.get_sync_notifiers()
  1150. if condition_name not in notifiers_dict:
  1151. # throwing exception, disable all sync_wait in pipeline
  1152. self.disable_sync()
  1153. raise RuntimeError("Condition name not found")
  1154. if num_batch is not None:
  1155. num_batch *= self.get_batch_size()
  1156. notifiers_dict[condition_name](num_batch, data)
  1157. def get_batch_size(self):
  1158. """
  1159. Get the size of a batch.
  1160. Return:
  1161. Number, the number of data in a batch.
  1162. """
  1163. if self.children:
  1164. return self.children[0].get_batch_size()
  1165. return 1
  1166. def get_repeat_count(self):
  1167. """
  1168. Get the replication times in RepeatDataset else 1.
  1169. Return:
  1170. Number, the count of repeat.
  1171. """
  1172. if self.children:
  1173. return self.children[0].get_repeat_count()
  1174. return 1
  1175. def reset(self):
  1176. """Reset the dataset for next epoch."""
  1177. def is_shuffled(self):
  1178. for input_dataset in self.children:
  1179. if input_dataset.is_shuffled():
  1180. return True
  1181. return False
  1182. def is_sharded(self):
  1183. for input_dataset in self.children:
  1184. if input_dataset.is_sharded():
  1185. return True
  1186. return False
  1187. class SourceDataset(Dataset):
  1188. """
  1189. Abstract class to represent a source dataset which produces content to the data pipeline.
  1190. """
  1191. # No need for __init__ since it is the same as the super's init
  1192. @staticmethod
  1193. def _find_files(patterns):
  1194. """
  1195. Utility function to search for files with the given glob patterns.
  1196. Args:
  1197. patterns (Union[str, list[str]]): String or list of patterns to be searched.
  1198. Returns:
  1199. List, files.
  1200. """
  1201. if not isinstance(patterns, list):
  1202. patterns = [patterns]
  1203. file_list = []
  1204. unmatched_patterns = []
  1205. for pattern in patterns:
  1206. matches = [match for match in glob.glob(pattern, recursive=True) if os.path.isfile(match)]
  1207. if matches:
  1208. file_list.extend(matches)
  1209. else:
  1210. unmatched_patterns.append(pattern)
  1211. if unmatched_patterns:
  1212. raise ValueError("The following patterns did not match any files: ", unmatched_patterns)
  1213. if file_list: # not empty
  1214. return file_list
  1215. raise ValueError("The list of path names matching the patterns is empty.")
  1216. def is_shuffled(self):
  1217. raise NotImplementedError("SourceDataset must implement is_shuffled.")
  1218. def is_sharded(self):
  1219. raise NotImplementedError("SourceDataset must implement is_sharded.")
  1220. class MappableDataset(SourceDataset):
  1221. """
  1222. Abstract class to represent a source dataset which supports use of samplers.
  1223. """
  1224. def __init__(self, num_parallel_workers=None):
  1225. # check if all subclasses use this name
  1226. super().__init__(num_parallel_workers)
  1227. self.sampler = None
  1228. def add_sampler(self, new_sampler):
  1229. # note: By adding a sampler, the sampled IDs will flow to new_sampler
  1230. # after first passing through the current samplers attached to this dataset.
  1231. if self.dataset_size is not None:
  1232. self.dataset_size = None
  1233. new_sampler.add_child(self.sampler)
  1234. self.sampler = new_sampler
  1235. def use_sampler(self, new_sampler):
  1236. """
  1237. Will make the current dataset use the new_sampler provided.
  1238. Args:
  1239. new_sampler (Sampler): The sampler to use for the current dataset.
  1240. Returns:
  1241. Dataset, that uses new_sampler.
  1242. Examples:
  1243. >>> import mindspore.dataset as ds
  1244. >>>
  1245. >>> dataset_dir = "/path/to/imagefolder_directory"
  1246. >>> # Note: A SequentialSampler is created by default
  1247. >>> data = ds.ImageFolderDataset(dataset_dir)
  1248. >>>
  1249. >>> # Use a DistributedSampler instead of the SequentialSampler
  1250. >>> new_sampler = ds.DistributedSampler(10, 2)
  1251. >>> data.use_sampler(new_sampler)
  1252. """
  1253. if new_sampler is None:
  1254. raise TypeError("Input sampler can not be None.")
  1255. if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
  1256. raise TypeError("Input sampler is not an instance of a sampler.")
  1257. if self.dataset_size is not None:
  1258. self.dataset_size = None
  1259. self.sampler = self.sampler.child_sampler
  1260. self.add_sampler(new_sampler)
  1261. def is_shuffled(self):
  1262. raise NotImplementedError("MappableDataset must implement is_shuffled.")
  1263. def is_sharded(self):
  1264. raise NotImplementedError("MappableDataset must implement is_sharded.")
  1265. def _get_sampler_dataset_size(self):
  1266. if self.sampler is not None:
  1267. if hasattr(self.sampler, 'get_num_samples'):
  1268. return self.sampler.get_num_samples()
  1269. if hasattr(self.sampler, '__len__'):
  1270. return len(self.sampler)
  1271. return None
  1272. @check_split
  1273. def split(self, sizes, randomize=True):
  1274. """
  1275. Split the dataset into smaller, non-overlapping datasets.
  1276. Args:
  1277. sizes (Union[list[int], list[float]]): If a list of integers [s1, s2, …, sn] is
  1278. provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
  1279. respectively. If the sum of all sizes does not equal the original dataset size, an
  1280. error will occur.
  1281. If a list of floats [f1, f2, …, fn] is provided, all floats must be between 0 and 1
  1282. and must sum to 1, otherwise an error will occur. The dataset will be split into n
  1283. Datasets of size round(f1*K), round(f2*K), …, round(fn*K) where K is the size of the
  1284. original dataset.
  1285. If after rounding:
  1286. - Any size equals 0, an error will occur.
  1287. - The sum of split sizes < K, the difference will be added to the first split.
  1288. - The sum of split sizes > K, the difference will be removed from the first large
  1289. enough split such that it will have atleast 1 row after removing the difference.
  1290. randomize (bool, optional): Determines whether or not to split the data randomly (default=True).
  1291. If True, the data will be randomly split. Otherwise, each split will be created with
  1292. consecutive rows from the dataset.
  1293. Note:
  1294. 1. There is an optimized split function, which will be called automatically when the dataset
  1295. that calls this function is a MappableDataset.
  1296. 2. Dataset should not be sharded if split is going to be called. Instead, create a
  1297. DistributedSampler and specify a split to shard after splitting. If dataset is
  1298. sharded after a split, it is strongly recommended to set the same seed in each instance
  1299. of execution, otherwise each shard may not be part of the same split (see Examples).
  1300. 3. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
  1301. Shuffling the dataset may not be deterministic, which means the data in each split
  1302. will be different in each epoch. Furthermore, if sharding occurs after split, each
  1303. shard may not be part of the same split.
  1304. Raises:
  1305. RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
  1306. RuntimeError: If sizes is list of integers and sum of all elements in sizes does not
  1307. equal the dataset size.
  1308. RuntimeError: If sizes is list of float and there is a split with size 0 after calculations.
  1309. RuntimeError: If the dataset is sharded prior to calling split.
  1310. ValueError: If sizes is list of float and not all floats are between 0 and 1, or if the
  1311. floats don’t sum to 1.
  1312. Returns
  1313. tuple(Dataset), a tuple of datasets that have been split.
  1314. Examples:
  1315. >>> import mindspore.dataset as ds
  1316. >>>
  1317. >>> dataset_dir = "/path/to/imagefolder_directory"
  1318. >>>
  1319. >>> # Since many datasets have shuffle on by default, set shuffle to False if split will be called!
  1320. >>> data = ds.ImageFolderDataset(dataset_dir, shuffle=False)
  1321. >>>
  1322. >>> # Set the seed, and tell split to use this seed when randomizing.
  1323. >>> # This is needed because sharding will be done later
  1324. >>> ds.config.set_seed(58)
  1325. >>> train, test = data.split([0.9, 0.1])
  1326. >>>
  1327. >>> # To shard the train dataset, use a DistributedSampler
  1328. >>> train_sampler = ds.DistributedSampler(10, 2)
  1329. >>> train.use_sampler(train_sampler)
  1330. """
  1331. if self.is_shuffled():
  1332. logger.warning("Dataset is shuffled before split.")
  1333. if self.is_sharded():
  1334. raise RuntimeError("Dataset should not be sharded before split.")
  1335. absolute_sizes = self._get_absolute_split_sizes(sizes)
  1336. splits = []
  1337. current_split_start_index = 0
  1338. for size in absolute_sizes:
  1339. ds = copy.deepcopy(self)
  1340. ds.dataset_size = None
  1341. if randomize:
  1342. # want to shuffle the same way every epoch before split, we are assuming
  1343. # that the user will call set_seed
  1344. random_sampler = samplers.RandomSampler()
  1345. random_sampler.reshuffle_each_epoch = False
  1346. ds.add_sampler(random_sampler)
  1347. subset_sampler = samplers.SequentialSampler(current_split_start_index, size)
  1348. ds.add_sampler(subset_sampler)
  1349. # add sequential sampler, so that if user calls use_sampler, we will
  1350. # get rid of the sequential sampler instead of something we need
  1351. ds.add_sampler(samplers.SequentialSampler())
  1352. splits.append(ds)
  1353. current_split_start_index += size
  1354. return tuple(splits)
  1355. class DatasetOp(Dataset):
  1356. """
  1357. Abstract class to represent an operation on a dataset.
  1358. """
  1359. # No need for __init__ since it is the same as the super's init
  1360. def get_class_indexing(self):
  1361. """
  1362. Get the class index.
  1363. Return:
  1364. Dict, A str-to-int mapping from label name to index.
  1365. """
  1366. if self.children:
  1367. return self.children[0].get_class_indexing()
  1368. raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self)))
  1369. class BucketBatchByLengthDataset(DatasetOp):
  1370. """
  1371. The result of applying BucketBatchByLength operator to the input dataset.
  1372. """
  1373. def __init__(self, input_dataset, column_names, bucket_boundaries, bucket_batch_sizes,
  1374. element_length_function, pad_info, pad_to_bucket_boundary, drop_remainder):
  1375. super().__init__()
  1376. self.column_names = column_names
  1377. self.bucket_boundaries = bucket_boundaries
  1378. self.bucket_batch_sizes = bucket_batch_sizes
  1379. self.element_length_function = element_length_function
  1380. self.pad_info = pad_info
  1381. self.pad_to_bucket_boundary = pad_to_bucket_boundary
  1382. self.drop_remainder = drop_remainder
  1383. self.children.append(input_dataset)
  1384. input_dataset.parent.append(self)
  1385. self._input_indexs = input_dataset.input_indexs
  1386. def get_args(self):
  1387. args = super().get_args()
  1388. args["length_dependent_columns"] = self.column_names
  1389. args["bucket_boundaries"] = self.bucket_boundaries
  1390. args["bucket_batch_sizes"] = self.bucket_batch_sizes
  1391. args["element_length_function"] = self.element_length_function
  1392. args["pad_info"] = self.pad_info
  1393. args["pad_to_bucket_boundary"] = self.pad_to_bucket_boundary
  1394. args["drop_remainder"] = self.drop_remainder
  1395. return args
  1396. def get_dataset_size(self):
  1397. """
  1398. Get the number of batches in an epoch.
  1399. Return:
  1400. Number, number of batches.
  1401. """
  1402. if self.dataset_size is None:
  1403. num_rows = 0
  1404. for _ in self.create_dict_iterator(num_epochs=1, output_numpy=True):
  1405. num_rows += 1
  1406. self.dataset_size = num_rows
  1407. return self.dataset_size
  1408. class BatchDataset(DatasetOp):
  1409. """
  1410. The result of applying Batch operator to the input dataset.
  1411. Args:
  1412. input_dataset (Dataset): Input Dataset to be batched.
  1413. batch_size (Union[int, function]): The number of rows each batch is created with. An
  1414. int or callable which takes exactly 1 parameter, BatchInfo.
  1415. drop_remainder (bool, optional): Determines whether or not to drop the last
  1416. possibly incomplete batch (default=False). If True, and if there are less
  1417. than batch_size rows available to make the last batch, then those rows will
  1418. be dropped and not propagated to the child node.
  1419. num_parallel_workers (int, optional): Number of workers to process the dataset in parallel (default=None).
  1420. per_batch_map (callable, optional): Per batch map callable. A callable which takes
  1421. (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represents a batch of
  1422. Tensors on a given column. The number of lists should match with number of entries in input_columns. The
  1423. last parameter of the callable must always be a BatchInfo object.
  1424. input_columns (list[str], optional): List of names of the input columns. The size of the list must
  1425. match with signature of per_batch_map callable.
  1426. output_columns (list[str], optional): List of names assigned to the columns outputted by
  1427. the last operation. This parameter is mandatory if len(input_columns) !=
  1428. len(output_columns). The size of this list must match the number of output
  1429. columns of the last operation. (default=None, output columns will have the same
  1430. name as the input columns, i.e., the columns will be replaced).
  1431. column_order (list[str], optional): List of all the desired columns to propagate to the
  1432. child node. This list must be a subset of all the columns in the dataset after
  1433. all operations are applied. The order of the columns in each row propagated to the
  1434. child node follow the order they appear in this list. The parameter is mandatory
  1435. if the len(input_columns) != len(output_columns). (default=None, all columns
  1436. will be propagated to the child node, the order of the columns will remain the
  1437. same).
  1438. pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)}
  1439. will pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0.
  1440. """
  1441. def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None,
  1442. per_batch_map=None, input_columns=None, output_columns=None, column_order=None, pad_info=None):
  1443. super().__init__(num_parallel_workers)
  1444. if BatchDataset._is_ancestor_of_repeat(input_dataset):
  1445. logger.warning("Repeat is located before batch, data from two epochs can be batched together.")
  1446. BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
  1447. self.batch_size = batch_size
  1448. self.drop_remainder = drop_remainder
  1449. self.per_batch_map = per_batch_map
  1450. self.input_columns = input_columns if not isinstance(input_columns, str) else [input_columns]
  1451. self.output_columns = output_columns if not isinstance(output_columns, str) else [output_columns]
  1452. self.column_order = column_order if not isinstance(column_order, str) else [column_order]
  1453. self.pad_info = pad_info
  1454. self.children.append(input_dataset)
  1455. input_dataset.parent.append(self)
  1456. self._input_indexs = input_dataset.input_indexs
  1457. def get_args(self):
  1458. args = super().get_args()
  1459. args["batch_size"] = self.batch_size
  1460. args["drop_remainder"] = self.drop_remainder
  1461. args["per_batch_map"] = self.per_batch_map
  1462. args["input_columns"] = self.input_columns
  1463. args["output_columns"] = self.output_columns
  1464. args["column_order"] = self.column_order
  1465. args["pad_info"] = self.pad_info
  1466. return args
  1467. def get_dataset_size(self):
  1468. """
  1469. Get the number of batches in an epoch.
  1470. Return:
  1471. Number, number of batches.
  1472. """
  1473. if self.dataset_size is None:
  1474. child_size = self.children[0].get_dataset_size()
  1475. if child_size is not None and isinstance(self.batch_size, int):
  1476. if self.drop_remainder:
  1477. self.dataset_size = math.floor(child_size / self.batch_size)
  1478. else:
  1479. self.dataset_size = math.ceil(child_size / self.batch_size)
  1480. return self.dataset_size
  1481. def get_batch_size(self):
  1482. """
  1483. Get the size of a batch.
  1484. Return:
  1485. Number, the number of data in a batch.
  1486. """
  1487. return self.batch_size
  1488. @staticmethod
  1489. def _is_ancestor_of_repeat(dataset):
  1490. """
  1491. Utility function to find the case where repeat is used before batch.
  1492. Args:
  1493. dataset (Dataset): Dataset to be checked.
  1494. Return:
  1495. True or False.
  1496. """
  1497. if isinstance(dataset, RepeatDataset):
  1498. return True
  1499. flag = False
  1500. for input_dataset in dataset.children:
  1501. flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset)
  1502. return flag
  1503. @staticmethod
  1504. def _update_batch_size_for_syncwait(dataset, batch_size):
  1505. """
  1506. Utility function to notify batch size to sync_wait.
  1507. Args:
  1508. dataset (Dataset): Dataset to be checked.
  1509. batch_size (int): batch size to notify.
  1510. """
  1511. if isinstance(dataset, SyncWaitDataset):
  1512. dataset.update_sync_batch_size(batch_size)
  1513. for input_dataset in dataset.children:
  1514. BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
  1515. class BatchInfo(CBatchInfo):
  1516. """
  1517. The information object associates with the current batch of tensors.
  1518. """
  1519. def get_batch_num(self):
  1520. """
  1521. Return the batch number of the current batch.
  1522. Return:
  1523. Number, number of the current batch.
  1524. """
  1525. return
  1526. def get_epoch_num(self):
  1527. """
  1528. Return the epoch number of the current batch.
  1529. Return:
  1530. Number, number of the current epoch.
  1531. """
  1532. return
  1533. class BlockReleasePair:
  1534. """
  1535. The blocking condition class used by SyncWaitDataset.
  1536. Args:
  1537. init_release_rows (int): Number of lines to allow through the pipeline.
  1538. callback (function): The callback function that will be called when release is called.
  1539. """
  1540. def __init__(self, init_release_rows, callback=None):
  1541. if isinstance(init_release_rows, int) and init_release_rows <= 0:
  1542. raise ValueError("release_rows need to be greater than 0.")
  1543. self.row_count = -init_release_rows
  1544. self.cv = threading.Condition()
  1545. self.callback = callback
  1546. self.default_rows = init_release_rows
  1547. self.disable = False
  1548. def __deepcopy__(self, memodict):
  1549. if id(self) in memodict:
  1550. return memodict[id(self)]
  1551. memodict[id(self)] = self
  1552. # condition variable and callback are the same, but reset the counter
  1553. self.reset()
  1554. return self
  1555. def reset(self):
  1556. with self.cv:
  1557. self.row_count = -self.default_rows
  1558. self.cv.notify_all()
  1559. def update_batched_size(self, batch_size):
  1560. # sanity check
  1561. if isinstance(batch_size, int) and batch_size <= 0:
  1562. raise ValueError("batch_size need to be greater than 0.")
  1563. # should only use before the pipeline creates
  1564. self.row_count *= batch_size
  1565. self.default_rows *= batch_size
  1566. def block_func(self):
  1567. """
  1568. Function for handing blocking condition.
  1569. Return:
  1570. True
  1571. """
  1572. with self.cv:
  1573. # if disable is true, the always evaluate to true
  1574. not_time_out = self.cv.wait_for(lambda: (self.row_count < 0 or self.disable),
  1575. timeout=get_callback_timeout())
  1576. # time_out will be False if time out occurs
  1577. if not not_time_out:
  1578. logger.warning("Timeout happened in sync_wait, disabling lock")
  1579. self.disable = True
  1580. self.row_count += 1
  1581. return True
  1582. def release_func(self, pass_rows=None, data=None):
  1583. with self.cv:
  1584. if pass_rows is None:
  1585. pass_rows = self.default_rows
  1586. self.row_count -= pass_rows
  1587. if self.callback is not None:
  1588. self.callback(data)
  1589. self.cv.notify_all()
  1590. def disable_lock(self):
  1591. with self.cv:
  1592. self.disable = True
  1593. self.cv.notify_all()
  1594. class SyncWaitDataset(DatasetOp):
  1595. """
  1596. The result of adding a blocking condition to the input Dataset.
  1597. Args:
  1598. input_dataset (Dataset): Input dataset to apply flow control.
  1599. num_batch (int): Number of batches without blocking at the start of each epoch.
  1600. condition_name (str): Condition name that is used to toggle sending next row.
  1601. callback (function): Callback function that will be invoked when sync_update is called.
  1602. Raises:
  1603. RuntimeError: If condition name already exists.
  1604. """
  1605. def __init__(self, input_dataset, condition_name, num_batch, callback=None):
  1606. super().__init__()
  1607. self.children.append(input_dataset)
  1608. input_dataset.parent.append(self)
  1609. # set to the default value, waiting for the batch to update it
  1610. self._condition_name = condition_name
  1611. if isinstance(num_batch, int) and num_batch <= 0:
  1612. raise ValueError("num_batch need to be greater than 0.")
  1613. self._pair = BlockReleasePair(num_batch, callback)
  1614. if self._condition_name in self.children[0].get_sync_notifiers():
  1615. raise RuntimeError("Condition name is already in use")
  1616. logger.warning("Please remember to add dataset.sync_update(condition=%s), otherwise will result in hanging",
  1617. condition_name)
  1618. def get_sync_notifiers(self):
  1619. return {**self.children[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}}
  1620. def is_sync(self):
  1621. return True
  1622. def get_args(self):
  1623. args = super().get_args()
  1624. args["condition_name"] = self._condition_name
  1625. args["condition_func"] = self._pair.block_func
  1626. return args
  1627. def update_sync_batch_size(self, batch_size):
  1628. if isinstance(batch_size, int) and batch_size <= 0:
  1629. raise ValueError("num_batch need to be greater than 0.")
  1630. self._pair.update_batched_size(batch_size)
  1631. def disable_sync(self):
  1632. logger.info("Disabling Sync")
  1633. self._pair.disable_lock()
  1634. @staticmethod
  1635. def _is_ancestor_of_batch(dataset):
  1636. """
  1637. Utility function to find the case where sync_wait is used before batch.
  1638. Args:
  1639. dataset (Dataset): Dataset to be checked.
  1640. Return:
  1641. True or False.
  1642. """
  1643. if isinstance(dataset, BatchDataset):
  1644. return True
  1645. flag = False
  1646. for input_dataset in dataset.children:
  1647. flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset)
  1648. return flag
  1649. class ShuffleDataset(DatasetOp):
  1650. """
  1651. The result of applying Shuffle operator to the input Dataset.
  1652. Args:
  1653. input_dataset (Dataset): Input Dataset to be shuffled.
  1654. buffer_size (int): Size of the buffer.
  1655. Raises:
  1656. RuntimeError: If exist sync operators before shuffle.
  1657. """
  1658. def __init__(self, input_dataset, buffer_size):
  1659. super().__init__()
  1660. self.buffer_size = buffer_size
  1661. self.children.append(input_dataset)
  1662. self.reshuffle_each_epoch = None
  1663. input_dataset.parent.append(self)
  1664. self._input_indexs = input_dataset.input_indexs
  1665. if self.is_sync():
  1666. raise RuntimeError("No shuffle after sync operators")
  1667. def get_args(self):
  1668. args = super().get_args()
  1669. args["buffer_size"] = self.buffer_size
  1670. if self.reshuffle_each_epoch is not None:
  1671. args["reshuffle_each_epoch"] = self.reshuffle_each_epoch
  1672. return args
  1673. def is_shuffled(self):
  1674. return True
  1675. # Pyfunc collection for multiprocess pyfunc
  1676. # This global variable will only be used within subprocesses
  1677. _GLOBAL_PYFUNC_LIST = []
  1678. # Pyfunc worker init function
  1679. # Python multiprocessing library forbid sending lambda function through pipe.
  1680. # This init function allow us to add all Python function to a global collection and then fork afterwards.
  1681. def _pyfunc_worker_init(pyfunc_list):
  1682. global _GLOBAL_PYFUNC_LIST
  1683. _GLOBAL_PYFUNC_LIST = pyfunc_list
  1684. # Pyfunc worker execution function
  1685. # All exceptions will be raised to main processes
  1686. def _pyfunc_worker_exec(index, *args):
  1687. try:
  1688. return _GLOBAL_PYFUNC_LIST[index](*args)
  1689. except KeyboardInterrupt:
  1690. raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
  1691. # PythonCallable wrapper for multiprocess pyfunc
  1692. class _PythonCallable:
  1693. """
  1694. Internal Python function wrapper for multiprocessing pyfunc.
  1695. """
  1696. def __init__(self, py_callable, idx, pool=None):
  1697. # Original Python callable from user.
  1698. self.py_callable = py_callable
  1699. # Process pool created for current iterator.
  1700. self.pool = pool
  1701. # Python callable index for subprocess _GLOBAL_PYFUNC_LIST
  1702. self.idx = idx
  1703. def __call__(self, *args):
  1704. if self.pool is not None:
  1705. # This call will send the tensors along with Python callable index to the process pool.
  1706. # Block, yield GIL. Current thread will reacquire GIL once result is returned.
  1707. result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args])
  1708. while check_iterator_cleanup() is False:
  1709. try:
  1710. return result.get(30)
  1711. except multiprocessing.TimeoutError:
  1712. continue
  1713. except KeyboardInterrupt:
  1714. self.pool.terminate()
  1715. self.pool.join()
  1716. raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
  1717. return (None,)
  1718. # Invoke original Python callable in master process in case the pool is gone.
  1719. return self.py_callable(*args)
  1720. class MapDataset(DatasetOp):
  1721. """
  1722. The result of applying the Map operator to the input Dataset.
  1723. Args:
  1724. input_dataset (Dataset): Input Dataset to be mapped.
  1725. operations (TensorOp): A function mapping a nested structure of tensors
  1726. to another nested structure of tensor (default=None).
  1727. input_columns (list[str]): List of names of the input columns
  1728. (default=None, the operations will be applied on the first columns in the dataset).
  1729. The size of the list should match the number of inputs of the first operator.
  1730. output_columns (list[str], optional): List of names of the output columns.
  1731. The size of the list should match the number of outputs of the last operator
  1732. (default=None, output columns will be the input columns, i.e., the columns will
  1733. be replaced).
  1734. column_order (list[str], optional): List of all the desired columns of the dataset (default=None).
  1735. The argument is mandatory if len(input_columns) != len(output_columns).
  1736. num_parallel_workers (int, optional): Number of workers to process the dataset
  1737. in parallel (default=None).
  1738. python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
  1739. option could be beneficial if the Python operation is computational heavy (default=False).
  1740. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
  1741. The cache feature is under development and is not recommended.
  1742. callbacks: (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None)
  1743. Raises:
  1744. ValueError: If len(input_columns) != len(output_columns) and column_order is not specified.
  1745. """
  1746. def __init__(self, input_dataset, operations=None, input_columns=None, output_columns=None, column_order=None,
  1747. num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
  1748. super().__init__(num_parallel_workers)
  1749. self.children.append(input_dataset)
  1750. if operations is not None:
  1751. if not isinstance(operations, list):
  1752. operations = [operations]
  1753. elif isinstance(operations, list) and len(operations) > 1:
  1754. # wraps adjacent Python operations in a Compose to allow mixing of Python and C++ operations
  1755. new_ops, start_ind, end_ind = [], 0, 0
  1756. for i, op in enumerate(operations):
  1757. if not callable(op):
  1758. # reset counts
  1759. if start_ind != end_ind:
  1760. new_ops.append(py_transforms.Compose(operations[start_ind:end_ind]))
  1761. new_ops.append(op)
  1762. start_ind, end_ind = i + 1, i + 1
  1763. else:
  1764. end_ind += 1
  1765. # do additional check in case the last operation is a Python operation
  1766. if start_ind != end_ind:
  1767. new_ops.append(py_transforms.Compose(operations[start_ind:end_ind]))
  1768. operations = new_ops
  1769. self.operations = operations
  1770. if input_columns is not None and not isinstance(input_columns, list):
  1771. input_columns = [input_columns]
  1772. self.input_columns = input_columns
  1773. if output_columns is not None and not isinstance(output_columns, list):
  1774. output_columns = [output_columns]
  1775. self.output_columns = output_columns
  1776. self.cache = cache
  1777. self.column_order = column_order
  1778. if self.input_columns and self.output_columns \
  1779. and len(self.input_columns) != len(self.output_columns) \
  1780. and self.column_order is None:
  1781. raise ValueError("When (len(input_columns) != len(output_columns)), column_order must be specified.")
  1782. input_dataset.parent.append(self)
  1783. self._input_indexs = input_dataset.input_indexs
  1784. self.python_multiprocessing = python_multiprocessing
  1785. self.process_pool = None
  1786. if callbacks is not None and not isinstance(callbacks, list):
  1787. callbacks = [callbacks]
  1788. self.callbacks = callbacks
  1789. def get_args(self):
  1790. args = super().get_args()
  1791. args["input_columns"] = self.input_columns
  1792. args["operations"] = self.operations
  1793. args["output_columns"] = self.output_columns
  1794. args["column_order"] = self.column_order
  1795. args["cache"] = self.cache.cache_client if self.cache is not None else None
  1796. if self.callbacks is not None:
  1797. args["callbacks"] = [cb.create_runtime_obj() for cb in self.callbacks]
  1798. return args
  1799. def get_dataset_size(self):
  1800. """
  1801. Get the number of batches in an epoch.
  1802. Return:
  1803. Number, number of batches.
  1804. """
  1805. if self.dataset_size is None:
  1806. self.dataset_size = self.children[0].get_dataset_size()
  1807. return self.dataset_size
  1808. def __deepcopy__(self, memodict):
  1809. if id(self) in memodict:
  1810. return memodict[id(self)]
  1811. cls = self.__class__
  1812. new_op = cls.__new__(cls)
  1813. memodict[id(self)] = new_op
  1814. new_op.children = copy.deepcopy(self.children, memodict)
  1815. new_op.input_columns = copy.deepcopy(self.input_columns, memodict)
  1816. new_op.output_columns = copy.deepcopy(self.output_columns, memodict)
  1817. new_op.column_order = copy.deepcopy(self.column_order, memodict)
  1818. new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
  1819. new_op.parent = copy.deepcopy(self.parent, memodict)
  1820. new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict)
  1821. new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
  1822. new_op.cache = copy.deepcopy(self.cache, memodict)
  1823. new_op.operations = self.operations
  1824. new_op.dataset_size = self.dataset_size
  1825. new_op.callbacks = self.callbacks
  1826. return new_op
  1827. # Iterator bootstrap will be called on iterator construction.
  1828. # A deep copy of Dataset object is created prior of iterator_bootstrap.
  1829. # This method will create per iterator process pool and bind pyfunc execution to the pool.
  1830. def iterator_bootstrap(self):
  1831. """
  1832. Per iterator bootstrap callback.
  1833. """
  1834. if self.python_multiprocessing:
  1835. iter_specific_operations = []
  1836. callable_list = []
  1837. # Pass #1, look for Python callables and build list
  1838. for op in self.operations:
  1839. if callable(op):
  1840. callable_list.append(op)
  1841. if callable_list:
  1842. # Construct pool with the callable list
  1843. # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
  1844. self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
  1845. initializer=_pyfunc_worker_init,
  1846. initargs=(callable_list,))
  1847. # Pass #2
  1848. idx = 0
  1849. for op in self.operations:
  1850. if callable(op):
  1851. # Wrap Python callable into _PythonCallable
  1852. iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool))
  1853. idx += 1
  1854. else:
  1855. # CPP ops remain the same
  1856. iter_specific_operations.append(op)
  1857. self.operations = iter_specific_operations
  1858. def __del__(self):
  1859. if hasattr(self, 'process_pool') and self.process_pool is not None:
  1860. self.process_pool.terminate()
  1861. class FilterDataset(DatasetOp):
  1862. """
  1863. The result of applying filter predicate to the input Dataset.
  1864. Args:
  1865. input_dataset (Dataset): Input Dataset to be mapped.
  1866. predicate (callable): Python callable which returns a boolean value. If False then filter the element.
  1867. input_columns (list[str], optional): List of names of the input columns
  1868. (default=None, the predicate will be applied to all columns in the dataset).
  1869. num_parallel_workers (int, optional): Number of workers to process the dataset
  1870. in parallel (default=None).
  1871. """
  1872. def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None):
  1873. super().__init__(num_parallel_workers)
  1874. self.predicate = lambda *args: bool(predicate(*args))
  1875. self.children.append(input_dataset)
  1876. input_dataset.parent.append(self)
  1877. if input_columns is not None and not isinstance(input_columns, list):
  1878. input_columns = [input_columns]
  1879. self.input_columns = input_columns
  1880. def get_args(self):
  1881. args = super().get_args()
  1882. args["predicate"] = self.predicate
  1883. args["input_columns"] = self.input_columns
  1884. return args
  1885. def get_dataset_size(self):
  1886. """
  1887. Get the number of batches in an epoch.
  1888. Return:
  1889. Number, num of batches.
  1890. """
  1891. if self.dataset_size is None:
  1892. num_rows = 0
  1893. for _ in self.create_dict_iterator(num_epochs=1, output_numpy=True):
  1894. num_rows += 1
  1895. self.dataset_size = num_rows
  1896. return self.dataset_size
  1897. class RepeatDataset(DatasetOp):
  1898. """
  1899. The result of applying Repeat operator to the input Dataset.
  1900. Args:
  1901. input_dataset (Dataset): Input Dataset to be repeated.
  1902. count (int): Number of times the dataset will be repeated (default=-1, repeat indefinitely).
  1903. """
  1904. def __init__(self, input_dataset, count):
  1905. super().__init__()
  1906. if count is None:
  1907. self.count = -1
  1908. else:
  1909. self.count = count
  1910. self.children.append(input_dataset)
  1911. input_dataset.parent.append(self)
  1912. self._input_indexs = input_dataset.input_indexs
  1913. def get_args(self):
  1914. args = super().get_args()
  1915. args["count"] = self.count
  1916. return args
  1917. def get_dataset_size(self):
  1918. """
  1919. Get the number of batches in an epoch.
  1920. Return:
  1921. Number, number of batches.
  1922. """
  1923. if self.dataset_size is None:
  1924. child_size = self.children[0].get_dataset_size()
  1925. if child_size is not None:
  1926. self.dataset_size = child_size * self.count
  1927. return self.dataset_size
  1928. def get_repeat_count(self):
  1929. """
  1930. Get the replication times in RepeatDataset.
  1931. Return:
  1932. Number, the count of repeat.
  1933. """
  1934. return self.count
  1935. class SkipDataset(DatasetOp):
  1936. """
  1937. The result of applying Skip operator to the input Dataset.
  1938. Args:
  1939. input_dataset (Dataset): Input dataset to have elements skipped.
  1940. count (int): Number of elements to be skipped in the dataset.
  1941. """
  1942. def __init__(self, input_dataset, count):
  1943. super().__init__()
  1944. self.count = count
  1945. self.children.append(input_dataset)
  1946. input_dataset.parent.append(self)
  1947. self._input_indexs = input_dataset.input_indexs
  1948. def get_args(self):
  1949. args = super().get_args()
  1950. args["count"] = self.count
  1951. return args
  1952. def get_dataset_size(self):
  1953. """
  1954. Get the number of batches in an epoch.
  1955. Return:
  1956. Number, number of batches.
  1957. """
  1958. if self.dataset_size is None:
  1959. child_size = self.children[0].get_dataset_size()
  1960. self.dataset_size = 0
  1961. if self.count >= 0 and self.count < child_size:
  1962. self.dataset_size = child_size - self.count
  1963. return self.dataset_size
  1964. class TakeDataset(DatasetOp):
  1965. """
  1966. The result of applying Take operator to the input Dataset.
  1967. Args:
  1968. input_dataset (Dataset): Input Dataset to have elements taken from.
  1969. count (int): Number of elements to be taken from the dataset.
  1970. """
  1971. def __init__(self, input_dataset, count):
  1972. super().__init__()
  1973. self.count = count
  1974. self.children.append(input_dataset)
  1975. input_dataset.parent.append(self)
  1976. self._input_indexs = input_dataset.input_indexs
  1977. def get_args(self):
  1978. args = super().get_args()
  1979. args["count"] = self.count
  1980. return args
  1981. def get_dataset_size(self):
  1982. """
  1983. Get the number of batches in an epoch.
  1984. Return:
  1985. Number, number of batches.
  1986. """
  1987. if self.dataset_size is None:
  1988. child_size = self.children[0].get_dataset_size()
  1989. if child_size < self.count:
  1990. self.dataset_size = child_size
  1991. else:
  1992. self.dataset_size = self.count
  1993. return self.dataset_size
  1994. class ZipDataset(DatasetOp):
  1995. """
  1996. The result of applying Zip operator to the input Dataset.
  1997. Args:
  1998. datasets (tuple): A tuple of datasets to be zipped together.
  1999. Raises:
  2000. TypeError: If dataset is not an instance of Dataset.
  2001. """
  2002. def __init__(self, datasets):
  2003. super().__init__()
  2004. for dataset in datasets:
  2005. if not isinstance(dataset, Dataset):
  2006. raise TypeError("The parameter %s of zip has type error!" % (dataset))
  2007. self.datasets = datasets
  2008. for data in datasets:
  2009. self.children.append(data)
  2010. data.parent.append(self)
  2011. def get_dataset_size(self):
  2012. """
  2013. Get the number of batches in an epoch.
  2014. Return:
  2015. Number, number of batches.
  2016. """
  2017. if self.dataset_size is None:
  2018. children_sizes = [c.get_dataset_size() for c in self.children]
  2019. if all(c is not None for c in children_sizes):
  2020. self.dataset_size = min(children_sizes)
  2021. return self.dataset_size
  2022. def num_classes(self):
  2023. """
  2024. Get the number of classes in a dataset.
  2025. Return:
  2026. Number, number of classes.
  2027. """
  2028. return None
  2029. def is_sync(self):
  2030. return any([c.is_sync() for c in self.children])
  2031. def get_args(self):
  2032. args = super().get_args()
  2033. return args
  2034. class ConcatDataset(DatasetOp):
  2035. """
  2036. The result of applying concat dataset operator to the input Dataset.
  2037. Args:
  2038. datasets (list): A list of datasets to be concatenated together.
  2039. Raises:
  2040. TypeError: If dataset is not an instance of Dataset.
  2041. ValueError: If there is no samples in the one of the datasets.
  2042. """
  2043. def __init__(self, datasets):
  2044. super().__init__()
  2045. for dataset in datasets:
  2046. if not isinstance(dataset, Dataset):
  2047. raise TypeError("The parameter %s of concat has type error!" % (dataset))
  2048. self.datasets = datasets
  2049. self._sampler = None
  2050. for data in datasets:
  2051. self.children.append(data)
  2052. data.parent.append(self)
  2053. self.children_sizes_ = [c.get_dataset_size() for c in self.children]
  2054. child_index = 0
  2055. for item in self.children_sizes_:
  2056. if item == 0:
  2057. raise ValueError("There is no samples in the %dth dataset. Please make sure there are "
  2058. "valid samples in the dataset" % child_index)
  2059. child_index += 1
  2060. # _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes
  2061. # whether the data set is mappable. The second element of pair is length of the dataset
  2062. self._children_flag_and_nums = []
  2063. # _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize
  2064. # the valid position of the dataset corresponding to the subscript when sampling
  2065. self._children_start_end_index_ = []
  2066. for index, child in enumerate(self.children):
  2067. tem_list = [-1, -1]
  2068. self._children_start_end_index_.append(tem_list)
  2069. dataset_len = self.children_sizes_[index]
  2070. if isinstance(child, GeneratorDataset) and not hasattr(child.source, "__getitem__"):
  2071. dataset_len = 0
  2072. self.children_sizes_[index] = 0
  2073. if isinstance(child, MappableDataset):
  2074. self._children_flag_and_nums.append((0, dataset_len))
  2075. else:
  2076. self._children_flag_and_nums.append((1, dataset_len))
  2077. def get_dataset_size(self):
  2078. """
  2079. Get the number of batches in an epoch.
  2080. Return:
  2081. Number, number of batches.
  2082. """
  2083. if self.dataset_size is None:
  2084. num_rows = 0
  2085. for _ in self.create_dict_iterator(num_epochs=1, output_numpy=True):
  2086. num_rows += 1
  2087. self.dataset_size = num_rows
  2088. return self.dataset_size
  2089. def use_sampler(self, sampler):
  2090. """
  2091. Set the distributedSampler to concat dataset
  2092. Args:
  2093. sampler (Sampler): The sampler to use for the current dataset.
  2094. Currently supported: DistributedSampler.
  2095. Raises:
  2096. TypeError: If the sampler is not an instance of DistributedSampler
  2097. ValueError: If the parameter shuffle of sampler is True
  2098. ValueError: If the parameter NumSamples of sampler is not None.
  2099. ValueError: If num_shards <=0.
  2100. """
  2101. if not isinstance(sampler, samplers.DistributedSampler):
  2102. raise TypeError("The parameter %s of concat must be DistributedSampler!" % (sampler))
  2103. if sampler.is_shuffled():
  2104. raise ValueError("The parameter shuffle of DistributedSampler must to be False!")
  2105. if sampler.num_shards <= 0:
  2106. raise ValueError("The parameter num_shards of DistributedSampler must be positive int!")
  2107. if sampler.get_num_samples() is not None:
  2108. raise ValueError("The parameter num_samples of DistributedSampler is not support to be set!")
  2109. self._sampler = _select_sampler(None, sampler, None, None, None)
  2110. cumulative_samples_nums = 0
  2111. for index, child in enumerate(self.children):
  2112. if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None:
  2113. raise ValueError("The parameter NumSamples of %s is not support to be set!" % (child))
  2114. if isinstance(child, BatchDataset):
  2115. raise TypeError("The parameter %s of concat must not be BatchDataset!" % (child))
  2116. if not self._children_flag_and_nums[index][0] and self._children_flag_and_nums[index][1]:
  2117. tem_value = cumulative_samples_nums + self._children_flag_and_nums[index][1]
  2118. if not self._children_flag_and_nums[index][1] >= sampler.num_shards:
  2119. if tem_value < sampler.num_shards:
  2120. self._children_start_end_index_[index][0] = cumulative_samples_nums
  2121. self._children_start_end_index_[index][1] = tem_value
  2122. else:
  2123. self._children_start_end_index_[index][0] = cumulative_samples_nums
  2124. self._children_start_end_index_[index][1] = tem_value % sampler.num_shards
  2125. tem_sampler = copy.deepcopy(sampler)
  2126. tem_sampler.set_offset(cumulative_samples_nums)
  2127. child.sampler = tem_sampler
  2128. cumulative_samples_nums += self.children_sizes_[index]
  2129. cumulative_samples_nums %= sampler.num_shards
  2130. def get_args(self):
  2131. args = super().get_args()
  2132. if self._sampler is not None:
  2133. args["sampler"] = self._sampler
  2134. args["children_flag_and_nums"] = self._children_flag_and_nums
  2135. args["children_start_end_index"] = self._children_start_end_index_
  2136. return args
  2137. class RenameDataset(DatasetOp):
  2138. """
  2139. The result of applying Rename operator to the input Dataset.
  2140. Args:
  2141. input_dataset (Dataset): Input Dataset to be Renamed.
  2142. input_columns (list[str]): List of names of the input columns.
  2143. output_columns (list[str]): List of names of the output columns.
  2144. """
  2145. def __init__(self, input_dataset, input_columns, output_columns):
  2146. super().__init__()
  2147. if not isinstance(input_columns, list):
  2148. input_columns = [input_columns]
  2149. if not isinstance(output_columns, list):
  2150. output_columns = [output_columns]
  2151. self.input_column_names = input_columns
  2152. self.output_column_names = output_columns
  2153. self.children.append(input_dataset)
  2154. input_dataset.parent.append(self)
  2155. self._input_indexs = input_dataset.input_indexs
  2156. def get_args(self):
  2157. args = super().get_args()
  2158. args["input_columns"] = self.input_column_names
  2159. args["output_columns"] = self.output_column_names
  2160. return args
  2161. class ProjectDataset(DatasetOp):
  2162. """
  2163. The result of applying Project operator to the input Dataset.
  2164. Args:
  2165. input_dataset (Dataset): Input Dataset to be Projected.
  2166. columns (list[str]): List of names of the columns to project.
  2167. prefetch_size (int, optional): Prefetch number of records ahead of the
  2168. user's request (default=None).
  2169. """
  2170. def __init__(self, input_dataset, columns, prefetch_size=None):
  2171. super().__init__()
  2172. if not isinstance(columns, list):
  2173. columns = [columns]
  2174. self.columns = columns
  2175. self.children.append(input_dataset)
  2176. self.prefetch_size = prefetch_size
  2177. input_dataset.parent.append(self)
  2178. self._input_indexs = input_dataset.input_indexs
  2179. def get_args(self):
  2180. args = super().get_args()
  2181. args["columns"] = self.columns
  2182. args["prefetch_size"] = self.prefetch_size
  2183. return args
  2184. class TransferDataset(DatasetOp):
  2185. """
  2186. The result of applying TDT operator to the input Dataset.
  2187. Args:
  2188. input_dataset (Dataset): Input Dataset to be transferred.
  2189. queue_name (str): Name of device queue.
  2190. device_id (int): ID of device.
  2191. device_type (str): Type of device, including "CPU", "GPU", and "Ascend".
  2192. send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True).
  2193. """
  2194. def __init__(self, input_dataset, queue_name, device_id, device_type, send_epoch_end=True):
  2195. super().__init__()
  2196. self.children.append(input_dataset)
  2197. input_dataset.parent.append(self)
  2198. self.queue_name = queue_name
  2199. self._input_indexs = input_dataset.input_indexs
  2200. self._device_type = device_type
  2201. self._device_id = device_id
  2202. self._send_epoch_end = send_epoch_end
  2203. self.iterator = None
  2204. def get_args(self):
  2205. args = super().get_args()
  2206. args["queue_name"] = self.queue_name
  2207. args["device_type"] = self._device_type
  2208. args["device_id"] = self._device_id
  2209. args["send_epoch_end"] = self._send_epoch_end
  2210. if hasattr(self.children[0], "__total_batch__"):
  2211. args["total_batch"] = self.children[0].__total_batch__
  2212. return args
  2213. def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
  2214. raise RuntimeError("TransferDataset is not iterable.")
  2215. def create_tuple_iterator(self, columns=None, num_epochs=-1, output_numpy=False):
  2216. raise RuntimeError("TransferDataset is not iterable.")
  2217. def __iter__(self):
  2218. raise RuntimeError("TransferDataset is not iterable.")
  2219. def output_shapes(self):
  2220. raise RuntimeError("TransferDataset does not support output_shapes.")
  2221. def output_types(self):
  2222. raise RuntimeError("TransferDataset does not support output_types.")
  2223. def send(self, num_epochs=-1):
  2224. # need to keep iterator alive so the executionTree is not destroyed
  2225. if self._noop_mode():
  2226. return
  2227. if self.iterator is not None:
  2228. del self.iterator
  2229. self.iterator = TupleIterator(self, num_epochs=num_epochs)
  2230. def stop_send(self):
  2231. self.iterator.depipeline.StopSend()
  2232. def continue_send(self):
  2233. self.iterator.depipeline.ContinueSend()
  2234. class RangeDataset(MappableDataset):
  2235. """
  2236. A source dataset that reads and parses datasets stored on disk in a range.
  2237. Args:
  2238. start (int): Starting index.
  2239. stop (int): Ending index.
  2240. step (int): Step size in the range specified by start and stop.
  2241. """
  2242. def __init__(self, start, stop, step):
  2243. super().__init__()
  2244. self.start = start
  2245. self.stop = stop
  2246. self.step = step
  2247. def get_args(self):
  2248. args = super().get_args()
  2249. args["start"] = self.start
  2250. args["stop"] = self.stop
  2251. args["step"] = self.step
  2252. return args
  2253. def is_shuffled(self):
  2254. return False
  2255. def is_sharded(self):
  2256. return False
  2257. def get_dataset_size(self):
  2258. if self.dataset_size is None:
  2259. self.dataset_size = math.ceil((self.stop - self.start) / self.step)
  2260. return self.dataset_size
  2261. def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False):
  2262. """
  2263. Create sampler based on user input.
  2264. Args:
  2265. num_samples (int): Number of samples.
  2266. input_sampler (Union[Iterable, Sampler]): Sampler from user.
  2267. shuffle (bool): Shuffle.
  2268. num_shards (int): Number of shard for sharding.
  2269. shard_id (int): Shard ID.
  2270. non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False).
  2271. """
  2272. if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]):
  2273. return None
  2274. if input_sampler is not None:
  2275. # If the user provided a sampler, then it doesn't matter what the other args are because
  2276. # we are being asked specifically to use the given sampler.
  2277. # That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
  2278. # be None. Consider this example:
  2279. # sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
  2280. # data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
  2281. # In this case, the user has given different sample-related arguments that contradict each other.
  2282. # To prevent this, only allow the user to manually specify the sampler if those arguments are all None
  2283. if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
  2284. samplers.RandomSampler, samplers.SubsetRandomSampler,
  2285. samplers.WeightedRandomSampler, samplers.Sampler)) and
  2286. (any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
  2287. raise ValueError(
  2288. 'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
  2289. ' shard_id: {}, shuffle: {})'.format(num_samples, num_shards, shard_id, shuffle))
  2290. return input_sampler
  2291. if shuffle is None:
  2292. if num_shards is not None:
  2293. # If shuffle is not specified, sharding enabled, use distributed random sampler
  2294. shuffle = True
  2295. return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
  2296. # If shuffle is not specified, sharding disabled, use random sampler
  2297. if num_samples is not None:
  2298. return samplers.RandomSampler(replacement=True, num_samples=num_samples)
  2299. return samplers.RandomSampler(num_samples=num_samples)
  2300. if shuffle is True:
  2301. if num_shards is not None:
  2302. # If shuffle enabled, sharding enabled, use distributed random sampler
  2303. return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
  2304. # If shuffle enabled, sharding disabled, use random sampler
  2305. if num_samples is not None:
  2306. return samplers.RandomSampler(replacement=True, num_samples=num_samples)
  2307. return samplers.RandomSampler(num_samples=num_samples)
  2308. if num_shards is not None:
  2309. # If shuffle disabled, sharding enabled, use distributed sequential sampler
  2310. return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
  2311. # If shuffle disabled, sharding disabled, use sequential sampler
  2312. return samplers.SequentialSampler(num_samples=num_samples)
  2313. class ImageFolderDataset(MappableDataset):
  2314. """
  2315. A source dataset that reads images from a tree of directories.
  2316. All images within one folder have the same label.
  2317. The generated dataset has two columns ['image', 'label'].
  2318. The shape of the image column is [image_size] if decode flag is False, or [H,W,C]
  2319. otherwise.
  2320. The type of the image tensor is uint8. The label is a scalar int32 tensor.
  2321. This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive. The table
  2322. below shows what input arguments are allowed and their expected behavior.
  2323. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  2324. :widths: 25 25 50
  2325. :header-rows: 1
  2326. * - Parameter 'sampler'
  2327. - Parameter 'shuffle'
  2328. - Expected Order Behavior
  2329. * - None
  2330. - None
  2331. - random order
  2332. * - None
  2333. - True
  2334. - random order
  2335. * - None
  2336. - False
  2337. - sequential order
  2338. * - Sampler object
  2339. - None
  2340. - order defined by sampler
  2341. * - Sampler object
  2342. - True
  2343. - not allowed
  2344. * - Sampler object
  2345. - False
  2346. - not allowed
  2347. Args:
  2348. dataset_dir (str): Path to the root directory that contains the dataset.
  2349. num_samples (int, optional): The number of images to be included in the dataset
  2350. (default=None, all images).
  2351. num_parallel_workers (int, optional): Number of workers to read the data
  2352. (default=None, set in the config).
  2353. shuffle (bool, optional): Whether or not to perform shuffle on the dataset
  2354. (default=None, expected order behavior shown in the table).
  2355. sampler (Sampler, optional): Object used to choose samples from the
  2356. dataset (default=None, expected order behavior shown in the table).
  2357. extensions (list[str], optional): List of file extensions to be
  2358. included in the dataset (default=None).
  2359. class_indexing (dict, optional): A str-to-int mapping from folder name to index
  2360. (default=None, the folder names will be sorted
  2361. alphabetically and each class will be given a
  2362. unique index starting from 0).
  2363. decode (bool, optional): Decode the images after reading (default=False).
  2364. num_shards (int, optional): Number of shards that the dataset will be divided
  2365. into (default=None).
  2366. shard_id (int, optional): The shard ID within num_shards (default=None). This
  2367. argument can only be specified when num_shards is also specified.
  2368. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
  2369. The cache feature is under development and is not recommended.
  2370. Raises:
  2371. RuntimeError: If sampler and shuffle are specified at the same time.
  2372. RuntimeError: If sampler and sharding are specified at the same time.
  2373. RuntimeError: If num_shards is specified but shard_id is None.
  2374. RuntimeError: If shard_id is specified but num_shards is None.
  2375. RuntimeError: If class_indexing is not a dictionary.
  2376. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  2377. Examples:
  2378. >>> import mindspore.dataset as ds
  2379. >>>
  2380. >>> # Set path to the imagefolder directory.
  2381. >>> # This directory needs to contain sub-directories which contain the images
  2382. >>> dataset_dir = "/path/to/imagefolder_directory"
  2383. >>>
  2384. >>> # 1) Read all samples (image files) in dataset_dir with 8 threads
  2385. >>> imagefolder_dataset = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8)
  2386. >>>
  2387. >>> # 2) Read all samples (image files) from folder cat and folder dog with label 0 and 1
  2388. >>> imagefolder_dataset = ds.ImageFolderDataset(dataset_dir, class_indexing={"cat":0, "dog":1})
  2389. >>>
  2390. >>> # 3) Read all samples (image files) in dataset_dir with extensions .JPEG and .png (case sensitive)
  2391. >>> imagefolder_dataset = ds.ImageFolderDataset(dataset_dir, extensions=[".JPEG", ".png"])
  2392. """
  2393. @check_imagefolderdataset
  2394. def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
  2395. shuffle=None, sampler=None, extensions=None, class_indexing=None,
  2396. decode=False, num_shards=None, shard_id=None, cache=None):
  2397. super().__init__(num_parallel_workers)
  2398. self.dataset_dir = dataset_dir
  2399. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  2400. self.num_samples = num_samples
  2401. self.shuffle_level = shuffle
  2402. self.extensions = extensions
  2403. self.class_indexing = class_indexing
  2404. self.decode = decode
  2405. self.num_shards = num_shards
  2406. self.shard_id = shard_id
  2407. self.cache = cache
  2408. def get_args(self):
  2409. args = super().get_args()
  2410. args["dataset_dir"] = self.dataset_dir
  2411. args["num_samples"] = self.num_samples
  2412. args["sampler"] = self.sampler
  2413. args["shuffle"] = self.shuffle_level
  2414. args["extensions"] = self.extensions
  2415. args["class_indexing"] = self.class_indexing
  2416. args["decode"] = self.decode
  2417. args["num_shards"] = self.num_shards
  2418. args["shard_id"] = self.shard_id
  2419. args["cache"] = self.cache.cache_client if self.cache is not None else None
  2420. return args
  2421. def get_dataset_size(self):
  2422. """
  2423. Get the number of batches in an epoch.
  2424. Return:
  2425. Number, number of batches.
  2426. """
  2427. if self.dataset_size is None:
  2428. num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[0]
  2429. self.dataset_size = get_num_rows(num_rows, self.num_shards)
  2430. rows_from_sampler = self._get_sampler_dataset_size()
  2431. if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
  2432. self.dataset_size = rows_from_sampler
  2433. return self.dataset_size
  2434. def num_classes(self):
  2435. """
  2436. Get the number of classes in dataset.
  2437. Return:
  2438. Number, number of classes.
  2439. """
  2440. return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[1]
  2441. def is_shuffled(self):
  2442. if self.shuffle_level is None:
  2443. return True
  2444. return self.shuffle_level or self.sampler.is_shuffled()
  2445. def is_sharded(self):
  2446. if self.num_shards is not None:
  2447. return self.num_shards > 1
  2448. return self.sampler.is_sharded()
  2449. class MnistDataset(MappableDataset):
  2450. """
  2451. A source dataset for reading and parsing the MNIST dataset.
  2452. The generated dataset has two columns ['image', 'label'].
  2453. The type of the image tensor is uint8. The label is a scalar uint32 tensor.
  2454. This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive. The table
  2455. below shows what input arguments are allowed and their expected behavior.
  2456. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  2457. :widths: 25 25 50
  2458. :header-rows: 1
  2459. * - Parameter 'sampler'
  2460. - Parameter 'shuffle'
  2461. - Expected Order Behavior
  2462. * - None
  2463. - None
  2464. - random order
  2465. * - None
  2466. - True
  2467. - random order
  2468. * - None
  2469. - False
  2470. - sequential order
  2471. * - Sampler object
  2472. - None
  2473. - order defined by sampler
  2474. * - Sampler object
  2475. - True
  2476. - not allowed
  2477. * - Sampler object
  2478. - False
  2479. - not allowed
  2480. Citation of Mnist dataset.
  2481. .. code-block::
  2482. @article{lecun2010mnist,
  2483. title = {MNIST handwritten digit database},
  2484. author = {LeCun, Yann and Cortes, Corinna and Burges, CJ},
  2485. journal = {ATT Labs [Online]},
  2486. volume = {2},
  2487. year = {2010},
  2488. howpublished = {http://yann.lecun.com/exdb/mnist},
  2489. description = {The MNIST database of handwritten digits has a training set of 60,000 examples,
  2490. and a test set of 10,000 examples. It is a subset of a larger set available from
  2491. NIST. The digits have been size-normalized and centered in a fixed-size image.}
  2492. }
  2493. Args:
  2494. dataset_dir (str): Path to the root directory that contains the dataset.
  2495. usage (str, optional): Usage of this dataset, can be "train", "test" or "all" . "train" will read from 60,000
  2496. train samples, "test" will read from 10,000 test samples, "all" will read from all 70,000 samples.
  2497. (default=None, all samples)
  2498. num_samples (int, optional): The number of images to be included in the dataset
  2499. (default=None, all images).
  2500. num_parallel_workers (int, optional): Number of workers to read the data
  2501. (default=None, set in the config).
  2502. shuffle (bool, optional): Whether or not to perform shuffle on the dataset
  2503. (default=None, expected order behavior shown in the table).
  2504. sampler (Sampler, optional): Object used to choose samples from the
  2505. dataset (default=None, expected order behavior shown in the table).
  2506. num_shards (int, optional): Number of shards that the dataset will be divided
  2507. into (default=None).
  2508. shard_id (int, optional): The shard ID within num_shards (default=None). This
  2509. argument can only be specified when num_shards is also specified.
  2510. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
  2511. The cache feature is under development and is not recommended.
  2512. Raises:
  2513. RuntimeError: If sampler and shuffle are specified at the same time.
  2514. RuntimeError: If sampler and sharding are specified at the same time.
  2515. RuntimeError: If num_shards is specified but shard_id is None.
  2516. RuntimeError: If shard_id is specified but num_shards is None.
  2517. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  2518. Examples:
  2519. >>> import mindspore.dataset as ds
  2520. >>>
  2521. >>> dataset_dir = "/path/to/mnist_folder"
  2522. >>> # Read 3 samples from MNIST dataset
  2523. >>> mnist_dataset = ds.MnistDataset(dataset_dir=dataset_dir, num_samples=3)
  2524. >>> # Note: In mnist_dataset dataset, each dictionary has keys "image" and "label"
  2525. """
  2526. @check_mnist_cifar_dataset
  2527. def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None,
  2528. shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None):
  2529. super().__init__(num_parallel_workers)
  2530. self.dataset_dir = dataset_dir
  2531. self.usage = usage
  2532. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  2533. self.num_samples = num_samples
  2534. self.shuffle_level = shuffle
  2535. self.num_shards = num_shards
  2536. self.shard_id = shard_id
  2537. self.cache = cache
  2538. def get_args(self):
  2539. args = super().get_args()
  2540. args["dataset_dir"] = self.dataset_dir
  2541. args["usage"] = self.usage
  2542. args["num_samples"] = self.num_samples
  2543. args["shuffle"] = self.shuffle_level
  2544. args["sampler"] = self.sampler
  2545. args["num_shards"] = self.num_shards
  2546. args["shard_id"] = self.shard_id
  2547. args["cache"] = self.cache.cache_client if self.cache is not None else None
  2548. return args
  2549. def get_dataset_size(self):
  2550. """
  2551. Get the number of batches in an epoch.
  2552. Return:
  2553. Number, number of batches.
  2554. """
  2555. if self.dataset_size is None:
  2556. num_rows = MnistOp.get_num_rows(self.dataset_dir, "all" if self.usage is None else self.usage)
  2557. self.dataset_size = get_num_rows(num_rows, self.num_shards)
  2558. rows_from_sampler = self._get_sampler_dataset_size()
  2559. if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
  2560. self.dataset_size = rows_from_sampler
  2561. return self.dataset_size
  2562. def is_shuffled(self):
  2563. if self.shuffle_level is None:
  2564. return True
  2565. return self.shuffle_level or self.sampler.is_shuffled()
  2566. def is_sharded(self):
  2567. if self.num_shards is not None:
  2568. return self.num_shards > 1
  2569. return self.sampler.is_sharded()
  2570. class MindDataset(MappableDataset):
  2571. """
  2572. A source dataset that reads MindRecord files.
  2573. Args:
  2574. dataset_file (Union[str, list[str]]): If dataset_file is a str, it represents for
  2575. a file name of one component of a mindrecord source, other files with identical source
  2576. in the same path will be found and loaded automatically. If dataset_file is a list,
  2577. it represents for a list of dataset files to be read directly.
  2578. columns_list (list[str], optional): List of columns to be read (default=None).
  2579. num_parallel_workers (int, optional): The number of readers (default=None).
  2580. shuffle (bool, optional): Whether or not to perform shuffle on the dataset
  2581. (default=None, performs shuffle).
  2582. num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
  2583. shard_id (int, optional): The shard ID within num_shards (default=None). This
  2584. argument can only be specified when num_shards is also specified.
  2585. sampler (Sampler, optional): Object used to choose samples from the
  2586. dataset (default=None, sampler is exclusive
  2587. with shuffle and block_reader). Support list: SubsetRandomSampler,
  2588. PkSampler, RandomSampler, SequentialSampler, DistributedSampler.
  2589. padded_sample (dict, optional): Samples will be appended to dataset, where
  2590. keys are the same as column_list.
  2591. num_padded (int, optional): Number of padding samples. Dataset size
  2592. plus num_padded should be divisible by num_shards.
  2593. num_samples (int, optional): The number of samples to be included in the dataset
  2594. (default=None, all samples).
  2595. Raises:
  2596. ValueError: If num_shards is specified but shard_id is None.
  2597. ValueError: If shard_id is specified but num_shards is None.
  2598. """
  2599. @check_minddataset
  2600. def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None,
  2601. shuffle=None, num_shards=None, shard_id=None,
  2602. sampler=None, padded_sample=None,
  2603. num_padded=None, num_samples=None):
  2604. super().__init__(num_parallel_workers)
  2605. if isinstance(dataset_file, list):
  2606. self.load_dataset = False
  2607. else:
  2608. self.load_dataset = True
  2609. self.dataset_file = dataset_file
  2610. self.columns_list = columns_list
  2611. self.shuffle_option = shuffle
  2612. self.num_shards = num_shards
  2613. self.shard_id = shard_id
  2614. if shuffle is False:
  2615. logger.warning("WARN: global shuffle is not used.")
  2616. if sampler is not None:
  2617. if isinstance(sampler, (samplers.SubsetRandomSampler, samplers.PKSampler,
  2618. samplers.DistributedSampler, samplers.RandomSampler,
  2619. samplers.SequentialSampler)) is False:
  2620. raise ValueError("The sampler is not supported yet.")
  2621. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  2622. self.num_samples = num_samples
  2623. if num_padded is None:
  2624. num_padded = 0
  2625. self.padded_sample = padded_sample
  2626. self.num_padded = num_padded
  2627. def get_args(self):
  2628. args = super().get_args()
  2629. padded_sample = None
  2630. if self.padded_sample:
  2631. padded_sample = {}
  2632. for k, v in self.padded_sample.items():
  2633. if isinstance(v, np.ndarray):
  2634. padded_sample[k] = v.tobytes()
  2635. else:
  2636. padded_sample[k] = v
  2637. args["dataset_file"] = self.dataset_file
  2638. args["load_dataset"] = self.load_dataset
  2639. args["columns_list"] = self.columns_list
  2640. args["shuffle_option"] = self.shuffle_option
  2641. args["num_samples"] = self.num_samples
  2642. args["num_padded"] = self.num_padded
  2643. args["padded_sample"] = padded_sample
  2644. args["sampler"] = self.sampler
  2645. return args
  2646. def get_dataset_size(self):
  2647. """
  2648. Get the number of batches in an epoch.
  2649. Return:
  2650. Number, number of batches.
  2651. """
  2652. if self.dataset_size is None:
  2653. if self.load_dataset:
  2654. dataset_file = [self.dataset_file]
  2655. else:
  2656. dataset_file = self.dataset_file
  2657. num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded)
  2658. self.dataset_size = num_rows
  2659. return self.dataset_size
  2660. def is_shuffled(self):
  2661. if self.shuffle_option is None:
  2662. return True
  2663. return self.shuffle_option or self.sampler.is_shuffled()
  2664. def is_sharded(self):
  2665. if self.num_shards is not None:
  2666. return self.num_shards > 1
  2667. return self.sampler.is_sharded()
  2668. def _iter_fn(dataset, num_samples):
  2669. """
  2670. Generator function wrapper for iterable dataset.
  2671. """
  2672. if num_samples is not None:
  2673. ds_iter = iter(dataset)
  2674. for _ in range(num_samples):
  2675. try:
  2676. val = next(ds_iter)
  2677. except StopIteration:
  2678. return
  2679. # convert output tensors to ndarrays
  2680. yield tuple([np.array(x, copy=False) for x in val])
  2681. else:
  2682. for val in dataset:
  2683. # convert output tensors to ndarrays
  2684. yield tuple([np.array(x, copy=False) for x in val])
  2685. def _generator_fn(generator, num_samples):
  2686. """
  2687. Generator function wrapper for generator function dataset.
  2688. """
  2689. if num_samples is not None:
  2690. gen_iter = generator()
  2691. for _ in range(num_samples):
  2692. try:
  2693. val = next(gen_iter)
  2694. except StopIteration:
  2695. return
  2696. yield val
  2697. else:
  2698. gen_iter = generator()
  2699. for val in gen_iter:
  2700. yield val
  2701. def _py_sampler_fn(sampler, num_samples, dataset):
  2702. """
  2703. Generator function wrapper for mappable dataset with Python sampler.
  2704. """
  2705. if num_samples is not None:
  2706. sampler_iter = iter(sampler)
  2707. for _ in range(num_samples):
  2708. try:
  2709. idx = next(sampler_iter)
  2710. except StopIteration:
  2711. return
  2712. val = dataset[idx]
  2713. # convert output tensors to ndarrays
  2714. yield tuple([np.array(x, copy=False) for x in val])
  2715. else:
  2716. for i in sampler:
  2717. val = dataset[i]
  2718. # convert output tensors to ndarrays
  2719. yield tuple([np.array(x, copy=False) for x in val])
  2720. def _cpp_sampler_fn(sampler, dataset):
  2721. """
  2722. Generator function wrapper for mappable dataset with cpp sampler.
  2723. """
  2724. indices = sampler.get_indices()
  2725. for i in indices:
  2726. val = dataset[i]
  2727. # convert output tensors to ndarrays
  2728. yield tuple([np.array(x, copy=False) for x in val])
  2729. def _cpp_sampler_fn_mp(sampler, sample_fn):
  2730. """
  2731. Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
  2732. """
  2733. indices = sampler.get_indices()
  2734. return sample_fn.process(indices)
  2735. def _py_sampler_fn_mp(sampler, num_samples, sample_fn):
  2736. """
  2737. Multiprocessing generator function wrapper for mappable dataset with Python sampler.
  2738. """
  2739. indices = _fetch_py_sampler_indices(sampler, num_samples)
  2740. return sample_fn.process(indices)
  2741. def _fetch_py_sampler_indices(sampler, num_samples):
  2742. """
  2743. Indice fetcher for Python sampler.
  2744. """
  2745. if num_samples is not None:
  2746. sampler_iter = iter(sampler)
  2747. ret = []
  2748. for _ in range(num_samples):
  2749. try:
  2750. val = next(sampler_iter)
  2751. ret.append(val)
  2752. except StopIteration:
  2753. break
  2754. return ret
  2755. return [i for i in sampler]
  2756. def _fill_worker_indices(workers, indices, idx):
  2757. """
  2758. Worker index queue filler, fill worker index queue in round robin order.
  2759. """
  2760. num_worker = len(workers)
  2761. while idx < len(indices):
  2762. try:
  2763. workers[idx % num_worker].put(indices[idx])
  2764. idx += 1
  2765. except queue.Full:
  2766. break
  2767. return idx
  2768. class SamplerFn:
  2769. """
  2770. Multiprocessing or multithread generator function wrapper master process.
  2771. """
  2772. def __init__(self, dataset, num_worker, multi_process):
  2773. self.workers = []
  2774. self.num_worker = num_worker
  2775. self.multi_process = multi_process
  2776. # Event for end of epoch
  2777. if multi_process is True:
  2778. self.eof = multiprocessing.Event()
  2779. else:
  2780. self.eof = threading.Event()
  2781. # Create workers
  2782. for _ in range(num_worker):
  2783. if multi_process is True:
  2784. worker = _GeneratorWorkerMp(dataset, self.eof)
  2785. worker.daemon = True
  2786. # When multi processes fork a subprocess, the lock of the main process is copied to the subprocess,
  2787. # which may cause deadlock. Therefore, the subprocess startup is performed in che initialization phase.
  2788. # In this phase, the main process is not locked.
  2789. worker.start()
  2790. else:
  2791. worker = _GeneratorWorkerMt(dataset, self.eof)
  2792. worker.daemon = True
  2793. self.workers.append(worker)
  2794. def process(self, indices):
  2795. """
  2796. The main process, start the child process or child thread, and fill the index queue.
  2797. Get the result and return.
  2798. """
  2799. for w in self.workers:
  2800. # Check whether the queue of the subprocess is empty.
  2801. if not w.queue_empty():
  2802. raise Exception("The queue of the subprocess is not empty.")
  2803. # Start all workers
  2804. if not w.is_alive():
  2805. w.start()
  2806. # Fill initial index queues
  2807. idx_cursor = 0
  2808. idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
  2809. # Fetch results
  2810. for i in range(len(indices)):
  2811. # Fetch result and put index
  2812. try:
  2813. result = self.workers[i % self.num_worker].get()
  2814. except queue.Empty:
  2815. raise Exception("Generator worker process timeout")
  2816. except KeyboardInterrupt:
  2817. self.eof.set()
  2818. for w in self.workers:
  2819. w.terminate()
  2820. w.join()
  2821. raise Exception("Generator worker receives KeyboardInterrupt")
  2822. if idx_cursor < len(indices):
  2823. idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
  2824. yield tuple([np.array(x, copy=False) for x in result])
  2825. def __del__(self):
  2826. self.eof.set()
  2827. def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
  2828. """
  2829. Multithread or multiprocess generator worker process loop.
  2830. """
  2831. while True:
  2832. # Fetch index, block
  2833. try:
  2834. idx = idx_queue.get(timeout=1)
  2835. except KeyboardInterrupt:
  2836. raise Exception("Generator worker receives KeyboardInterrupt")
  2837. except queue.Empty:
  2838. if eof.is_set():
  2839. return
  2840. # If end-of-file (eof) is not set, continue to get data from idx_queue
  2841. continue
  2842. if idx is None:
  2843. # When the queue is out of scope from master process, a None item can be fetched from the queue.
  2844. # Upon receiving None, worker process should check if eof is set.
  2845. assert eof.is_set(), ""
  2846. return
  2847. if eof.is_set():
  2848. return
  2849. # Fetch data, any exception from __getitem__ will terminate worker and timeout master process
  2850. result = dataset[idx]
  2851. # Send data, block
  2852. while True:
  2853. try:
  2854. result_queue.put(result, timeout=5)
  2855. except KeyboardInterrupt:
  2856. raise Exception("Generator worker receives KeyboardInterrupt")
  2857. except queue.Full:
  2858. if eof.is_set():
  2859. return
  2860. # If eof is not set, continue to put data to result_queue
  2861. continue
  2862. break
  2863. del result, idx
  2864. class _GeneratorWorkerMt(threading.Thread):
  2865. """
  2866. Worker process for multithread Generator.
  2867. """
  2868. def __init__(self, dataset, eof):
  2869. self.idx_queue = queue.Queue(16)
  2870. self.res_queue = queue.Queue(16)
  2871. super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof))
  2872. def put(self, item):
  2873. """
  2874. Put function for worker index queue. Never block. Raise queue.Full on failure.
  2875. """
  2876. self.idx_queue.put_nowait(item)
  2877. def get(self):
  2878. """
  2879. Get function for worker result queue. Block with timeout.
  2880. """
  2881. return self.res_queue.get(timeout=30)
  2882. def queue_empty(self):
  2883. if not self.idx_queue.empty():
  2884. logger.error("idx_queue is not empty")
  2885. return False
  2886. if not self.res_queue.empty():
  2887. logger.error("res_queue is not empty")
  2888. return False
  2889. return True
  2890. class _GeneratorWorkerMp(multiprocessing.Process):
  2891. """
  2892. Worker process for multiprocess Generator.
  2893. """
  2894. def __init__(self, dataset, eof):
  2895. self.idx_queue = multiprocessing.Queue(16)
  2896. self.res_queue = multiprocessing.Queue(16)
  2897. super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof))
  2898. def put(self, item):
  2899. """
  2900. Put function for worker index queue. Never block. Raise queue.Full on failure.
  2901. """
  2902. self.idx_queue.put_nowait(item)
  2903. def get(self):
  2904. """
  2905. Get function for worker result queue. Block with timeout.
  2906. """
  2907. # Relax 10s to 30s, since it sometimes will cause "Generator worker process timeout"
  2908. # when we run too many iterators with infinite epoch(num_epoch=-1)
  2909. return self.res_queue.get(timeout=30)
  2910. def queue_empty(self):
  2911. if not self.idx_queue.empty():
  2912. logger.error("idx_queue is not empty")
  2913. return False
  2914. if not self.res_queue.empty():
  2915. logger.error("res_queue is not empty")
  2916. return False
  2917. return True
  2918. def __del__(self):
  2919. # Try to destruct here, sometimes the class itself will be destructed in advance,
  2920. # so "self" will be a NoneType
  2921. try:
  2922. self.terminate()
  2923. except AttributeError:
  2924. pass
  2925. class GeneratorDataset(MappableDataset):
  2926. """
  2927. A source dataset that generates data from Python by invoking Python data source each epoch.
  2928. This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive. The table
  2929. below shows what input arguments are allowed and their expected behavior.
  2930. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  2931. :widths: 25 25 50
  2932. :header-rows: 1
  2933. * - Parameter 'sampler'
  2934. - Parameter 'shuffle'
  2935. - Expected Order Behavior
  2936. * - None
  2937. - None
  2938. - random order
  2939. * - None
  2940. - True
  2941. - random order
  2942. * - None
  2943. - False
  2944. - sequential order
  2945. * - Sampler object
  2946. - None
  2947. - order defined by sampler
  2948. * - Sampler object
  2949. - True
  2950. - not allowed
  2951. * - Sampler object
  2952. - False
  2953. - not allowed
  2954. Args:
  2955. source (Union[Callable, Iterable, Random Accessible]):
  2956. A generator callable object, an iterable Python object or a random accessible Python object.
  2957. Callable source is required to return a tuple of NumPy arrays as a row of the dataset on source().next().
  2958. Iterable source is required to return a tuple of NumPy arrays as a row of the dataset on
  2959. iter(source).next().
  2960. Random accessible source is required to return a tuple of NumPy arrays as a row of the dataset on
  2961. source[idx].
  2962. column_names (list[str], optional): List of column names of the dataset (default=None). Users are required to
  2963. provide either column_names or schema.
  2964. column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None).
  2965. If provided, sanity check will be performed on generator output.
  2966. schema (Union[Schema, str], optional): Path to the JSON schema file or schema object (default=None). Users are
  2967. required to provide either column_names or schema. If both are provided, schema will be used.
  2968. num_samples (int, optional): The number of samples to be included in the dataset
  2969. (default=None, all images).
  2970. num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
  2971. shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
  2972. (default=None, expected order behavior shown in the table).
  2973. sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset. Random accessible
  2974. input is required (default=None, expected order behavior shown in the table).
  2975. num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
  2976. When this argument is specified, 'num_samples' will not used. Random accessible input is required.
  2977. shard_id (int, optional): The shard ID within num_shards (default=None). This argument must be specified only
  2978. when num_shards is also specified. Random accessible input is required.
  2979. python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
  2980. option could be beneficial if the Python operation is computational heavy (default=True).
  2981. Examples:
  2982. >>> import mindspore.dataset as ds
  2983. >>>
  2984. >>> # 1) Multidimensional generator function as callable input
  2985. >>> def GeneratorMD():
  2986. >>> for i in range(64):
  2987. >>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
  2988. >>> # Create multi_dimension_generator_dataset with GeneratorMD and column name "multi_dimensional_data"
  2989. >>> multi_dimension_generator_dataset = ds.GeneratorDataset(GeneratorMD, ["multi_dimensional_data"])
  2990. >>>
  2991. >>> # 2) Multi-column generator function as callable input
  2992. >>> def GeneratorMC(maxid = 64):
  2993. >>> for i in range(maxid):
  2994. >>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
  2995. >>> # Create multi_column_generator_dataset with GeneratorMC and column names "col1" and "col2"
  2996. >>> multi_column_generator_dataset = ds.GeneratorDataset(GeneratorMC, ["col1", "col2"])
  2997. >>>
  2998. >>> # 3) Iterable dataset as iterable input
  2999. >>> class MyIterable():
  3000. >>> def __iter__(self):
  3001. >>> return # User implementation
  3002. >>> # Create iterable_generator_dataset with MyIterable object
  3003. >>> iterable_generator_dataset = ds.GeneratorDataset(MyIterable(), ["col1"])
  3004. >>>
  3005. >>> # 4) Random accessible dataset as random accessible input
  3006. >>> class MyRA():
  3007. >>> def __getitem__(self, index):
  3008. >>> return # User implementation
  3009. >>> # Create ra_generator_dataset with MyRA object
  3010. >>> ra_generator_dataset = ds.GeneratorDataset(MyRA(), ["col1"])
  3011. >>> # List/Dict/Tuple is also random accessible
  3012. >>> list_generator = ds.GeneratorDataset([(np.array(0),), (np.array(1)), (np.array(2))], ["col1"])
  3013. >>>
  3014. >>> # 5) Built-in Sampler
  3015. >>> my_generator = ds.GeneratorDataset(my_ds, ["img", "label"], sampler=samplers.RandomSampler())
  3016. """
  3017. @check_generatordataset
  3018. def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
  3019. num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
  3020. python_multiprocessing=True):
  3021. super().__init__(num_parallel_workers)
  3022. self.source = source
  3023. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  3024. self.num_samples = num_samples
  3025. self.num_shards = num_shards
  3026. self.python_multiprocessing = python_multiprocessing
  3027. if column_names is not None and not isinstance(column_names, list):
  3028. column_names = [column_names]
  3029. self.column_names = column_names
  3030. if column_types is not None:
  3031. self.column_types = mstypelist_to_detypelist(column_types)
  3032. else:
  3033. self.column_types = column_types
  3034. if schema is not None:
  3035. self.schema = schema
  3036. if not isinstance(schema, Schema):
  3037. self.schema = Schema(schema)
  3038. self.column_names = []
  3039. self.column_types = []
  3040. for col in self.schema.columns:
  3041. self.column_names.append(col["name"])
  3042. self.column_types.append(DataType(col["type"]))
  3043. def get_args(self):
  3044. args = super().get_args()
  3045. args["source"] = self.source
  3046. args["column_names"] = self.column_names
  3047. args["column_types"] = self.column_types
  3048. return args
  3049. def get_dataset_size(self):
  3050. """
  3051. Get the number of batches in an epoch.
  3052. Return:
  3053. Number, number of batches.
  3054. """
  3055. if self.dataset_size is None:
  3056. if hasattr(self.source, "__len__"):
  3057. if not self.num_shards:
  3058. self.dataset_size = len(self.source)
  3059. else:
  3060. self.dataset_size = math.ceil(len(self.source) / self.num_shards)
  3061. rows_from_sampler = self._get_sampler_dataset_size()
  3062. if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
  3063. self.dataset_size = rows_from_sampler
  3064. else:
  3065. num_rows = 0
  3066. for _ in self.create_dict_iterator(num_epochs=1, output_numpy=True):
  3067. num_rows += 1
  3068. self.dataset_size = num_rows
  3069. return self.dataset_size
  3070. def __deepcopy__(self, memodict):
  3071. if id(self) in memodict:
  3072. return memodict[id(self)]
  3073. cls = self.__class__
  3074. new_op = cls.__new__(cls)
  3075. memodict[id(self)] = new_op
  3076. new_op.children = copy.deepcopy(self.children, memodict)
  3077. new_op.parent = copy.deepcopy(self.parent, memodict)
  3078. new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
  3079. new_op.column_types = copy.deepcopy(self.column_types, memodict)
  3080. new_op.column_names = copy.deepcopy(self.column_names, memodict)
  3081. new_op.num_samples = copy.deepcopy(self.num_samples, memodict)
  3082. new_op.dataset_size = self.dataset_size
  3083. new_op.sampler = copy.deepcopy(self.sampler)
  3084. if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
  3085. if isinstance(new_op.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
  3086. samplers.RandomSampler, samplers.SubsetRandomSampler,
  3087. samplers.WeightedRandomSampler, samplers.Sampler)):
  3088. sampler_instance = new_op.sampler.create()
  3089. sampler_instance.set_num_rows(len(self.source))
  3090. sampler_instance.initialize()
  3091. if new_op.num_parallel_workers > 1:
  3092. sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
  3093. new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, sample_fn))
  3094. else:
  3095. new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source))
  3096. else:
  3097. if new_op.num_parallel_workers > 1:
  3098. sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
  3099. new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, sample_fn))
  3100. else:
  3101. new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source))
  3102. else:
  3103. try:
  3104. iter(self.source)
  3105. except TypeError:
  3106. # Use generator function if input callable
  3107. new_op.source = (lambda: _generator_fn(self.source, new_op.num_samples))
  3108. else:
  3109. # Use iterator function if input is iterable
  3110. # Random accessible input is also iterable
  3111. new_op.source = (lambda: _iter_fn(self.source, new_op.num_samples))
  3112. return new_op
  3113. def is_shuffled(self):
  3114. return self.sampler.is_shuffled()
  3115. def is_sharded(self):
  3116. return self.sampler.is_sharded()
  3117. class TFRecordDataset(SourceDataset):
  3118. """
  3119. A source dataset that reads and parses datasets stored on disk in TFData format.
  3120. Args:
  3121. dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a
  3122. pattern of files. The list will be sorted in a lexicographical order.
  3123. schema (Union[str, Schema], optional): Path to the JSON schema file or schema object (default=None).
  3124. If the schema is not provided, the meta data from the TFData file is considered the schema.
  3125. columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
  3126. num_samples (int, optional): Number of samples (rows) to read (default=None).
  3127. If num_samples is None and numRows(parsed from schema) does not exist, read the full dataset;
  3128. If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows;
  3129. If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
  3130. num_parallel_workers (int, optional): Number of workers to read the data
  3131. (default=None, number set in the config).
  3132. shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
  3133. (default=Shuffle.GLOBAL).
  3134. If shuffle is False, no shuffling will be performed;
  3135. If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
  3136. Otherwise, there are two levels of shuffling:
  3137. - Shuffle.GLOBAL: Shuffle both the files and samples.
  3138. - Shuffle.FILES: Shuffle files only.
  3139. num_shards (int, optional): Number of shards that the dataset will be divided
  3140. into (default=None).
  3141. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3142. argument can only be specified when num_shards is also specified.
  3143. shard_equal_rows (bool, optional): Get equal rows for all shards(default=False). If shard_equal_rows
  3144. is false, number of rows of each shard may be not equal.
  3145. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
  3146. The cache feature is under development and is not recommended.
  3147. Examples:
  3148. >>> import mindspore.dataset as ds
  3149. >>> import mindspore.common.dtype as mstype
  3150. >>>
  3151. >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple tf data files
  3152. >>>
  3153. >>> # 1) Get all rows from dataset_files with no explicit schema
  3154. >>> # The meta-data in the first row will be used as a schema.
  3155. >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files)
  3156. >>>
  3157. >>> # 2) Get all rows from dataset_files with user-defined schema
  3158. >>> schema = ds.Schema()
  3159. >>> schema.add_column('col_1d', de_type=mindspore.int64, shape=[2])
  3160. >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema=schema)
  3161. >>>
  3162. >>> # 3) Get all rows from dataset_files with schema file "./schema.json"
  3163. >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json")
  3164. """
  3165. @check_tfrecorddataset
  3166. def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
  3167. shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None):
  3168. super().__init__(num_parallel_workers)
  3169. self.dataset_files = self._find_files(dataset_files)
  3170. self.dataset_files.sort()
  3171. self.num_shards = num_shards
  3172. self.shard_id = shard_id
  3173. schema_obj = None
  3174. if (schema is not None) and (not isinstance(schema, Schema)):
  3175. schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it
  3176. self.schema = schema
  3177. self.columns_list = columns_list
  3178. self.num_samples = num_samples
  3179. self.cache = cache
  3180. if schema_obj is not None and num_samples is None:
  3181. self.num_samples = schema_obj.num_rows
  3182. if not isinstance(shuffle, (bool, Shuffle)):
  3183. raise TypeError("shuffle must be of type boolean or enum 'Shuffle'.")
  3184. if not isinstance(shuffle, Shuffle):
  3185. if shuffle:
  3186. self.shuffle_level = Shuffle.GLOBAL
  3187. self.shuffle_files = True
  3188. else:
  3189. self.shuffle_level = None
  3190. self.shuffle_files = False
  3191. else:
  3192. self.shuffle_level = shuffle
  3193. self.shuffle_files = True
  3194. # The TF record dataset does not directly support a sampler. It has provided sampling arguments
  3195. # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
  3196. # the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
  3197. sampler_shuffle = self.shuffle_files
  3198. sampler = None
  3199. self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id,
  3200. non_mappable=True)
  3201. self.shard_equal_rows = shard_equal_rows
  3202. def get_args(self):
  3203. args = super().get_args()
  3204. args["dataset_files"] = self.dataset_files
  3205. if self.schema is not None:
  3206. if isinstance(self.schema, Schema):
  3207. self.schema.datasetType = 'TF'
  3208. if self.num_samples is not None:
  3209. self.schema.num_rows = self.num_samples
  3210. args["schema_json_string"] = self.schema.to_json()
  3211. else:
  3212. args["schema_file_path"] = self.schema
  3213. args["schema"] = self.schema
  3214. args["columns_list"] = self.columns_list
  3215. args["num_samples"] = self.num_samples
  3216. if self.shuffle_files is not None:
  3217. args["shuffle_files"] = self.shuffle_files
  3218. args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL)
  3219. args["shuffle"] = self.shuffle_level
  3220. args["num_shards"] = self.num_shards
  3221. args["shard_id"] = self.shard_id
  3222. args["shard_equal_rows"] = self.shard_equal_rows
  3223. args["cache"] = self.cache.cache_client if self.cache is not None else None
  3224. args["sampler"] = self.sampler
  3225. return args
  3226. def get_dataset_size(self, estimate=False):
  3227. """
  3228. Get the number of batches in an epoch.
  3229. Note:
  3230. Because the TFRecord format does not save metadata, all files need to be traversed to obtain
  3231. the total amount of data. Therefore, this api is slow.
  3232. Args:
  3233. estimate (bool, optional): Fast estimation of the dataset size instead of a full scan.
  3234. Return:
  3235. Number, number of batches.
  3236. """
  3237. if self.dataset_size is None:
  3238. num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate)
  3239. self.dataset_size = get_num_rows(num_rows, self.num_shards)
  3240. if self.num_samples is not None and self.num_samples < self.dataset_size:
  3241. self.dataset_size = self.num_samples
  3242. return self.dataset_size
  3243. def is_shuffled(self):
  3244. return self.shuffle_files
  3245. def is_sharded(self):
  3246. if self.num_shards is not None:
  3247. return self.num_shards > 1
  3248. return False
  3249. class ManifestDataset(MappableDataset):
  3250. """
  3251. A source dataset that reads images from a manifest file.
  3252. The generated dataset has two columns ['image', 'label'].
  3253. The shape of the image column is [image_size] if decode flag is False, or [H,W,C]
  3254. otherwise.
  3255. The type of the image tensor is uint8. The label is a scalar uint64 tensor.
  3256. This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive. The table
  3257. below shows what input arguments are allowed and their expected behavior.
  3258. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  3259. :widths: 25 25 50
  3260. :header-rows: 1
  3261. * - Parameter 'sampler'
  3262. - Parameter 'shuffle'
  3263. - Expected Order Behavior
  3264. * - None
  3265. - None
  3266. - random order
  3267. * - None
  3268. - True
  3269. - random order
  3270. * - None
  3271. - False
  3272. - sequential order
  3273. * - Sampler object
  3274. - None
  3275. - order defined by sampler
  3276. * - Sampler object
  3277. - True
  3278. - not allowed
  3279. * - Sampler object
  3280. - False
  3281. - not allowed
  3282. Args:
  3283. dataset_file (str): File to be read.
  3284. usage (str, optional): acceptable usages include train, eval and inference (default="train").
  3285. num_samples (int, optional): The number of images to be included in the dataset.
  3286. (default=None, all images).
  3287. num_parallel_workers (int, optional): Number of workers to read the data
  3288. (default=None, number set in the config).
  3289. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  3290. order behavior shown in the table).
  3291. sampler (Sampler, optional): Object used to choose samples from the
  3292. dataset (default=None, expected order behavior shown in the table).
  3293. class_indexing (dict, optional): A str-to-int mapping from label name to index
  3294. (default=None, the folder names will be sorted alphabetically and each
  3295. class will be given a unique index starting from 0).
  3296. decode (bool, optional): decode the images after reading (default=False).
  3297. num_shards (int, optional): Number of shards that the dataset will be divided
  3298. into (default=None).
  3299. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3300. argument can only be specified when num_shards is also specified.
  3301. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
  3302. The cache feature is under development and is not recommended.
  3303. Raises:
  3304. RuntimeError: If sampler and shuffle are specified at the same time.
  3305. RuntimeError: If sampler and sharding are specified at the same time.
  3306. RuntimeError: If num_shards is specified but shard_id is None.
  3307. RuntimeError: If shard_id is specified but num_shards is None.
  3308. RuntimeError: If class_indexing is not a dictionary.
  3309. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  3310. Examples:
  3311. >>> import mindspore.dataset as ds
  3312. >>>
  3313. >>> dataset_file = "/path/to/manifest_file.manifest"
  3314. >>>
  3315. >>> # 1) Read all samples specified in manifest_file dataset with 8 threads for training
  3316. >>> manifest_dataset = ds.ManifestDataset(dataset_file, usage="train", num_parallel_workers=8)
  3317. >>>
  3318. >>> # 2) Read samples (specified in manifest_file.manifest) for shard 0
  3319. >>> # in a 2-way distributed training setup
  3320. >>> manifest_dataset = ds.ManifestDataset(dataset_file, num_shards=2, shard_id=0)
  3321. """
  3322. @check_manifestdataset
  3323. def __init__(self, dataset_file, usage="train", num_samples=None, num_parallel_workers=None,
  3324. shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None,
  3325. cache=None):
  3326. super().__init__(num_parallel_workers)
  3327. self.dataset_file = dataset_file
  3328. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  3329. if class_indexing is not None and not isinstance(class_indexing, dict):
  3330. raise RuntimeError("class_indexing must be a dictionary.")
  3331. self.num_samples = num_samples
  3332. self.class_indexing = class_indexing
  3333. self.decode = decode
  3334. self.usage = usage
  3335. self.shuffle_level = shuffle
  3336. self.num_shards = num_shards
  3337. self.shard_id = shard_id
  3338. self.cache = cache
  3339. def get_args(self):
  3340. args = super().get_args()
  3341. args["dataset_file"] = self.dataset_file
  3342. args["usage"] = self.usage
  3343. args["num_samples"] = self.num_samples
  3344. args["shuffle"] = self.shuffle_level
  3345. args["sampler"] = self.sampler
  3346. args["class_indexing"] = self.class_indexing
  3347. args["decode"] = self.decode
  3348. args["num_shards"] = self.num_shards
  3349. args["shard_id"] = self.shard_id
  3350. args["cache"] = self.cache.cache_client if self.cache is not None else None
  3351. return args
  3352. def get_dataset_size(self):
  3353. """
  3354. Get the number of batches in an epoch.
  3355. Return:
  3356. Number, number of batches.
  3357. """
  3358. if self.dataset_size is None:
  3359. if self.class_indexing is None:
  3360. class_indexing = dict()
  3361. else:
  3362. class_indexing = self.class_indexing
  3363. num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[0]
  3364. self.dataset_size = get_num_rows(num_rows, self.num_shards)
  3365. rows_from_sampler = self._get_sampler_dataset_size()
  3366. if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
  3367. self.dataset_size = rows_from_sampler
  3368. return self.dataset_size
  3369. def num_classes(self):
  3370. """
  3371. Get the number of classes in a dataset.
  3372. Return:
  3373. Number, number of classes.
  3374. """
  3375. if self.class_indexing is None:
  3376. class_indexing = dict()
  3377. else:
  3378. class_indexing = self.class_indexing
  3379. return ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[1]
  3380. def get_class_indexing(self):
  3381. """
  3382. Get the class index.
  3383. Return:
  3384. Dict, A str-to-int mapping from label name to index.
  3385. """
  3386. if self.class_indexing is None:
  3387. class_indexing = dict()
  3388. else:
  3389. class_indexing = self.class_indexing
  3390. return ManifestOp.get_class_indexing(self.dataset_file, class_indexing, self.usage)
  3391. def is_shuffled(self):
  3392. if self.shuffle_level is None:
  3393. return True
  3394. return self.shuffle_level or self.sampler.is_shuffled()
  3395. def is_sharded(self):
  3396. if self.num_shards is not None:
  3397. return self.num_shards > 1
  3398. return self.sampler.is_sharded()
  3399. class Cifar10Dataset(MappableDataset):
  3400. """
  3401. A source dataset that reads cifar10 data.
  3402. The generated dataset has two columns ['image', 'label'].
  3403. The type of the image tensor is uint8. The label is a scalar uint32 tensor.
  3404. This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive. The table
  3405. below shows what input arguments are allowed and their expected behavior.
  3406. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  3407. :widths: 25 25 50
  3408. :header-rows: 1
  3409. * - Parameter 'sampler'
  3410. - Parameter 'shuffle'
  3411. - Expected Order Behavior
  3412. * - None
  3413. - None
  3414. - random order
  3415. * - None
  3416. - True
  3417. - random order
  3418. * - None
  3419. - False
  3420. - sequential order
  3421. * - Sampler object
  3422. - None
  3423. - order defined by sampler
  3424. * - Sampler object
  3425. - True
  3426. - not allowed
  3427. * - Sampler object
  3428. - False
  3429. - not allowed
  3430. Citation of Cifar10 dataset.
  3431. .. code-block::
  3432. @techreport{Krizhevsky09,
  3433. author = {Alex Krizhevsky},
  3434. title = {Learning multiple layers of features from tiny images},
  3435. institution = {},
  3436. year = {2009},
  3437. howpublished = {http://www.cs.toronto.edu/~kriz/cifar.html},
  3438. description = {The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes,
  3439. with 6000 images per class. There are 50000 training images and 10000 test images.}
  3440. }
  3441. Args:
  3442. dataset_dir (str): Path to the root directory that contains the dataset.
  3443. usage (str, optional): Usage of this dataset, can be "train", "test" or "all" . "train" will read from 50,000
  3444. train samples, "test" will read from 10,000 test samples, "all" will read from all 60,000 samples.
  3445. (default=None, all samples)
  3446. num_samples (int, optional): The number of images to be included in the dataset.
  3447. (default=None, all images).
  3448. num_parallel_workers (int, optional): Number of workers to read the data
  3449. (default=None, number set in the config).
  3450. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  3451. order behavior shown in the table).
  3452. sampler (Sampler, optional): Object used to choose samples from the
  3453. dataset (default=None, expected order behavior shown in the table).
  3454. num_shards (int, optional): Number of shards that the dataset will be divided
  3455. into (default=None).
  3456. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3457. argument can only be specified when num_shards is also specified.
  3458. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
  3459. The cache feature is under development and is not recommended.
  3460. Raises:
  3461. RuntimeError: If sampler and shuffle are specified at the same time.
  3462. RuntimeError: If sampler and sharding are specified at the same time.
  3463. RuntimeError: If num_shards is specified but shard_id is None.
  3464. RuntimeError: If shard_id is specified but num_shards is None.
  3465. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  3466. Examples:
  3467. >>> import mindspore.dataset as ds
  3468. >>>
  3469. >>> dataset_dir = "/path/to/cifar10_dataset_directory"
  3470. >>>
  3471. >>> # 1) Get all samples from CIFAR10 dataset in sequence
  3472. >>> dataset = ds.Cifar10Dataset(dataset_dir=dataset_dir, shuffle=False)
  3473. >>>
  3474. >>> # 2) Randomly select 350 samples from CIFAR10 dataset
  3475. >>> dataset = ds.Cifar10Dataset(dataset_dir=dataset_dir, num_samples=350, shuffle=True)
  3476. >>>
  3477. >>> # 3) Get samples from CIFAR10 dataset for shard 0 in a 2-way distributed training
  3478. >>> dataset = ds.Cifar10Dataset(dataset_dir=dataset_dir, num_shards=2, shard_id=0)
  3479. >>>
  3480. >>> # In CIFAR10 dataset, each dictionary has keys "image" and "label"
  3481. """
  3482. @check_mnist_cifar_dataset
  3483. def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None,
  3484. shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None):
  3485. super().__init__(num_parallel_workers)
  3486. self.dataset_dir = dataset_dir
  3487. self.usage = usage
  3488. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  3489. self.num_samples = num_samples
  3490. self.num_shards = num_shards
  3491. self.shard_id = shard_id
  3492. self.shuffle_level = shuffle
  3493. self.cache = cache
  3494. def get_args(self):
  3495. args = super().get_args()
  3496. args["dataset_dir"] = self.dataset_dir
  3497. args["usage"] = self.usage
  3498. args["num_samples"] = self.num_samples
  3499. args["sampler"] = self.sampler
  3500. args["num_shards"] = self.num_shards
  3501. args["shard_id"] = self.shard_id
  3502. args["shuffle"] = self.shuffle_level
  3503. args["cache"] = self.cache.cache_client if self.cache is not None else None
  3504. return args
  3505. def get_dataset_size(self):
  3506. """
  3507. Get the number of batches in an epoch.
  3508. Return:
  3509. Number, number of batches.
  3510. """
  3511. if self.dataset_size is None:
  3512. num_rows = CifarOp.get_num_rows(self.dataset_dir, "all" if self.usage is None else self.usage, True)
  3513. self.dataset_size = get_num_rows(num_rows, self.num_shards)
  3514. rows_from_sampler = self._get_sampler_dataset_size()
  3515. if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
  3516. self.dataset_size = rows_from_sampler
  3517. return self.dataset_size
  3518. def is_shuffled(self):
  3519. if self.shuffle_level is None:
  3520. return True
  3521. return self.shuffle_level or self.sampler.is_shuffled()
  3522. def is_sharded(self):
  3523. if self.num_shards is not None:
  3524. return self.num_shards > 1
  3525. return self.sampler.is_sharded()
  3526. class Cifar100Dataset(MappableDataset):
  3527. """
  3528. A source dataset that reads cifar100 data.
  3529. The generated dataset has three columns ['image', 'coarse_label', 'fine_label'].
  3530. The type of the image tensor is uint8. The coarse and fine labels are each a scalar uint32 tensor.
  3531. This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive. The table
  3532. below shows what input arguments are allowed and their expected behavior.
  3533. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  3534. :widths: 25 25 50
  3535. :header-rows: 1
  3536. * - Parameter 'sampler'
  3537. - Parameter 'shuffle'
  3538. - Expected Order Behavior
  3539. * - None
  3540. - None
  3541. - random order
  3542. * - None
  3543. - True
  3544. - random order
  3545. * - None
  3546. - False
  3547. - sequential order
  3548. * - Sampler object
  3549. - None
  3550. - order defined by sampler
  3551. * - Sampler object
  3552. - True
  3553. - not allowed
  3554. * - Sampler object
  3555. - False
  3556. - not allowed
  3557. Citation of Cifar100 dataset.
  3558. .. code-block::
  3559. @techreport{Krizhevsky09,
  3560. author = {Alex Krizhevsky},
  3561. title = {Learning multiple layers of features from tiny images},
  3562. institution = {},
  3563. year = {2009},
  3564. howpublished = {http://www.cs.toronto.edu/~kriz/cifar.html},
  3565. description = {This dataset is just like the CIFAR-10, except it has 100 classes containing 600 images
  3566. each. There are 500 training images and 100 testing images per class. The 100 classes in
  3567. the CIFAR-100 are grouped into 20 superclasses. Each image comes with a "fine" label (the
  3568. class to which it belongs) and a "coarse" label (the superclass to which it belongs).}
  3569. }
  3570. Args:
  3571. dataset_dir (str): Path to the root directory that contains the dataset.
  3572. usage (str, optional): Usage of this dataset, can be "train", "test" or "all" . "train" will read from 50,000
  3573. train samples, "test" will read from 10,000 test samples, "all" will read from all 60,000 samples.
  3574. (default=None, all samples)
  3575. num_samples (int, optional): The number of images to be included in the dataset.
  3576. (default=None, all images).
  3577. num_parallel_workers (int, optional): Number of workers to read the data
  3578. (default=None, number set in the config).
  3579. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  3580. order behavior shown in the table).
  3581. sampler (Sampler, optional): Object used to choose samples from the
  3582. dataset (default=None, expected order behavior shown in the table).
  3583. num_shards (int, optional): Number of shards that the dataset will be divided
  3584. into (default=None).
  3585. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3586. argument can only be specified when num_shards is also specified.
  3587. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
  3588. The cache feature is under development and is not recommended.
  3589. Raises:
  3590. RuntimeError: If sampler and shuffle are specified at the same time.
  3591. RuntimeError: If sampler and sharding are specified at the same time.
  3592. RuntimeError: If num_shards is specified but shard_id is None.
  3593. RuntimeError: If shard_id is specified but num_shards is None.
  3594. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  3595. Examples:
  3596. >>> import mindspore.dataset as ds
  3597. >>>
  3598. >>> dataset_dir = "/path/to/cifar100_dataset_directory"
  3599. >>>
  3600. >>> # 1) Get all samples from CIFAR100 dataset in sequence
  3601. >>> cifar100_dataset = ds.Cifar100Dataset(dataset_dir=dataset_dir, shuffle=False)
  3602. >>>
  3603. >>> # 2) Randomly select 350 samples from CIFAR100 dataset
  3604. >>> cifar100_dataset = ds.Cifar100Dataset(dataset_dir=dataset_dir, num_samples=350, shuffle=True)
  3605. >>>
  3606. >>> # In CIFAR100 dataset, each dictionary has 3 keys: "image", "fine_label" and "coarse_label"
  3607. """
  3608. @check_mnist_cifar_dataset
  3609. def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None,
  3610. shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None):
  3611. super().__init__(num_parallel_workers)
  3612. self.dataset_dir = dataset_dir
  3613. self.usage = usage
  3614. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  3615. self.num_samples = num_samples
  3616. self.num_shards = num_shards
  3617. self.shard_id = shard_id
  3618. self.shuffle_level = shuffle
  3619. self.cache = cache
  3620. def get_args(self):
  3621. args = super().get_args()
  3622. args["dataset_dir"] = self.dataset_dir
  3623. args["usage"] = self.usage
  3624. args["num_samples"] = self.num_samples
  3625. args["sampler"] = self.sampler
  3626. args["num_shards"] = self.num_shards
  3627. args["shard_id"] = self.shard_id
  3628. args["shuffle"] = self.shuffle_level
  3629. args["cache"] = self.cache.cache_client if self.cache is not None else None
  3630. return args
  3631. def get_dataset_size(self):
  3632. """
  3633. Get the number of batches in an epoch.
  3634. Return:
  3635. Number, number of batches.
  3636. """
  3637. if self.dataset_size is None:
  3638. num_rows = CifarOp.get_num_rows(self.dataset_dir, "all" if self.usage is None else self.usage, False)
  3639. self.dataset_size = get_num_rows(num_rows, self.num_shards)
  3640. rows_from_sampler = self._get_sampler_dataset_size()
  3641. if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
  3642. self.dataset_size = rows_from_sampler
  3643. return self.dataset_size
  3644. def is_shuffled(self):
  3645. if self.shuffle_level is None:
  3646. return True
  3647. return self.shuffle_level or self.sampler.is_shuffled()
  3648. def is_sharded(self):
  3649. if self.num_shards is not None:
  3650. return self.num_shards > 1
  3651. return self.sampler.is_sharded()
  3652. class RandomDataset(SourceDataset):
  3653. """
  3654. A source dataset that generates random data.
  3655. Args:
  3656. total_rows (int): Number of rows for the dataset to generate (default=None, number of rows is random)
  3657. schema (Union[str, Schema], optional): Path to the JSON schema file or schema object (default=None).
  3658. If the schema is not provided, the random dataset generates a random schema.
  3659. columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
  3660. num_samples (int): number of samples to draw from the total. (default=None, which means all rows)
  3661. num_parallel_workers (int, optional): Number of workers to read the data
  3662. (default=None, number set in the config).
  3663. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
  3664. The cache feature is under development and is not recommended.
  3665. shuffle (bool, optional): Whether or not to perform shuffle on the dataset
  3666. (default=None, expected order behavior shown in the table).
  3667. num_shards (int, optional): Number of shards that the dataset will be divided
  3668. into (default=None).
  3669. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3670. argument can only be specified when num_shards is also specified.
  3671. """
  3672. @check_random_dataset
  3673. def __init__(self, total_rows=None, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
  3674. cache=None, shuffle=None, num_shards=None, shard_id=None):
  3675. super().__init__(num_parallel_workers)
  3676. schema_obj = None
  3677. if (schema is not None) and (not isinstance(schema, Schema)):
  3678. schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it
  3679. self.schema = schema
  3680. self.columns_list = columns_list
  3681. sampler = None
  3682. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id, non_mappable=True)
  3683. self.num_samples = num_samples
  3684. self.cache = cache
  3685. if schema_obj is not None and total_rows is None:
  3686. self.total_rows = schema_obj.num_rows
  3687. elif total_rows is None:
  3688. self.total_rows = 0
  3689. else:
  3690. self.total_rows = total_rows
  3691. self.num_shards = num_shards
  3692. self.shard_id = shard_id
  3693. self.shuffle_level = shuffle
  3694. def get_args(self):
  3695. args = super().get_args()
  3696. if self.schema is not None:
  3697. if isinstance(self.schema, Schema):
  3698. self.schema.datasetType = 'Random'
  3699. if self.total_rows is not None:
  3700. self.schema.num_rows = self.total_rows
  3701. args["schema_json_string"] = self.schema.to_json()
  3702. else:
  3703. args["schema_file_path"] = self.schema
  3704. args["schema"] = self.schema
  3705. args["columns_list"] = self.columns_list
  3706. args["num_samples"] = self.num_samples
  3707. args["total_rows"] = self.total_rows
  3708. args["cache"] = self.cache.cache_client if self.cache is not None else None
  3709. args["sampler"] = self.sampler
  3710. return args
  3711. def get_dataset_size(self):
  3712. """
  3713. Get the number of batches in an epoch.
  3714. Return:
  3715. Number, number of batches.
  3716. """
  3717. if self.dataset_size is None:
  3718. num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
  3719. self.dataset_size = get_num_rows(num_rows, self.num_shards)
  3720. rows_from_sampler = self._get_sampler_dataset_size()
  3721. if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
  3722. self.dataset_size = rows_from_sampler
  3723. return self.dataset_size
  3724. def is_shuffled(self):
  3725. if self.shuffle_level is None:
  3726. return True
  3727. return self.shuffle_level or self.sampler.is_shuffled()
  3728. def is_sharded(self):
  3729. if self.num_shards is not None:
  3730. return self.num_shards > 1
  3731. return self.sampler.is_sharded()
  3732. class Schema:
  3733. """
  3734. Class to represent a schema of a dataset.
  3735. Args:
  3736. schema_file(str): Path of schema file (default=None).
  3737. Return:
  3738. Schema object, schema info about dataset.
  3739. Raises:
  3740. RuntimeError: If schema file failed to load.
  3741. Example:
  3742. >>> import mindspore.dataset as ds
  3743. >>> import mindspore.common.dtype as mstype
  3744. >>>
  3745. >>> # Create schema; specify column name, mindspore.dtype and shape of the column
  3746. >>> schema = ds.Schema()
  3747. >>> schema.add_column('col1', de_type=mindspore.int64, shape=[2])
  3748. """
  3749. def __init__(self, schema_file=None):
  3750. self.num_rows = None
  3751. if schema_file is None:
  3752. self.columns = []
  3753. self.dataset_type = ''
  3754. else:
  3755. if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK):
  3756. raise ValueError("The file %s does not exist or permission denied!" % schema_file)
  3757. try:
  3758. with open(schema_file, 'r') as load_f:
  3759. json_obj = json.load(load_f)
  3760. except json.decoder.JSONDecodeError:
  3761. raise RuntimeError("Schema file failed to load.")
  3762. except UnicodeDecodeError:
  3763. raise RuntimeError("Schema file failed to decode.")
  3764. except Exception:
  3765. raise RuntimeError("Schema file failed to open.")
  3766. self.from_json(json_obj)
  3767. @check_add_column
  3768. def add_column(self, name, de_type, shape=None):
  3769. """
  3770. Add new column to the schema.
  3771. Args:
  3772. name (str): Name of the column.
  3773. de_type (str): Data type of the column.
  3774. shape (list[int], optional): Shape of the column
  3775. (default=None, [-1] which is an unknown shape of rank 1).
  3776. Raises:
  3777. ValueError: If column type is unknown.
  3778. """
  3779. new_column = dict()
  3780. new_column["name"] = name
  3781. if isinstance(de_type, typing.Type):
  3782. de_type = mstype_to_detype(de_type)
  3783. new_column["type"] = str(de_type)
  3784. else:
  3785. new_column["type"] = str(DataType(de_type))
  3786. if shape is not None:
  3787. new_column["shape"] = shape
  3788. new_column["rank"] = len(shape)
  3789. else:
  3790. new_column["rank"] = 1
  3791. self.columns.append(new_column)
  3792. def to_json(self):
  3793. """
  3794. Get a JSON string of the schema.
  3795. Returns:
  3796. Str, JSON string of the schema.
  3797. """
  3798. json_file = dict()
  3799. json_file["columns"] = self.columns
  3800. if self.dataset_type:
  3801. json_file["datasetType"] = self.dataset_type
  3802. if self.num_rows:
  3803. json_file["numRows"] = self.num_rows
  3804. return json.dumps(json_file, indent=2)
  3805. def parse_columns(self, columns):
  3806. """
  3807. Parse the columns and add it to self.
  3808. Args:
  3809. columns (Union[dict, list[dict]]): Dataset attribute information, decoded from schema file.
  3810. - list[dict], 'name' and 'type' must be in keys, 'shape' optional.
  3811. - dict, columns.keys() as name, columns.values() is dict, and 'type' inside, 'shape' optional.
  3812. Raises:
  3813. RuntimeError: If failed to parse columns.
  3814. RuntimeError: If unknown items in columns.
  3815. RuntimeError: If column's name field is missing.
  3816. RuntimeError: If column's type field is missing.
  3817. Example:
  3818. >>> schema = Schema()
  3819. >>> columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]},
  3820. >>> {'name': 'label', 'type': 'int8', 'shape': [1]}]
  3821. >>> schema.parse_columns(columns1)
  3822. >>> columns2 = {'image': {'shape': [3, 3], 'type': 'int8'}, 'label': {'shape': [1], 'type': 'int8'}}
  3823. >>> schema.parse_columns(columns2)
  3824. """
  3825. self.columns = []
  3826. if isinstance(columns, list):
  3827. for column in columns:
  3828. try:
  3829. name = column.pop("name")
  3830. except KeyError:
  3831. raise RuntimeError("Column's name is missing")
  3832. try:
  3833. de_type = column.pop("type")
  3834. except KeyError:
  3835. raise RuntimeError("Column' type is missing")
  3836. shape = column.pop("shape", None)
  3837. column.pop("t_impl", None)
  3838. column.pop("rank", None)
  3839. if column:
  3840. raise RuntimeError("Unknown field {}".format(",".join(column.keys())))
  3841. self.add_column(name, de_type, shape)
  3842. elif isinstance(columns, dict):
  3843. for key, value in columns.items():
  3844. name = key
  3845. try:
  3846. de_type = value.pop("type")
  3847. except KeyError:
  3848. raise RuntimeError("Column' type is missing")
  3849. shape = value.pop("shape", None)
  3850. value.pop("t_impl", None)
  3851. value.pop("rank", None)
  3852. if value:
  3853. raise RuntimeError("Unknown field {}".format(",".join(value.keys())))
  3854. self.add_column(name, de_type, shape)
  3855. else:
  3856. raise RuntimeError("columns must be dict or list, columns contain name, type, shape(optional).")
  3857. def from_json(self, json_obj):
  3858. """
  3859. Get schema file from JSON file.
  3860. Args:
  3861. json_obj(dictionary): Object of JSON parsed.
  3862. Raises:
  3863. RuntimeError: if there is unknown item in the object.
  3864. RuntimeError: if dataset type is missing in the object.
  3865. RuntimeError: if columns are missing in the object.
  3866. """
  3867. if not isinstance(json_obj, dict) or json_obj is None:
  3868. raise ValueError("Expected non-empty dict.")
  3869. for k, v in json_obj.items():
  3870. if k == "datasetType":
  3871. self.dataset_type = v
  3872. elif k == "numRows":
  3873. self.num_rows = v
  3874. elif k == "columns":
  3875. self.parse_columns(v)
  3876. else:
  3877. raise RuntimeError("Unknown field %s" % k)
  3878. if self.columns is None:
  3879. raise RuntimeError("Columns are missing.")
  3880. if self.num_rows is not None:
  3881. if not isinstance(self.num_rows, int) or self.num_rows <= 0:
  3882. raise ValueError("numRows must be greater than 0")
  3883. def __str__(self):
  3884. return self.to_json()
  3885. class VOCDataset(MappableDataset):
  3886. """
  3887. A source dataset for reading and parsing VOC dataset.
  3888. The generated dataset has multiple columns :
  3889. - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32],
  3890. ['difficult', dtype=uint32], ['truncate', dtype=uint32]].
  3891. - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]].
  3892. This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive. The table
  3893. below shows what input arguments are allowed and their expected behavior.
  3894. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  3895. :widths: 25 25 50
  3896. :header-rows: 1
  3897. * - Parameter 'sampler'
  3898. - Parameter 'shuffle'
  3899. - Expected Order Behavior
  3900. * - None
  3901. - None
  3902. - random order
  3903. * - None
  3904. - True
  3905. - random order
  3906. * - None
  3907. - False
  3908. - sequential order
  3909. * - Sampler object
  3910. - None
  3911. - order defined by sampler
  3912. * - Sampler object
  3913. - True
  3914. - not allowed
  3915. * - Sampler object
  3916. - False
  3917. - not allowed
  3918. Citation of VOC dataset.
  3919. .. code-block::
  3920. @article{Everingham10,
  3921. author = {Everingham, M. and Van~Gool, L. and Williams, C. K. I. and Winn, J. and Zisserman, A.},
  3922. title = {The Pascal Visual Object Classes (VOC) Challenge},
  3923. journal = {International Journal of Computer Vision},
  3924. volume = {88},
  3925. year = {2010},
  3926. number = {2},
  3927. month = {jun},
  3928. pages = {303--338},
  3929. biburl = {http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham10.html#bibtex},
  3930. howpublished = {http://host.robots.ox.ac.uk/pascal/VOC/voc{year}/index.html},
  3931. description = {The PASCAL Visual Object Classes (VOC) challenge is a benchmark in visual
  3932. object category recognition and detection, providing the vision and machine
  3933. learning communities with a standard dataset of images and annotation, and
  3934. standard evaluation procedures.}
  3935. }
  3936. Args:
  3937. dataset_dir (str): Path to the root directory that contains the dataset.
  3938. task (str): Set the task type of reading voc data, now only support "Segmentation" or "Detection"
  3939. (default="Segmentation").
  3940. usage (str): The type of data list text file to be read (default="train").
  3941. class_indexing (dict, optional): A str-to-int mapping from label name to index, only valid in
  3942. "Detection" task (default=None, the folder names will be sorted alphabetically and each
  3943. class will be given a unique index starting from 0).
  3944. num_samples (int, optional): The number of images to be included in the dataset
  3945. (default=None, all images).
  3946. num_parallel_workers (int, optional): Number of workers to read the data
  3947. (default=None, number set in the config).
  3948. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  3949. order behavior shown in the table).
  3950. decode (bool, optional): Decode the images after reading (default=False).
  3951. sampler (Sampler, optional): Object used to choose samples from the dataset
  3952. (default=None, expected order behavior shown in the table).
  3953. num_shards (int, optional): Number of shards that the dataset will be divided
  3954. into (default=None).
  3955. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3956. argument can only be specified when num_shards is also specified.
  3957. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
  3958. The cache feature is under development and is not recommended.
  3959. Raises:
  3960. RuntimeError: If xml of Annotations is an invalid format.
  3961. RuntimeError: If xml of Annotations loss attribution of "object".
  3962. RuntimeError: If xml of Annotations loss attribution of "bndbox".
  3963. RuntimeError: If sampler and shuffle are specified at the same time.
  3964. RuntimeError: If sampler and sharding are specified at the same time.
  3965. RuntimeError: If num_shards is specified but shard_id is None.
  3966. RuntimeError: If shard_id is specified but num_shards is None.
  3967. ValueError: If task is not equal 'Segmentation' or 'Detection'.
  3968. ValueError: If task equal 'Segmentation' but class_indexing is not None.
  3969. ValueError: If txt related to mode is not exist.
  3970. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  3971. Examples:
  3972. >>> import mindspore.dataset as ds
  3973. >>>
  3974. >>> dataset_dir = "/path/to/voc_dataset_directory"
  3975. >>>
  3976. >>> # 1) Read VOC data for segmentatation training
  3977. >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Segmentation", usage="train")
  3978. >>>
  3979. >>> # 2) Read VOC data for detection training
  3980. >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", usage="train")
  3981. >>>
  3982. >>> # 3) Read all VOC dataset samples in dataset_dir with 8 threads in random order
  3983. >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", usage="train", num_parallel_workers=8)
  3984. >>>
  3985. >>> # 4) Read then decode all VOC dataset samples in dataset_dir in sequence
  3986. >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", usage="train", decode=True, shuffle=False)
  3987. >>>
  3988. >>> # In VOC dataset, if task='Segmentation', each dictionary has keys "image" and "target"
  3989. >>> # In VOC dataset, if task='Detection', each dictionary has keys "image" and "annotation"
  3990. """
  3991. @check_vocdataset
  3992. def __init__(self, dataset_dir, task="Segmentation", usage="train", class_indexing=None, num_samples=None,
  3993. num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None,
  3994. cache=None):
  3995. super().__init__(num_parallel_workers)
  3996. self.dataset_dir = dataset_dir
  3997. self.task = task
  3998. self.usage = usage
  3999. self.class_indexing = class_indexing
  4000. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  4001. self.num_samples = num_samples
  4002. self.decode = decode
  4003. self.shuffle_level = shuffle
  4004. self.num_shards = num_shards
  4005. self.shard_id = shard_id
  4006. self.cache = cache
  4007. def get_args(self):
  4008. args = super().get_args()
  4009. args["dataset_dir"] = self.dataset_dir
  4010. args["task"] = self.task
  4011. args["usage"] = self.usage
  4012. args["class_indexing"] = self.class_indexing
  4013. args["num_samples"] = self.num_samples
  4014. args["sampler"] = self.sampler
  4015. args["decode"] = self.decode
  4016. args["shuffle"] = self.shuffle_level
  4017. args["num_shards"] = self.num_shards
  4018. args["shard_id"] = self.shard_id
  4019. args["cache"] = self.cache.cache_client if self.cache is not None else None
  4020. return args
  4021. def get_dataset_size(self):
  4022. """
  4023. Get the number of batches in an epoch.
  4024. Return:
  4025. Number, number of batches.
  4026. """
  4027. if self.dataset_size is None:
  4028. if self.num_samples is None:
  4029. num_samples = 0
  4030. else:
  4031. num_samples = self.num_samples
  4032. if self.class_indexing is None:
  4033. class_indexing = dict()
  4034. else:
  4035. class_indexing = self.class_indexing
  4036. num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.usage, class_indexing, num_samples)
  4037. self.dataset_size = get_num_rows(num_rows, self.num_shards)
  4038. rows_from_sampler = self._get_sampler_dataset_size()
  4039. if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
  4040. self.dataset_size = rows_from_sampler
  4041. return self.dataset_size
  4042. def get_class_indexing(self):
  4043. """
  4044. Get the class index.
  4045. Return:
  4046. Dict, A str-to-int mapping from label name to index.
  4047. """
  4048. if self.task != "Detection":
  4049. raise NotImplementedError()
  4050. if self.class_indexing is None:
  4051. class_indexing = dict()
  4052. else:
  4053. class_indexing = self.class_indexing
  4054. return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.usage, class_indexing)
  4055. def is_shuffled(self):
  4056. if self.shuffle_level is None:
  4057. return True
  4058. return self.shuffle_level or self.sampler.is_shuffled()
  4059. def is_sharded(self):
  4060. if self.num_shards is not None:
  4061. return self.num_shards > 1
  4062. return self.sampler.is_sharded()
  4063. class CocoDataset(MappableDataset):
  4064. """
  4065. A source dataset for reading and parsing COCO dataset.
  4066. CocoDataset support four kinds of task: 2017 Train/Val/Test Detection, Keypoints, Stuff, Panoptic.
  4067. The generated dataset has multi-columns :
  4068. - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
  4069. ['iscrowd', dtype=uint32]].
  4070. - task='Stuff', column: [['image', dtype=uint8], ['segmentation',dtype=float32], ['iscrowd',dtype=uint32]].
  4071. - task='Keypoint', column: [['image', dtype=uint8], ['keypoints', dtype=float32],
  4072. ['num_keypoints', dtype=uint32]].
  4073. - task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
  4074. ['iscrowd', dtype=uint32], ['area', dtype=uint32]].
  4075. This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive. CocoDataset doesn't support
  4076. PKSampler. The table below shows what input arguments are allowed and their expected behavior.
  4077. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  4078. :widths: 25 25 50
  4079. :header-rows: 1
  4080. * - Parameter 'sampler'
  4081. - Parameter 'shuffle'
  4082. - Expected Order Behavior
  4083. * - None
  4084. - None
  4085. - random order
  4086. * - None
  4087. - True
  4088. - random order
  4089. * - None
  4090. - False
  4091. - sequential order
  4092. * - Sampler object
  4093. - None
  4094. - order defined by sampler
  4095. * - Sampler object
  4096. - True
  4097. - not allowed
  4098. * - Sampler object
  4099. - False
  4100. - not allowed
  4101. Citation of Coco dataset.
  4102. .. code-block::
  4103. @article{DBLP:journals/corr/LinMBHPRDZ14,
  4104. author = {Tsung{-}Yi Lin and Michael Maire and Serge J. Belongie and
  4105. Lubomir D. Bourdev and Ross B. Girshick and James Hays and
  4106. Pietro Perona and Deva Ramanan and Piotr Doll{\'{a}}r and C. Lawrence Zitnick},
  4107. title = {Microsoft {COCO:} Common Objects in Context},
  4108. journal = {CoRR},
  4109. volume = {abs/1405.0312},
  4110. year = {2014},
  4111. url = {http://arxiv.org/abs/1405.0312},
  4112. archivePrefix = {arXiv},
  4113. eprint = {1405.0312},
  4114. timestamp = {Mon, 13 Aug 2018 16:48:13 +0200},
  4115. biburl = {https://dblp.org/rec/journals/corr/LinMBHPRDZ14.bib},
  4116. bibsource = {dblp computer science bibliography, https://dblp.org},
  4117. description = {COCO is a large-scale object detection, segmentation, and captioning dataset.
  4118. It contains 91 common object categories with 82 of them having more than 5,000
  4119. labeled instances. In contrast to the popular ImageNet dataset, COCO has fewer
  4120. categories but more instances per category.}
  4121. }
  4122. Args:
  4123. dataset_dir (str): Path to the root directory that contains the dataset.
  4124. annotation_file (str): Path to the annotation JSON.
  4125. task (str): Set the task type for reading COCO data. Supported task types:
  4126. 'Detection', 'Stuff', 'Panoptic' and 'Keypoint' (default='Detection').
  4127. num_samples (int, optional): The number of images to be included in the dataset
  4128. (default=None, all images).
  4129. num_parallel_workers (int, optional): Number of workers to read the data
  4130. (default=None, number set in the configuration file).
  4131. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  4132. order behavior shown in the table).
  4133. decode (bool, optional): Decode the images after reading (default=False).
  4134. sampler (Sampler, optional): Object used to choose samples from the dataset
  4135. (default=None, expected order behavior shown in the table).
  4136. num_shards (int, optional): Number of shards that the dataset will be divided
  4137. into (default=None).
  4138. shard_id (int, optional): The shard ID within num_shards (default=None). This
  4139. argument can only be specified when num_shards is also specified.
  4140. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
  4141. The cache feature is under development and is not recommended.
  4142. Raises:
  4143. RuntimeError: If sampler and shuffle are specified at the same time.
  4144. RuntimeError: If sampler and sharding are specified at the same time.
  4145. RuntimeError: If num_shards is specified but shard_id is None.
  4146. RuntimeError: If shard_id is specified but num_shards is None.
  4147. RuntimeError: If parse JSON file failed.
  4148. ValueError: If task is not in ['Detection', 'Stuff', 'Panoptic', 'Keypoint'].
  4149. ValueError: If annotation_file is not exist.
  4150. ValueError: If dataset_dir is not exist.
  4151. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  4152. Examples:
  4153. >>> import mindspore.dataset as ds
  4154. >>>
  4155. >>> dataset_dir = "/path/to/coco_dataset_directory/image_folder"
  4156. >>> annotation_file = "/path/to/coco_dataset_directory/annotation_folder/annotation.json"
  4157. >>>
  4158. >>> # 1) Read COCO data for Detection task
  4159. >>> coco_dataset = ds.CocoDataset(dataset_dir, annotation_file=annotation_file, task='Detection')
  4160. >>>
  4161. >>> # 2) Read COCO data for Stuff task
  4162. >>> coco_dataset = ds.CocoDataset(dataset_dir, annotation_file=annotation_file, task='Stuff')
  4163. >>>
  4164. >>> # 3) Read COCO data for Panoptic task
  4165. >>> coco_dataset = ds.CocoDataset(dataset_dir, annotation_file=annotation_file, task='Panoptic')
  4166. >>>
  4167. >>> # 4) Read COCO data for Keypoint task
  4168. >>> coco_dataset = ds.CocoDataset(dataset_dir, annotation_file=annotation_file, task='Keypoint')
  4169. >>>
  4170. >>> # In COCO dataset, each dictionary has keys "image" and "annotation"
  4171. """
  4172. @check_cocodataset
  4173. def __init__(self, dataset_dir, annotation_file, task="Detection", num_samples=None, num_parallel_workers=None,
  4174. shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None):
  4175. super().__init__(num_parallel_workers)
  4176. self.dataset_dir = dataset_dir
  4177. self.annotation_file = annotation_file
  4178. self.task = task
  4179. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  4180. self.num_samples = num_samples
  4181. self.decode = decode
  4182. self.shuffle_level = shuffle
  4183. self.num_shards = num_shards
  4184. self.shard_id = shard_id
  4185. self.cache = cache
  4186. def get_args(self):
  4187. args = super().get_args()
  4188. args["dataset_dir"] = self.dataset_dir
  4189. args["annotation_file"] = self.annotation_file
  4190. args["task"] = self.task
  4191. args["num_samples"] = self.num_samples
  4192. args["sampler"] = self.sampler
  4193. args["decode"] = self.decode
  4194. args["shuffle"] = self.shuffle_level
  4195. args["num_shards"] = self.num_shards
  4196. args["shard_id"] = self.shard_id
  4197. args["cache"] = self.cache.cache_client if self.cache is not None else None
  4198. return args
  4199. def get_dataset_size(self):
  4200. """
  4201. Get the number of batches in an epoch.
  4202. Return:
  4203. Number, number of batches.
  4204. """
  4205. if self.dataset_size is None:
  4206. num_rows = CocoOp.get_num_rows(self.dataset_dir, self.annotation_file, self.task)
  4207. self.dataset_size = get_num_rows(num_rows, self.num_shards)
  4208. rows_from_sampler = self._get_sampler_dataset_size()
  4209. if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
  4210. self.dataset_size = rows_from_sampler
  4211. return self.dataset_size
  4212. def get_class_indexing(self):
  4213. """
  4214. Get the class index.
  4215. Return:
  4216. Dict, A str-to-int mapping from label name to index.
  4217. """
  4218. if self.task not in {"Detection", "Panoptic"}:
  4219. raise NotImplementedError("Only 'Detection' and 'Panoptic' support get_class_indexing.")
  4220. class_index = CocoOp.get_class_indexing(self.dataset_dir, self.annotation_file, self.task)
  4221. return dict(class_index)
  4222. def is_shuffled(self):
  4223. if self.shuffle_level is None:
  4224. return True
  4225. return self.shuffle_level or self.sampler.is_shuffled()
  4226. def is_sharded(self):
  4227. if self.num_shards is not None:
  4228. return self.num_shards > 1
  4229. return self.sampler.is_sharded()
  4230. class CelebADataset(MappableDataset):
  4231. """
  4232. A source dataset for reading and parsing CelebA dataset. Currently supported: list_attr_celeba.txt only.
  4233. Note:
  4234. The generated dataset has two columns ['image', 'attr'].
  4235. The type of the image tensor is uint8. The attribute tensor is uint32 and one hot type.
  4236. Citation of CelebA dataset.
  4237. .. code-block::
  4238. @article{DBLP:journals/corr/LiuLWT14,
  4239. author = {Ziwei Liu and Ping Luo and Xiaogang Wang and Xiaoou Tang},
  4240. title = {Deep Learning Face Attributes in the Wild},
  4241. journal = {CoRR},
  4242. volume = {abs/1411.7766},
  4243. year = {2014},
  4244. url = {http://arxiv.org/abs/1411.7766},
  4245. archivePrefix = {arXiv},
  4246. eprint = {1411.7766},
  4247. timestamp = {Tue, 10 Dec 2019 15:37:26 +0100},
  4248. biburl = {https://dblp.org/rec/journals/corr/LiuLWT14.bib},
  4249. bibsource = {dblp computer science bibliography, https://dblp.org},
  4250. howpublished = {http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html},
  4251. description = {CelebFaces Attributes Dataset (CelebA) is a large-scale face attributes dataset
  4252. with more than 200K celebrity images, each with 40 attribute annotations.
  4253. The images in this dataset cover large pose variations and background clutter.
  4254. CelebA has large diversities, large quantities, and rich annotations, including
  4255. * 10,177 number of identities,
  4256. * 202,599 number of face images, and
  4257. * 5 landmark locations, 40 binary attributes annotations per image.
  4258. The dataset can be employed as the training and test sets for the following computer
  4259. vision tasks: face attribute recognition, face detection, landmark (or facial part)
  4260. localization, and face editing & synthesis.}
  4261. }
  4262. Args:
  4263. dataset_dir (str): Path to the root directory that contains the dataset.
  4264. num_parallel_workers (int, optional): Number of workers to read the data (default=value set in the config).
  4265. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None).
  4266. usage (str): one of 'all', 'train', 'valid' or 'test'.
  4267. sampler (Sampler, optional): Object used to choose samples from the dataset (default=None).
  4268. decode (bool, optional): decode the images after reading (default=False).
  4269. extensions (list[str], optional): List of file extensions to be
  4270. included in the dataset (default=None).
  4271. num_samples (int, optional): The number of images to be included in the dataset.
  4272. (default=None, all images).
  4273. num_shards (int, optional): Number of shards that the dataset will be divided
  4274. into (default=None).
  4275. shard_id (int, optional): The shard ID within num_shards (default=None). This
  4276. argument can only be specified when num_shards is also specified.
  4277. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
  4278. The cache feature is under development and is not recommended.
  4279. Examples:
  4280. >>> import mindspore.dataset as ds
  4281. >>>
  4282. >>> dataset_dir = "/path/to/celeba_directory"
  4283. >>> dataset = ds.CelebADataset(dataset_dir=dataset_dir, usage='train')
  4284. """
  4285. @check_celebadataset
  4286. def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, usage='all', sampler=None, decode=False,
  4287. extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None):
  4288. super().__init__(num_parallel_workers)
  4289. self.dataset_dir = dataset_dir
  4290. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  4291. self.num_parallel_workers = num_parallel_workers
  4292. self.decode = decode
  4293. self.extensions = extensions
  4294. self.num_samples = num_samples
  4295. self.usage = usage
  4296. self.num_shards = num_shards
  4297. self.shard_id = shard_id
  4298. self.shuffle_level = shuffle
  4299. self.cache = cache
  4300. if usage != "all":
  4301. dir = os.path.realpath(self.dataset_dir)
  4302. partition_file = os.path.join(dir, "list_eval_partition.txt")
  4303. if os.path.exists(partition_file) is False:
  4304. raise RuntimeError("Partition file can not be found when usage is not 'all'.")
  4305. def get_args(self):
  4306. args = super().get_args()
  4307. args["dataset_dir"] = self.dataset_dir
  4308. args["sampler"] = self.sampler
  4309. args["shuffle"] = self.shuffle_level
  4310. args["decode"] = self.decode
  4311. args["extensions"] = self.extensions
  4312. args["num_samples"] = self.num_samples
  4313. args["usage"] = self.usage
  4314. args["num_shards"] = self.num_shards
  4315. args["shard_id"] = self.shard_id
  4316. args["cache"] = self.cache.cache_client if self.cache is not None else None
  4317. return args
  4318. def get_dataset_size(self):
  4319. """
  4320. Get the number of batches in an epoch.
  4321. Return:
  4322. Number, number of batches.
  4323. """
  4324. if self.dataset_size is None:
  4325. dir = os.path.realpath(self.dataset_dir)
  4326. attr_file = os.path.join(dir, "list_attr_celeba.txt")
  4327. num_rows = ''
  4328. try:
  4329. with open(attr_file, 'r') as f:
  4330. num_rows = int(f.readline())
  4331. except FileNotFoundError:
  4332. raise RuntimeError("attr file can not be found.")
  4333. except BaseException:
  4334. raise RuntimeError("Get dataset size failed from attribution file.")
  4335. if self.usage != 'all':
  4336. partition_file = os.path.join(dir, "list_eval_partition.txt")
  4337. usage_type = 0
  4338. partition_num = 0
  4339. if self.usage == "train":
  4340. usage_type = 0
  4341. elif self.usage == "valid":
  4342. usage_type = 1
  4343. elif self.usage == "test":
  4344. usage_type = 2
  4345. try:
  4346. with open(partition_file, 'r') as f:
  4347. for line in f.readlines():
  4348. split_line = line.split(' ')
  4349. if int(split_line[1]) == usage_type:
  4350. partition_num += 1
  4351. except FileNotFoundError:
  4352. raise RuntimeError("Partition file can not be found")
  4353. if partition_num < num_rows:
  4354. num_rows = partition_num
  4355. self.dataset_size = get_num_rows(num_rows, self.num_shards)
  4356. if self.num_samples is not None and self.num_samples < self.dataset_size:
  4357. self.dataset_size = self.num_samples
  4358. rows_from_sampler = self._get_sampler_dataset_size()
  4359. if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
  4360. self.dataset_size = rows_from_sampler
  4361. return self.dataset_size
  4362. def is_shuffled(self):
  4363. if self.shuffle_level is None:
  4364. return True
  4365. return self.shuffle_level or self.sampler.is_shuffled()
  4366. def is_sharded(self):
  4367. if self.num_shards is not None:
  4368. return self.num_shards > 1
  4369. return self.sampler.is_sharded()
  4370. class CLUEDataset(SourceDataset):
  4371. """
  4372. A source dataset that reads and parses CLUE datasets.
  4373. CLUE, the Chinese Language Understanding Evaluation Benchmark, is a collection of datasets, baselines,
  4374. pre-trained models, corpus and leaderboard. Supported CLUE classification tasks: 'AFQMC', 'TNEWS', 'IFLYTEK',
  4375. 'CMNLI', 'WSC' and 'CSL'.
  4376. Citation of CLUE dataset.
  4377. .. code-block::
  4378. @article{CLUEbenchmark,
  4379. title = {CLUE: A Chinese Language Understanding Evaluation Benchmark},
  4380. author = {Liang Xu, Xuanwei Zhang, Lu Li, Hai Hu, Chenjie Cao, Weitang Liu, Junyi Li, Yudong Li,
  4381. Kai Sun, Yechen Xu, Yiming Cui, Cong Yu, Qianqian Dong, Yin Tian, Dian Yu, Bo Shi, Jun Zeng,
  4382. Rongzhao Wang, Weijian Xie, Yanting Li, Yina Patterson, Zuoyu Tian, Yiwen Zhang, He Zhou,
  4383. Shaoweihua Liu, Qipeng Zhao, Cong Yue, Xinrui Zhang, Zhengliang Yang, Zhenzhong Lan},
  4384. journal = {arXiv preprint arXiv:2004.05986},
  4385. year = {2020},
  4386. howpublished = {https://github.com/CLUEbenchmark/CLUE},
  4387. description = {CLUE, a Chinese Language Understanding Evaluation benchmark. It contains eight different
  4388. tasks, including single-sentence classification, sentence pair classification, and machine
  4389. reading comprehension.}
  4390. }
  4391. Args:
  4392. dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for
  4393. a pattern of files. The list will be sorted in a lexicographical order.
  4394. task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'.
  4395. (default=AFQMC).
  4396. usage (str, optional): Need train, test or eval data (default="train").
  4397. num_samples (int, optional): Number of samples (rows) to read (default=None, reads the full dataset).
  4398. num_parallel_workers (int, optional): Number of workers to read the data
  4399. (default=None, number set in the config).
  4400. shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
  4401. (default=Shuffle.GLOBAL).
  4402. If shuffle is False, no shuffling will be performed;
  4403. If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
  4404. Otherwise, there are two levels of shuffling:
  4405. - Shuffle.GLOBAL: Shuffle both the files and samples.
  4406. - Shuffle.FILES: Shuffle files only.
  4407. num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
  4408. shard_id (int, optional): The shard ID within num_shards (default=None). This
  4409. argument can only be specified when num_shards is also specified.
  4410. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
  4411. The cache feature is under development and is not recommended.
  4412. Examples:
  4413. >>> import mindspore.dataset as ds
  4414. >>>
  4415. >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
  4416. >>> dataset = ds.CLUEDataset(dataset_files=dataset_files, task='AFQMC', usage='train')
  4417. """
  4418. @check_cluedataset
  4419. def __init__(self, dataset_files, task='AFQMC', usage='train', num_samples=None,
  4420. num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None):
  4421. super().__init__(num_parallel_workers)
  4422. self.dataset_files = self._find_files(dataset_files)
  4423. self.dataset_files.sort()
  4424. self.num_samples = num_samples
  4425. self.task_dict = {
  4426. 'AFQMC': {
  4427. 'train': {
  4428. 'sentence1': 'sentence1',
  4429. 'sentence2': 'sentence2',
  4430. 'label': 'label'
  4431. },
  4432. 'test': {
  4433. 'id': 'id',
  4434. 'sentence1': 'sentence1',
  4435. 'sentence2': 'sentence2'
  4436. },
  4437. 'eval': {
  4438. 'sentence1': 'sentence1',
  4439. 'sentence2': 'sentence2',
  4440. 'label': 'label'
  4441. }
  4442. },
  4443. 'CMNLI': {
  4444. 'train': {
  4445. 'sentence1': 'sentence1',
  4446. 'sentence2': 'sentence2',
  4447. 'label': 'label'
  4448. },
  4449. 'test': {
  4450. 'id': 'id',
  4451. 'sentence1': 'sentence1',
  4452. 'sentence2': 'sentence2'
  4453. },
  4454. 'eval': {
  4455. 'sentence1': 'sentence1',
  4456. 'sentence2': 'sentence2',
  4457. 'label': 'label'
  4458. }
  4459. },
  4460. 'CSL': {
  4461. 'train': {
  4462. 'id': 'id',
  4463. 'abst': 'abst',
  4464. 'keyword': 'keyword',
  4465. 'label': 'label'
  4466. },
  4467. 'test': {
  4468. 'id': 'id',
  4469. 'abst': 'abst',
  4470. 'keyword': 'keyword'
  4471. },
  4472. 'eval': {
  4473. 'id': 'id',
  4474. 'abst': 'abst',
  4475. 'keyword': 'keyword',
  4476. 'label': 'label'
  4477. }
  4478. },
  4479. 'IFLYTEK': {
  4480. 'train': {
  4481. 'label': 'label',
  4482. 'label_des': 'label_des',
  4483. 'sentence': 'sentence'
  4484. },
  4485. 'test': {
  4486. 'id': 'id',
  4487. 'sentence': 'sentence',
  4488. },
  4489. 'eval': {
  4490. 'label': 'label',
  4491. 'label_des': 'label_des',
  4492. 'sentence': 'sentence'
  4493. }
  4494. },
  4495. 'TNEWS': {
  4496. 'train': {
  4497. 'label': 'label',
  4498. 'label_desc': 'label_desc',
  4499. 'sentence': 'sentence',
  4500. 'keywords': 'keywords'
  4501. },
  4502. 'test': {
  4503. 'id': 'id',
  4504. 'sentence': 'sentence',
  4505. 'keywords': 'keywords'
  4506. },
  4507. 'eval': {
  4508. 'label': 'label',
  4509. 'label_desc': 'label_desc',
  4510. 'sentence': 'sentence',
  4511. 'keywords': 'keywords'
  4512. }
  4513. },
  4514. 'WSC': {
  4515. 'train': {
  4516. 'span1_index': 'target/span1_index',
  4517. 'span2_index': 'target/span2_index',
  4518. 'span1_text': 'target/span1_text',
  4519. 'span2_text': 'target/span2_text',
  4520. 'idx': 'idx',
  4521. 'label': 'label',
  4522. 'text': 'text'
  4523. },
  4524. 'test': {
  4525. 'span1_index': 'target/span1_index',
  4526. 'span2_index': 'target/span2_index',
  4527. 'span1_text': 'target/span1_text',
  4528. 'span2_text': 'target/span2_text',
  4529. 'idx': 'idx',
  4530. 'text': 'text'
  4531. },
  4532. 'eval': {
  4533. 'span1_index': 'target/span1_index',
  4534. 'span2_index': 'target/span2_index',
  4535. 'span1_text': 'target/span1_text',
  4536. 'span2_text': 'target/span2_text',
  4537. 'idx': 'idx',
  4538. 'label': 'label',
  4539. 'text': 'text'
  4540. }
  4541. }
  4542. }
  4543. self.cols_to_keyword = self.task_dict[task][usage]
  4544. if not isinstance(shuffle, (bool, Shuffle)):
  4545. raise TypeError("shuffle must be of type boolean or enum 'Shuffle'.")
  4546. if not isinstance(shuffle, Shuffle):
  4547. if shuffle:
  4548. self.shuffle_level = Shuffle.GLOBAL
  4549. self.shuffle_files = True
  4550. else:
  4551. self.shuffle_level = None
  4552. self.shuffle_files = False
  4553. else:
  4554. self.shuffle_level = shuffle
  4555. self.shuffle_files = True
  4556. self.num_shards = num_shards
  4557. self.shard_id = shard_id
  4558. # The clue dataset does not directly support a sampler. It has provided sampling arguments
  4559. # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
  4560. # the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
  4561. sampler_shuffle = self.shuffle_files
  4562. sampler = None
  4563. self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id,
  4564. non_mappable=True)
  4565. self.cache = cache
  4566. def get_args(self):
  4567. args = super().get_args()
  4568. args["dataset_files"] = self.dataset_files
  4569. args["num_samples"] = self.num_samples
  4570. if self.shuffle_files is not None:
  4571. args["shuffle_files"] = self.shuffle_files
  4572. args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL)
  4573. args["shuffle"] = self.shuffle_level
  4574. args["num_shards"] = self.num_shards
  4575. args["shard_id"] = self.shard_id
  4576. args["cols_to_keyword"] = self.cols_to_keyword
  4577. args["sampler"] = self.sampler
  4578. args["cache"] = self.cache.cache_client if self.cache is not None else None
  4579. return args
  4580. def get_dataset_size(self):
  4581. """
  4582. Get the number of batches in an epoch.
  4583. Return:
  4584. Number, number of batches.
  4585. """
  4586. if self.dataset_size is None:
  4587. num_rows = ClueOp.get_num_rows(self.dataset_files)
  4588. self.dataset_size = get_num_rows(num_rows, self.num_shards)
  4589. if self.num_samples is not None and self.num_samples < self.dataset_size:
  4590. self.dataset_size = self.num_samples
  4591. return self.dataset_size
  4592. def is_shuffled(self):
  4593. return self.shuffle_files
  4594. def is_sharded(self):
  4595. if self.num_shards is not None:
  4596. return self.num_shards > 1
  4597. return False
  4598. class CSVDataset(SourceDataset):
  4599. """
  4600. A source dataset that reads and parses comma-separated values (CSV) datasets.
  4601. Args:
  4602. dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search
  4603. for a pattern of files. The list will be sorted in a lexicographical order.
  4604. field_delim (str, optional): A string that indicates the char delimiter to separate fields (default=',').
  4605. column_defaults (list, optional): List of default values for the CSV field (default=None). Each item
  4606. in the list is either a valid type (float, int, or string). If this is not provided, treats all
  4607. columns as string type.
  4608. column_names (list[str], optional): List of column names of the dataset (default=None). If this
  4609. is not provided, infers the column_names from the first row of CSV file.
  4610. num_samples (int, optional): Number of samples (rows) to read (default=None, reads the full dataset).
  4611. num_parallel_workers (int, optional): Number of workers to read the data
  4612. (default=None, number set in the config).
  4613. shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
  4614. (default=Shuffle.GLOBAL).
  4615. If shuffle is False, no shuffling will be performed;
  4616. If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
  4617. Otherwise, there are two levels of shuffling:
  4618. - Shuffle.GLOBAL: Shuffle both the files and samples.
  4619. - Shuffle.FILES: Shuffle files only.
  4620. num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
  4621. shard_id (int, optional): The shard ID within num_shards (default=None). This
  4622. argument can only be specified when num_shards is also specified.
  4623. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
  4624. The cache feature is under development and is not recommended.
  4625. Examples:
  4626. >>> import mindspore.dataset as ds
  4627. >>>
  4628. >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
  4629. >>> dataset = ds.CSVDataset(dataset_files=dataset_files, column_names=['col1', 'col2', 'col3', 'col4'])
  4630. """
  4631. @check_csvdataset
  4632. def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None,
  4633. num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None):
  4634. super().__init__(num_parallel_workers)
  4635. self.dataset_files = self._find_files(dataset_files)
  4636. self.dataset_files.sort()
  4637. self.field_delim = field_delim
  4638. self.column_defaults = column_defaults
  4639. self.column_names = column_names
  4640. self.num_samples = num_samples
  4641. if not isinstance(shuffle, (bool, Shuffle)):
  4642. raise TypeError("shuffle must be of type boolean or enum 'Shuffle'.")
  4643. if not isinstance(shuffle, Shuffle):
  4644. if shuffle:
  4645. self.shuffle_level = Shuffle.GLOBAL
  4646. self.shuffle_files = True
  4647. else:
  4648. self.shuffle_level = None
  4649. self.shuffle_files = False
  4650. else:
  4651. self.shuffle_level = shuffle
  4652. self.shuffle_files = True
  4653. self.num_shards = num_shards
  4654. self.shard_id = shard_id
  4655. self.cache = cache
  4656. # The CSV dataset does not directly support a sampler. It has provided sampling arguments
  4657. # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
  4658. # the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
  4659. sampler_shuffle = self.shuffle_files
  4660. sampler = None
  4661. self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id,
  4662. non_mappable=True)
  4663. def get_args(self):
  4664. args = super().get_args()
  4665. args["dataset_files"] = self.dataset_files
  4666. args['field_delim'] = self.field_delim
  4667. args['column_defaults'] = self.column_defaults
  4668. args['column_names'] = self.column_names
  4669. args["num_samples"] = self.num_samples
  4670. if self.shuffle_files is not None:
  4671. args["shuffle_files"] = self.shuffle_files
  4672. args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL)
  4673. args["shuffle"] = self.shuffle_level
  4674. args["num_shards"] = self.num_shards
  4675. args["shard_id"] = self.shard_id
  4676. args["sampler"] = self.sampler
  4677. args["cache"] = self.cache.cache_client if self.cache is not None else None
  4678. return args
  4679. def get_dataset_size(self):
  4680. """
  4681. Get the number of batches in an epoch.
  4682. Return:
  4683. Number, number of batches.
  4684. """
  4685. if self.dataset_size is None:
  4686. num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None)
  4687. self.dataset_size = get_num_rows(num_rows, self.num_shards)
  4688. if self.num_samples is not None and self.num_samples < self.dataset_size:
  4689. self.dataset_size = num_rows
  4690. return self.dataset_size
  4691. def is_shuffled(self):
  4692. return self.shuffle_files
  4693. def is_sharded(self):
  4694. if self.num_shards is not None:
  4695. return self.num_shards > 1
  4696. return False
  4697. class TextFileDataset(SourceDataset):
  4698. """
  4699. A source dataset that reads and parses datasets stored on disk in text format.
  4700. The generated dataset has one column ['text'].
  4701. Args:
  4702. dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a
  4703. pattern of files. The list will be sorted in a lexicographical order.
  4704. num_samples (int, optional): Number of samples (rows) to read (default=None, reads the full dataset).
  4705. num_parallel_workers (int, optional): Number of workers to read the data
  4706. (default=None, number set in the config).
  4707. shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch
  4708. (default=Shuffle.GLOBAL).
  4709. If shuffle is False, no shuffling will be performed;
  4710. If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
  4711. Otherwise, there are two levels of shuffling:
  4712. - Shuffle.GLOBAL: Shuffle both the files and samples.
  4713. - Shuffle.FILES: Shuffle files only.
  4714. num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
  4715. shard_id (int, optional): The shard ID within num_shards (default=None). This
  4716. argument can only be specified when num_shards is also specified.
  4717. cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
  4718. The cache feature is under development and is not recommended.
  4719. Examples:
  4720. >>> import mindspore.dataset as ds
  4721. >>>
  4722. >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
  4723. >>> dataset = ds.TextFileDataset(dataset_files=dataset_files)
  4724. """
  4725. @check_textfiledataset
  4726. def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None,
  4727. shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, cache=None):
  4728. super().__init__(num_parallel_workers)
  4729. self.dataset_files = self._find_files(dataset_files)
  4730. self.dataset_files.sort()
  4731. self.num_samples = num_samples
  4732. if not isinstance(shuffle, (bool, Shuffle)):
  4733. raise TypeError("shuffle must be of type boolean or enum 'Shuffle'.")
  4734. if not isinstance(shuffle, Shuffle):
  4735. if shuffle:
  4736. self.shuffle_level = Shuffle.GLOBAL
  4737. self.shuffle_files = True
  4738. else:
  4739. self.shuffle_level = None
  4740. self.shuffle_files = False
  4741. else:
  4742. self.shuffle_level = shuffle
  4743. self.shuffle_files = True
  4744. self.num_shards = num_shards
  4745. self.shard_id = shard_id
  4746. self.cache = cache
  4747. # The text file dataset does not directly support a sampler. It has provided sampling arguments
  4748. # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
  4749. # the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
  4750. sampler_shuffle = self.shuffle_files
  4751. sampler = None
  4752. self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id,
  4753. non_mappable=True)
  4754. def get_args(self):
  4755. args = super().get_args()
  4756. args["dataset_files"] = self.dataset_files
  4757. args["num_samples"] = self.num_samples
  4758. if self.shuffle_files is not None:
  4759. args["shuffle_files"] = self.shuffle_files
  4760. args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL)
  4761. args["shuffle"] = self.shuffle_level
  4762. args["num_shards"] = self.num_shards
  4763. args["shard_id"] = self.shard_id
  4764. args["sampler"] = self.sampler
  4765. args["cache"] = self.cache.cache_client if self.cache is not None else None
  4766. return args
  4767. def get_dataset_size(self):
  4768. """
  4769. Get the number of batches in an epoch.
  4770. Return:
  4771. Number, number of batches.
  4772. """
  4773. if self.dataset_size is None:
  4774. num_rows = TextFileOp.get_num_rows(self.dataset_files)
  4775. self.dataset_size = get_num_rows(num_rows, self.num_shards)
  4776. # If the user gave a num samples in the dataset, then the sampler will limit the rows returned
  4777. # to that amount. Account for that here in the row count
  4778. if self.num_samples is not None and self.num_samples > 0 and num_rows > self.num_samples:
  4779. self.dataset_size = self.num_samples
  4780. return self.dataset_size
  4781. def is_shuffled(self):
  4782. return self.shuffle_files
  4783. def is_sharded(self):
  4784. if self.num_shards is not None:
  4785. return self.num_shards > 1
  4786. return False
  4787. class _NumpySlicesDataset:
  4788. """
  4789. Mainly for dealing with several kinds of formats of Python data, and return one row each time.
  4790. """
  4791. def __init__(self, data, column_list=None):
  4792. self.column_list = None
  4793. # Convert dict data into tuple
  4794. if isinstance(data, dict):
  4795. data = self.process_dict(data)
  4796. if isinstance(data, tuple):
  4797. self.data = ()
  4798. data_len = len(data)
  4799. for i in range(data_len):
  4800. self.data = self.data + (np.array(data[i]),)
  4801. else:
  4802. self.data = (np.array(data),)
  4803. # check whether the data length in each column is equal
  4804. data_len = [len(data_item) for data_item in self.data]
  4805. if data_len[1:] != data_len[:-1]:
  4806. raise ValueError("Data length in each column is not equal.")
  4807. # Init column_name
  4808. if column_list is not None:
  4809. self.column_list = column_list
  4810. elif self.column_list is None:
  4811. self.column_list = []
  4812. column_num = len(self.data)
  4813. for i in range(column_num):
  4814. self.column_list.append("column_" + str(i))
  4815. def __getitem__(self, index):
  4816. data_row = [d[index, ...] for d in self.data]
  4817. data_res = tuple(data_row)
  4818. return data_res
  4819. def __len__(self):
  4820. return len(self.data[0])
  4821. def process_dict(self, input_data):
  4822. """
  4823. Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first.
  4824. """
  4825. # Convert pandas like dict(has "values" column) into General dict
  4826. data_keys = list(input_data.keys())
  4827. data_col = input_data[data_keys[0]]
  4828. if hasattr(data_col, "values"):
  4829. new_dict = {}
  4830. for key in data_keys:
  4831. item1 = input_data.pop(key)
  4832. new_dict[key] = item1.values
  4833. input_data = new_dict
  4834. # Convert the data in dict into tuple
  4835. data = ()
  4836. keys = list(input_data.keys())
  4837. self.column_list = keys
  4838. for key in keys:
  4839. value = input_data[key]
  4840. data = data + (list(value),)
  4841. return data
  4842. class NumpySlicesDataset(GeneratorDataset):
  4843. """
  4844. Create a dataset with given data slices, mainly for loading Python data into dataset.
  4845. This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive. The table
  4846. below shows what input arguments are allowed and their expected behavior.
  4847. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  4848. :widths: 25 25 50
  4849. :header-rows: 1
  4850. * - Parameter 'sampler'
  4851. - Parameter 'shuffle'
  4852. - Expected Order Behavior
  4853. * - None
  4854. - None
  4855. - random order
  4856. * - None
  4857. - True
  4858. - random order
  4859. * - None
  4860. - False
  4861. - sequential order
  4862. * - Sampler object
  4863. - None
  4864. - order defined by sampler
  4865. * - Sampler object
  4866. - True
  4867. - not allowed
  4868. * - Sampler object
  4869. - False
  4870. - not allowed
  4871. Args:
  4872. data (Union[list, tuple, dict]) Input of given data. Supported data types include: list, tuple, dict and other
  4873. NumPy formats. Input data will be sliced along the first dimension and generate additional rows.
  4874. Large data is not recommended to be loaded in this way as data is loading into memory.
  4875. column_names (list[str], optional): List of column names of the dataset (default=None). If column_names is not
  4876. provided, when data is dict, column_names will be its keys, otherwise it will be like column_1, column_2 ...
  4877. num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images).
  4878. num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
  4879. shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
  4880. (default=None, expected order behavior shown in the table).
  4881. sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset. Random accessible
  4882. input is required (default=None, expected order behavior shown in the table).
  4883. num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
  4884. When this argument is specified, 'num_samples' will not used. Random accessible input is required.
  4885. shard_id (int, optional): The shard ID within num_shards (default=None). This argument must be specified only
  4886. when num_shards is also specified. Random accessible input is required.
  4887. Examples:
  4888. >>> import mindspore.dataset as ds
  4889. >>>
  4890. >>> # 1) Input data can be a list
  4891. >>> data = [1, 2, 3]
  4892. >>> dataset1 = ds.NumpySlicesDataset(data, column_names=["column_1"])
  4893. >>>
  4894. >>> # 2) Input data can be a dictionary, and column_names will be its keys
  4895. >>> data = {"a": [1, 2], "b": [3, 4]}
  4896. >>> dataset2 = ds.NumpySlicesDataset(data)
  4897. >>>
  4898. >>> # 3) Input data can be a tuple of lists (or NumPy arrays), each tuple element refers to data in each column
  4899. >>> data = ([1, 2], [3, 4], [5, 6])
  4900. >>> dataset3 = ds.NumpySlicesDataset(data, column_names=["column_1", "column_2", "column_3"])
  4901. >>>
  4902. >>> # 4) Load data from CSV file
  4903. >>> import pandas as pd
  4904. >>> df = pd.read_csv("file.csv")
  4905. >>> dataset4 = ds.NumpySlicesDataset(dict(df), shuffle=False)
  4906. """
  4907. @check_numpyslicesdataset
  4908. def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None,
  4909. sampler=None, num_shards=None, shard_id=None):
  4910. dataset = _NumpySlicesDataset(data, column_names)
  4911. super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples,
  4912. num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
  4913. num_shards=num_shards, shard_id=shard_id)
  4914. class _PaddedDataset:
  4915. """
  4916. Mainly for combining false samples provided by users into a dataset.
  4917. Args:
  4918. padded_samples (list(dict)): Data provided by user to be added to the initial Dataset.
  4919. """
  4920. def __init__(self, padded_samples):
  4921. self.column_names = list(padded_samples[0].keys())
  4922. self.padded_samples = padded_samples
  4923. def __getitem__(self, item):
  4924. return (self.padded_samples[item][key] for key in self.column_names)
  4925. def __len__(self):
  4926. return len(self.padded_samples)
  4927. class PaddedDataset(GeneratorDataset):
  4928. """
  4929. Create a dataset with fake data provided by user. Mainly used to add to the original data set
  4930. and assign it to the corresponding shard.
  4931. Args:
  4932. padded_samples (list(dict)): Samples provided by user.
  4933. Raises:
  4934. TypeError: If padded_samples is not an instance of list.
  4935. TypeError: If the element of padded_samples is not an instance of dict.
  4936. ValueError: If the padded_samples is empty.
  4937. Examples:
  4938. >>> import mindspore.dataset as ds
  4939. >>> data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}]
  4940. >>> ds1 = ds.PaddedDataset(data1)
  4941. """
  4942. @check_paddeddataset
  4943. def __init__(self, padded_samples):
  4944. dataset = _PaddedDataset(padded_samples)
  4945. super().__init__(dataset, column_names=dataset.column_names,
  4946. num_shards=None,
  4947. shard_id=None, shuffle=False)
  4948. self._dataset_size = len(dataset.padded_samples)
  4949. self.padded_samples = padded_samples
  4950. def get_dataset_size(self):
  4951. return self._dataset_size
  4952. class BuildVocabDataset(DatasetOp):
  4953. """
  4954. Build a vocab from a dataset. This will collect all the unique words in a dataset and return a vocab
  4955. which contains top_k most frequent words (if top_k is specified).
  4956. This function is not meant to be called directly by user. To build vocab, use the function
  4957. text.Vocab.from_dataset()
  4958. Args:
  4959. vocab (Vocab): text.vocab object.
  4960. columns (Union[str, list], optional): Column names to get words from. It can be a list of column names
  4961. (Default=None, all columns are used, return error if any column is not a string).
  4962. freq_range (tuple, optional): Tuple of integers (min_frequency, max_frequency). Words within the frequency
  4963. range will be kept. 0 <= min_frequency <= max_frequency <= total_words. min_frequency/max_frequency
  4964. can be None, which corresponds to 0/total_words separately (default=None, all words are included).
  4965. top_k (int, optional): top_k > 0. Number of words to be built into vocab. top_k most frequent words are
  4966. taken. The top_k is taken after freq_range. If not enough top_k words, all words will be taken
  4967. (default=None, all words are included).
  4968. special_tokens (list, optional): List of strings, each one is a special token, for example
  4969. special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
  4970. special_first (bool, optional): Whether special_tokens will be prepended/appended to vocab, If special_tokens
  4971. is specified and special_first is set to None, special_tokens will be prepended. (default=None).
  4972. prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None).
  4973. """
  4974. def __init__(self, input_dataset, vocab, columns, freq_range, top_k, special_tokens, special_first,
  4975. prefetch_size=None):
  4976. super().__init__()
  4977. self.columns = columns
  4978. self.children.append(input_dataset)
  4979. self.prefetch_size = prefetch_size
  4980. self.vocab = vocab
  4981. self.freq_range = freq_range
  4982. self.top_k = top_k
  4983. self.special_tokens = special_tokens
  4984. self.special_first = special_first
  4985. input_dataset.parent.append(self)
  4986. def get_args(self):
  4987. args = super().get_args()
  4988. args["columns"] = self.columns
  4989. args["vocab"] = self.vocab
  4990. args["freq_range"] = self.freq_range
  4991. args["prefetch_size"] = self.prefetch_size
  4992. args["top_k"] = self.top_k
  4993. args["special_tokens"] = self.special_tokens
  4994. args["special_first"] = self.special_first
  4995. return args
  4996. def __deepcopy__(self, memodict):
  4997. if id(self) in memodict:
  4998. return memodict[id(self)]
  4999. cls = self.__class__
  5000. new_op = cls.__new__(cls)
  5001. memodict[id(self)] = new_op
  5002. new_op.children = copy.deepcopy(self.children, memodict)
  5003. new_op.columns = copy.deepcopy(self.columns, memodict)
  5004. new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
  5005. new_op.prefetch_size = copy.deepcopy(self.prefetch_size, memodict)
  5006. new_op.parent = copy.deepcopy(self.parent, memodict)
  5007. new_op.freq_range = copy.deepcopy(self.freq_range, memodict)
  5008. new_op.top_k = copy.deepcopy(self.top_k, memodict)
  5009. new_op.vocab = self.vocab
  5010. new_op.special_tokens = copy.deepcopy(self.special_tokens)
  5011. new_op.special_first = copy.deepcopy(self.special_first)
  5012. new_op.dataset_size = self.dataset_size
  5013. return new_op
  5014. class BuildSentencePieceVocabDataset(DatasetOp):
  5015. """
  5016. Build a SentencePieceVocab from a dataset.
  5017. This function is not meant to be called directly by user. To build vocab, use the function
  5018. text.SentencePieceVocab.from_dataset()
  5019. Args:
  5020. vocab (SentencePieceVocab): text.SentencePieceVocab object.
  5021. col_names (list): List of column names.
  5022. vocab_size (int): Vocabulary size. The type is uint32.
  5023. character_coverage (float): Percentage of characters covered by the model. Good defaults are:
  5024. 0.9995 for languages with rich character sets like Japanese or Chinese character sets,
  5025. and 1.0 for other languages with small character sets.
  5026. model_type (SentencePieceModel): Model type. Choose from unigram (default), bpe, char, or word.
  5027. The input sentence must be pretokenized when using word type.
  5028. params (dict): A dictionary with no incoming parameters.
  5029. """
  5030. def __init__(self, input_dataset, vocab, col_names, vocab_size, character_coverage, model_type, params):
  5031. super().__init__()
  5032. self.vocab = vocab
  5033. self.col_names = col_names
  5034. self.vocab_size = vocab_size
  5035. self.children.append(input_dataset)
  5036. self.character_coverage = character_coverage
  5037. self.model_type = DE_C_INTER_SENTENCEPIECE_MODE[model_type]
  5038. self.params = params
  5039. input_dataset.parent.append(self)
  5040. def get_args(self):
  5041. args = super().get_args()
  5042. args["vocab"] = self.vocab
  5043. args["col_names"] = self.col_names
  5044. args["vocab_size"] = self.vocab_size
  5045. args["character_coverage"] = self.character_coverage
  5046. args["model_type"] = self.model_type
  5047. args["params"] = self.params
  5048. return args
  5049. def __deepcopy__(self, memodict):
  5050. if id(self) in memodict:
  5051. return memodict[id(self)]
  5052. cls = self.__class__
  5053. new_op = cls.__new__(cls)
  5054. memodict[id(self)] = new_op
  5055. new_op.children = copy.deepcopy(self.children, memodict)
  5056. new_op.col_names = copy.deepcopy(self.col_names, memodict)
  5057. new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
  5058. new_op.vocab_size = copy.deepcopy(self.vocab_size, memodict)
  5059. new_op.parent = copy.deepcopy(self.parent, memodict)
  5060. new_op.character_coverage = copy.deepcopy(self.character_coverage, memodict)
  5061. new_op.params = copy.deepcopy(self.params, memodict)
  5062. new_op.vocab = self.vocab
  5063. new_op.model_type = copy.deepcopy(self.model_type)
  5064. new_op.dataset_size = self.dataset_size
  5065. return new_op