You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 160 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. datasets.py supports various formats of datasets, including ImageNet, TFData,
  17. MNIST, Cifar10/100, Manifest, MindRecord, etc. This module could load data in
  18. high performance and parse data precisely. It also provides the following
  19. operations for users to preprocess data: shuffle, batch, repeat, map, and zip.
  20. """
  21. import glob
  22. import json
  23. import math
  24. import os
  25. import uuid
  26. import multiprocessing
  27. import queue
  28. from enum import Enum
  29. from importlib import import_module
  30. import threading
  31. import copy
  32. import numpy as np
  33. from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
  34. MindRecordOp, TextFileOp, VOCOp, CBatchInfo
  35. from mindspore._c_expression import typing
  36. from mindspore import log as logger
  37. from . import samplers
  38. from .iterators import DictIterator, TupleIterator
  39. from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
  40. check_rename, \
  41. check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
  42. check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
  43. check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, check_split
  44. from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
  45. try:
  46. context = import_module("mindspore.context")
  47. except ModuleNotFoundError:
  48. context = None
  49. class Shuffle(str, Enum):
  50. GLOBAL: str = "global"
  51. FILES: str = "file"
  52. @check_zip
  53. def zip(datasets):
  54. """
  55. Zips the datasets in the input tuple of datasets.
  56. Args:
  57. datasets (tuple of class Dataset): A tuple of datasets to be zipped together.
  58. The number of datasets should be more than 1.
  59. Returns:
  60. DatasetOp, ZipDataset.
  61. Raises:
  62. ValueError: If the number of datasets is 1.
  63. TypeError: If datasets is not a tuple.
  64. Examples:
  65. >>> import mindspore.dataset as ds
  66. >>>
  67. >>> dataset_dir1 = "path/to/imagefolder_directory1"
  68. >>> dataset_dir2 = "path/to/imagefolder_directory2"
  69. >>> ds1 = ds.ImageFolderDatasetV2(dataset_dir1, num_parallel_workers=8)
  70. >>> ds2 = ds.ImageFolderDatasetV2(dataset_dir2, num_parallel_workers=8)
  71. >>>
  72. >>> # creates a dataset which is the combination of ds1 and ds2
  73. >>> data = ds.zip((ds1, ds2))
  74. """
  75. if len(datasets) <= 1:
  76. raise ValueError(
  77. "Can't zip empty or just one dataset!")
  78. return ZipDataset(datasets)
  79. def get_num_rows(num_rows, num_shards):
  80. """
  81. Get the number rows of the dataset according to the shards.
  82. Args:
  83. num_rows (int): The number rows of the dataset should be more than 0.
  84. The number rows of the dataset should be more than 0.
  85. num_shards (int or None): Number of shards that the dataset should be divided into.
  86. The number of shards should be None or more than 1.
  87. Returns:
  88. Int, number of rows.
  89. Raises:
  90. ValueError: If num_rows is invalid (< 0).
  91. ValueError: If num_shards is invalid (<= 0).
  92. """
  93. if num_rows < 0:
  94. raise ValueError("num_rows is invalid (< 0)")
  95. if num_shards is not None:
  96. if num_shards <= 0:
  97. raise ValueError("num_shards is invalid (<= 0)")
  98. if num_rows % num_shards == 0:
  99. num_rows = num_rows // num_shards
  100. else:
  101. num_rows = num_rows // num_shards + 1
  102. return num_rows
  103. class Dataset:
  104. """
  105. Abstract class to represent a dataset in DataEngine's data pipeline.
  106. This class is the base class of SourceDataset and DatasetOp, and represents
  107. a node in the data flow graph.
  108. Args:
  109. num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel
  110. (default=None).
  111. """
  112. def __init__(self, num_parallel_workers=None):
  113. self.input = []
  114. self.output = []
  115. self.num_parallel_workers = num_parallel_workers
  116. self._device_iter = 0
  117. self._input_indexs = ()
  118. self._output_types = None
  119. self._output_shapes = None
  120. self._dataset_size = None
  121. self._batch_size = None
  122. self._num_classes = None
  123. self._repeat_count = None
  124. self._sync = False
  125. def __add__(self, datasets):
  126. return self.concat(datasets)
  127. def get_args(self):
  128. """
  129. Returns attributes (member variables) related to the current class.
  130. Must include all arguments passed to the __init__() of the current class, excluding 'input_dataset'.
  131. Args:
  132. Returns:
  133. Python dictionary.
  134. """
  135. args = dict()
  136. args["num_parallel_workers"] = self.num_parallel_workers
  137. return args
  138. @check_batch
  139. def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
  140. input_columns=None, pad_info=None):
  141. """
  142. Combines batch_size number of consecutive rows into batches.
  143. For any child node, a batch is treated as a single row.
  144. For any column, all the elements within that column must have the same shape.
  145. If a per_batch_map callable is provided, it will be applied to the batches of tensors.
  146. Note:
  147. The order of using repeat and batch reflects the number of batches. Recommend that
  148. repeat operation should be used after batch operation.
  149. Args:
  150. batch_size (int or function): The number of rows each batch is created with. An
  151. int or callable which takes exactly 1 parameter, BatchInfo.
  152. drop_remainder (bool, optional): Determines whether or not to drop the last
  153. possibly incomplete batch (default=False). If True, and if there are less
  154. than batch_size rows available to make the last batch, then those rows will
  155. be dropped and not propagated to the child node.
  156. num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel (default=None).
  157. per_batch_map (callable, optional): Per batch map callable. A callable which takes
  158. (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represent a batch of
  159. Tensors on a given column. The number of lists should match with number of entries in input_columns. The
  160. last parameter of the callable should always be a BatchInfo object.
  161. input_columns (list of string, optional): List of names of the input columns. The size of the list should
  162. match with signature of per_batch_map callable.
  163. pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)}
  164. would pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0.
  165. Returns:
  166. BatchDataset, dataset batched.
  167. Examples:
  168. >>> import mindspore.dataset as ds
  169. >>> # data is an instance of Dataset object.
  170. >>> # creates a dataset where every 100 rows is combined into a batch
  171. >>> # and drops the last incomplete batch if there is one.
  172. >>> data = data.batch(100, True)
  173. """
  174. return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns,
  175. pad_info)
  176. @check_sync_wait
  177. def sync_wait(self, condition_name, num_batch=1, callback=None):
  178. '''
  179. Add a blocking condition to the input Dataset.
  180. Args:
  181. num_batch (int): the number of batches without blocking at the start of each epoch.
  182. condition_name (str): The condition name that is used to toggle sending next row.
  183. callback (function): The callback funciton that will be invoked when sync_update is called.
  184. Raises:
  185. RuntimeError: If condition name already exists.
  186. Examples:
  187. >>> import mindspore.dataset as ds
  188. >>> # data is an instance of Dataset object.
  189. >>> data = data.sync_wait("callback1")
  190. >>> data = data.batch(batch_size)
  191. >>> for batch_data in data.create_dict_iterator():
  192. >>> data = data.sync_update("callback1")
  193. '''
  194. return SyncWaitDataset(self, condition_name, num_batch, callback)
  195. @check_shuffle
  196. def shuffle(self, buffer_size):
  197. """
  198. Randomly shuffles the rows of this dataset using the following algorithm:
  199. 1. Make a shuffle buffer that contains the first buffer_size rows.
  200. 2. Randomly select an element from the shuffle buffer to be the next row
  201. propogated to the child node.
  202. 3. Get the next row (if any) from the parent node and put it in the shuffle buffer.
  203. 4. Repeat steps 2 and 3 until there are no more rows left in the shuffle buffer.
  204. A seed can be provided to be used on the first epoch. In every subsequent
  205. epoch, the seed is changed to a new one, randomly generated value.
  206. Args:
  207. buffer_size (int): The size of the buffer (must be larger than 1) for
  208. shuffling. Setting buffer_size equal to the number of rows in the entire
  209. dataset will result in a global shuffle.
  210. Returns:
  211. ShuffleDataset, dataset shuffled.
  212. Raises:
  213. RuntimeError: If exist sync operators before shuffle.
  214. Examples:
  215. >>> import mindspore.dataset as ds
  216. >>> # data is an instance of Dataset object
  217. >>> # optionally set the seed for the first epoch
  218. >>> ds.config.set_seed(58)
  219. >>>
  220. >>> # creates a shuffled dataset using a shuffle buffer of size 4
  221. >>> data = data.shuffle(4)
  222. """
  223. return ShuffleDataset(self, buffer_size)
  224. def flat_map(self, func):
  225. """
  226. Maps `func` to each row in dataset and flatten the result.
  227. The specified `func` is a function that must take one 'Ndarray' as input
  228. and return a 'Dataset'.
  229. Args:
  230. func (function): A function that must take one 'Ndarray' as an argument and
  231. return a 'Dataset'.
  232. Returns:
  233. Dataset, applied by the function.
  234. Examples:
  235. >>> import mindspore.dataset as ds
  236. >>> import mindspore.dataset.text as text
  237. >>> # declare a function which returns a Dataset object
  238. >>> def flat_map_func(x):
  239. >>> data_dir = text.to_str(x[0])
  240. >>> d = ds.ImageFolderDatasetV2(data_dir)
  241. >>> return d
  242. >>> # data is a Dataset object
  243. >>> data = ds.TextFileDataset(DATA_FILE)
  244. >>> data = data.flat_map(flat_map_func)
  245. Raises:
  246. TypeError: If `func` is not a function.
  247. TypeError: If `func` doesn't return a Dataset.
  248. """
  249. dataset = None
  250. if not hasattr(func, '__call__'):
  251. raise TypeError("func must be a function.")
  252. for row_data in self:
  253. if dataset is None:
  254. dataset = func(row_data)
  255. else:
  256. dataset += func(row_data)
  257. if not isinstance(dataset, Dataset):
  258. raise TypeError("flat_map must return a Dataset object.")
  259. return dataset
  260. @check_map
  261. def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
  262. num_parallel_workers=None, python_multiprocessing=False):
  263. """
  264. Applies each operation in operations to this dataset.
  265. The order of operations is determined by the position of each operation in operations.
  266. operations[0] will be applied first, then operations[1], then operations[2], etc.
  267. Each operation will be passed one or more columns from the dataset as input, and zero or
  268. more columns will be outputted. The first operation will be passed the columns specified
  269. in input_columns as input. If there is more than one operator in operations, the outputted
  270. columns of the previous operation are used as the input columns for the next operation.
  271. The columns outputted by the very last operation will be assigned names specified by
  272. output_columns.
  273. Only the columns specified in columns_order will be propagated to the child node. These
  274. columns will be in the same order as specified in columns_order.
  275. Args:
  276. input_columns (list[str]): List of the names of the columns that will be passed to
  277. the first operation as input. The size of this list must match the number of
  278. input columns expected by the first operator. (default=None, the first
  279. operation will be passed however many columns that is required, starting from
  280. the first column).
  281. operations (list[TensorOp] or Python list[functions]): List of operations to be
  282. applied on the dataset. Operations are applied in the order they appear in this list.
  283. output_columns (list[str], optional): List of names assigned to the columns outputted by
  284. the last operation. This parameter is mandatory if len(input_columns) !=
  285. len(output_columns). The size of this list must match the number of output
  286. columns of the last operation. (default=None, output columns will have the same
  287. name as the input columns, i.e., the columns will be replaced).
  288. columns_order (list[str], optional): list of all the desired columns to propagate to the
  289. child node. This list must be a subset of all the columns in the dataset after
  290. all operations are applied. The order of the columns in each row propagated to the
  291. child node follow the order they appear in this list. The parameter is mandatory
  292. if the len(input_columns) != len(output_columns). (default=None, all columns
  293. will be propagated to the child node, the order of the columns will remain the
  294. same).
  295. num_parallel_workers (int, optional): Number of threads used to process the dataset in
  296. parallel (default=None, the value from the config will be used).
  297. python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
  298. option could be beneficial if the python operation is computational heavy (default=False).
  299. Returns:
  300. MapDataset, dataset after mapping operation.
  301. Examples:
  302. >>> import mindspore.dataset as ds
  303. >>> import mindspore.dataset.transforms.vision.c_transforms as c_transforms
  304. >>>
  305. >>> # data is an instance of Dataset which has 2 columns, "image" and "label".
  306. >>> # ds_pyfunc is an instance of Dataset which has 3 columns, "col0", "col1", and "col2". Each column is
  307. >>> # a 2d array of integers.
  308. >>>
  309. >>> # This config is a global setting, meaning that all future operations which
  310. >>> # uses this config value will use 2 worker threads, unless if specified
  311. >>> # otherwise in their constructor. set_num_parallel_workers can be called
  312. >>> # again later if a different number of worker threads are needed.
  313. >>> ds.config.set_num_parallel_workers(2)
  314. >>>
  315. >>> # Two operations, which takes 1 column for input and outputs 1 column.
  316. >>> decode_op = c_transforms.Decode(rgb_format=True)
  317. >>> random_jitter_op = c_transforms.RandomColorAdjust((0.8, 0.8), (1, 1), (1, 1), (0, 0))
  318. >>>
  319. >>> # 1) Simple map example
  320. >>>
  321. >>> operations = [decode_op]
  322. >>> input_columns = ["image"]
  323. >>>
  324. >>> # Applies decode_op on column "image". This column will be replaced by the outputed
  325. >>> # column of decode_op. Since columns_order is not provided, both columns "image"
  326. >>> # and "label" will be propagated to the child node in their original order.
  327. >>> ds_decoded = data.map(input_columns, operations)
  328. >>>
  329. >>> # Rename column "image" to "decoded_image"
  330. >>> output_columns = ["decoded_image"]
  331. >>> ds_decoded = data.map(input_columns, operations, output_columns)
  332. >>>
  333. >>> # Specify the order of the columns.
  334. >>> columns_order ["label", "image"]
  335. >>> ds_decoded = data.map(input_columns, operations, None, columns_order)
  336. >>>
  337. >>> # Rename column "image" to "decoded_image" and also specify the order of the columns.
  338. >>> columns_order ["label", "decoded_image"]
  339. >>> output_columns = ["decoded_image"]
  340. >>> ds_decoded = data.map(input_columns, operations, output_columns, columns_order)
  341. >>>
  342. >>> # Rename column "image" to "decoded_image" and keep only this column.
  343. >>> columns_order ["decoded_image"]
  344. >>> output_columns = ["decoded_image"]
  345. >>> ds_decoded = data.map(input_columns, operations, output_columns, columns_order)
  346. >>>
  347. >>> # Simple example using pyfunc. Renaming columns and specifying column order
  348. >>> # work in the same way as the previous examples.
  349. >>> input_columns = ["col0"]
  350. >>> operations = [(lambda x: x + 1)]
  351. >>> ds_mapped = ds_pyfunc.map(input_columns, operations)
  352. >>>
  353. >>> # 2) Map example with more than one operation
  354. >>>
  355. >>> # If this list of operations is used with map, decode_op will be applied
  356. >>> # first, then random_jitter_op will be applied.
  357. >>> operations = [decode_op, random_jitter_op]
  358. >>>
  359. >>> input_columns = ["image"]
  360. >>>
  361. >>> # Creates a dataset where the images are decoded, then randomly color jittered.
  362. >>> # decode_op takes column "image" as input and outputs one column. The column
  363. >>> # outputted by decode_op is passed as input to random_jitter_op.
  364. >>> # random_jitter_op will output one column. Column "image" will be replaced by
  365. >>> # the column outputted by random_jitter_op (the very last operation). All other
  366. >>> # columns are unchanged. Since columns_order is not specified, the order of the
  367. >>> # columns will remain the same.
  368. >>> ds_mapped = data.map(input_columns, operations)
  369. >>>
  370. >>> # Creates a dataset that is identical to ds_mapped, except the column "image"
  371. >>> # that is outputted by random_jitter_op is renamed to "image_transformed".
  372. >>> # Specifying column order works in the same way as examples in 1).
  373. >>> output_columns = ["image_transformed"]
  374. >>> ds_mapped_and_renamed = data.map(input_columns, operation, output_columns)
  375. >>>
  376. >>> # Multiple operations using pyfunc. Renaming columns and specifying column order
  377. >>> # work in the same way as examples in 1).
  378. >>> input_columns = ["col0"]
  379. >>> operations = [(lambda x: x + x), (lambda x: x - 1)]
  380. >>> output_columns = ["col0_mapped"]
  381. >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns)
  382. >>>
  383. >>> # 3) Example where number of input columns is not equal to number of output columns
  384. >>>
  385. >>> # operations[0] is a lambda that takes 2 columns as input and outputs 3 columns.
  386. >>> # operations[1] is a lambda that takes 3 columns as input and outputs 1 column.
  387. >>> # operations[1] is a lambda that takes 1 column as input and outputs 4 columns.
  388. >>> #
  389. >>> # Note: the number of output columns of operation[i] must equal the number of
  390. >>> # input columns of operation[i+1]. Otherwise, this map call will also result
  391. >>> # in an error.
  392. >>> operations = [(lambda x y: (x, x + y, x + y + 1)),
  393. >>> (lambda x y z: x * y * z),
  394. >>> (lambda x: (x % 2, x % 3, x % 5, x % 7))]
  395. >>>
  396. >>> # Note: because the number of input columns is not the same as the number of
  397. >>> # output columns, the output_columns and columns_order parameter must be
  398. >>> # specified. Otherwise, this map call will also result in an error.
  399. >>> input_columns = ["col2", "col0"]
  400. >>> output_columns = ["mod2", "mod3", "mod5", "mod7"]
  401. >>>
  402. >>> # Propagate all columns to the child node in this order:
  403. >>> columns_order = ["col0", "col2", "mod2", "mod3", "mod5", "mod7", "col1"]
  404. >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
  405. >>>
  406. >>> # Propagate some columns to the child node in this order:
  407. >>> columns_order = ["mod7", "mod3", "col1"]
  408. >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
  409. """
  410. return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers,
  411. python_multiprocessing)
  412. @check_filter
  413. def filter(self, predicate, input_columns=None, num_parallel_workers=1):
  414. """
  415. Filter dataset by predicate.
  416. Note:
  417. If input_columns not provided or empty, all columns will be used.
  418. Args:
  419. predicate(callable): python callable which returns a boolean value, if False then filter the element.
  420. input_columns: (list[str], optional): List of names of the input columns, when
  421. default=None, the predicate will be applied on all columns in the dataset.
  422. num_parallel_workers (int, optional): Number of workers to process the Dataset
  423. in parallel (default=None).
  424. Returns:
  425. FilterDataset, dataset filter.
  426. Examples:
  427. >>> import mindspore.dataset as ds
  428. >>> # generator data(0 ~ 63)
  429. >>> # filter the data that greater than or equal to 11
  430. >>> dataset_f = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"])
  431. """
  432. return FilterDataset(self, predicate, input_columns, num_parallel_workers)
  433. @check_repeat
  434. def repeat(self, count=None):
  435. """
  436. Repeats this dataset count times. Repeat indefinitely if the count is None or -1.
  437. Note:
  438. The order of using repeat and batch reflects the number of batches. Recommend that
  439. repeat operation should be used after batch operation.
  440. If dataset_sink_mode is False, here repeat operation is invalid.
  441. If dataset_sink_mode is True, repeat count should be equal to the epoch of training. Otherwise,
  442. errors could occur since the amount of data is not the amount training requires.
  443. Args:
  444. count (int): Number of times the dataset should be repeated (default=None).
  445. Returns:
  446. RepeatDataset, dataset repeated.
  447. Examples:
  448. >>> import mindspore.dataset as ds
  449. >>> # data is an instance of Dataset object.
  450. >>> # creates a dataset where the dataset is repeated for 50 epochs
  451. >>> repeated = data.repeat(50)
  452. >>>
  453. >>> # creates a dataset where each epoch is shuffled individually
  454. >>> shuffled_and_repeated = data.shuffle(10)
  455. >>> shuffled_and_repeated = shuffled_and_repeated.repeat(50)
  456. >>>
  457. >>> # creates a dataset where the dataset is first repeated for
  458. >>> # 50 epochs before shuffling. the shuffle operator will treat
  459. >>> # the entire 50 epochs as one big dataset.
  460. >>> repeat_and_shuffle = data.repeat(50)
  461. >>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10)
  462. """
  463. if count == 1:
  464. return self
  465. return RepeatDataset(self, count)
  466. @check_skip
  467. def skip(self, count):
  468. """
  469. Skip the first N elements of this dataset.
  470. Args:
  471. count (int): Number of elements the dataset should be skipped.
  472. Returns:
  473. SkipDataset, dataset skipped.
  474. Examples:
  475. >>> import mindspore.dataset as ds
  476. >>> # data is an instance of Dataset object.
  477. >>> # creates a dataset which skips first 3 elements from data
  478. >>> data = data.skip(3)
  479. """
  480. return SkipDataset(self, count)
  481. @check_take
  482. def take(self, count=-1):
  483. """
  484. Takes at most given numbers of elements from the dataset.
  485. Note:
  486. 1. If count is greater than the number of element in dataset or equal to -1,
  487. all the element in dataset will be taken.
  488. 2. The order of using take and batch effects. If take before batch operation,
  489. then taken given number of rows, otherwise take given number of batches.
  490. Args:
  491. count (int, optional): Number of elements to be taken from the dataset (default=-1).
  492. Returns:
  493. TakeDataset, dataset taken.
  494. Examples:
  495. >>> import mindspore.dataset as ds
  496. >>> # data is an instance of Dataset object.
  497. >>> # creates a dataset where the dataset including 50 elements.
  498. >>> data = data.take(50)
  499. """
  500. if count == -1:
  501. return self
  502. return TakeDataset(self, count)
  503. def _get_absolute_split_sizes(self, sizes):
  504. """
  505. Internal method called by split to calculate absolute split sizes and to
  506. do some error checking after calculating absolute split sizes.
  507. """
  508. # call get_dataset_size here and check input here because
  509. # dont want to call this once in check_split and another time in
  510. # here again
  511. dataset_size = self.get_dataset_size()
  512. if(dataset_size is None or dataset_size <= 0):
  513. raise RuntimeError("dataset size unknown, unable to split.")
  514. all_int = all(isinstance(item, int) for item in sizes)
  515. if all_int:
  516. sizes_sum = sum(sizes)
  517. if sizes_sum != dataset_size:
  518. raise RuntimeError("sum of split sizes {} is not equal to dataset size {}."
  519. .format(sizes_sum, dataset_size))
  520. return sizes
  521. absolute_sizes = []
  522. for item in sizes:
  523. absolute_size = int(round(item * dataset_size))
  524. if absolute_size == 0:
  525. raise RuntimeError("split percentage {} is too small.".format(item))
  526. absolute_sizes.append(absolute_size)
  527. absolute_sizes_sum = sum(absolute_sizes)
  528. if absolute_sizes_sum != dataset_size:
  529. raise RuntimeError("sum of calculated split sizes {} is not equal to dataset size {}."
  530. .format(absolute_sizes_sum, dataset_size))
  531. return absolute_sizes
  532. @check_split
  533. def split(self, sizes, randomize=True):
  534. """
  535. Splits the dataset into smaller, non-overlapping datasets.
  536. This is a general purpose split function which can be called from any operator in the pipeline.
  537. There is another, optimized split function, which will be called automatically if ds.split is
  538. called where ds is a MappableDataset.
  539. Args:
  540. sizes (list of int or list of float): If a list of integers [s1, s2, …, sn] is
  541. provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
  542. respectively. If the sum of all sizes does not equal the original dataset size, an
  543. an error will occur.
  544. If a list of floats [f1, f2, …, fn] is provided, the dataset will be split into n
  545. Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
  546. of the original dataset. If after rounding, any size equals 0, an error will occur.
  547. All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
  548. randomize (bool): determines whether or not to split the data randomly. If true, the data
  549. will be randomly split. Otherwise, each split will be created with consecutive rows
  550. from the dataset.
  551. Note:
  552. 1. Dataset cannot be sharded if split is going to be called.
  553. 2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
  554. Shuffling the dataset may not be deterministic, which means the data in each split
  555. will be different in each epoch.
  556. Raises:
  557. RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
  558. RuntimeError: If sizes is list of integers and sum of all elements in sizes does not
  559. equal the dataset size.
  560. RuntimeError: If sizes is list of float and there is a split with size 0 after calculations.
  561. RuntimeError: If the dataset is sharded prior to calling split.
  562. ValueError: If sizes is list of float and not all floats are between 0 and 1, or if the
  563. floats don’t sum to 1.
  564. Returns
  565. tuple(Dataset), a tuple of datasets that have been split.
  566. Examples:
  567. >>> import mindspore.dataset as ds
  568. >>>
  569. >>> dataset_dir = "/path/to/text_file.txt"
  570. >>>
  571. >>> # TextFileDataset is not a mappable dataset, so this non optimized split will be called.
  572. >>> # many datasets have shuffle on by default, set shuffle to False if split will be called!
  573. >>> data = ds.TextFileDataset(dataset_dir, shuffle=False)
  574. >>> train, test = data.split([0.9, 0.1])
  575. """
  576. if self.is_shuffled():
  577. logger.warning("dataset is shuffled before split.")
  578. if self.is_sharded():
  579. raise RuntimeError("dataset should not be sharded before split.")
  580. absolute_sizes = self._get_absolute_split_sizes(sizes)
  581. splits = []
  582. rows_to_skip = 0
  583. for size in absolute_sizes:
  584. ds = copy.deepcopy(self)
  585. if randomize:
  586. # want to shuffle the same way every epoch before split
  587. ds = ds.shuffle()
  588. ds.reshuffle_each_epoch = False
  589. if rows_to_skip > 0:
  590. ds = ds.skip(rows_to_skip)
  591. ds = ds.take(size)
  592. splits.append(ds)
  593. rows_to_skip += size
  594. return tuple(splits)
  595. @check_zip_dataset
  596. def zip(self, datasets):
  597. """
  598. Zips the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name.
  599. Args:
  600. datasets (tuple or class Dataset): A tuple of datasets or a single class Dataset
  601. to be zipped together with this dataset.
  602. Returns:
  603. ZipDataset, dataset zipped.
  604. Examples:
  605. >>> import mindspore.dataset as ds
  606. >>> # ds1 and ds2 are instances of Dataset object
  607. >>> # creates a dataset which is the combination of ds1 and ds2
  608. >>> data = ds1.zip(ds2)
  609. """
  610. if isinstance(datasets, tuple):
  611. datasets = (self, *datasets)
  612. elif isinstance(datasets, Dataset):
  613. datasets = (self, datasets)
  614. else:
  615. raise TypeError("The zip function %s type error!" % (datasets))
  616. return ZipDataset(datasets)
  617. @check_concat
  618. def concat(self, datasets):
  619. """
  620. Concat the datasets in the input list of datasets, supported using "+" to reload concat operation.
  621. Note:
  622. The column name,column data type and rank of column data should be the same in input datasets.
  623. Args:
  624. datasets (list or class Dataset): A list of datasets or a single class Dataset
  625. to be concatenated together with this dataset.
  626. Returns:
  627. ConcatDataset, dataset concatenated.
  628. Examples:
  629. >>> import mindspore.dataset as ds
  630. >>> # ds1 and ds2 are instances of Dataset object
  631. >>> # creates a dataset by concating ds1 and ds2 with "+" operation
  632. >>> data1 = ds1 + ds2
  633. >>> # creates a dataset by concating ds1 and ds2 with concat operation
  634. >>> data1 = ds1.concat(ds2)
  635. """
  636. if isinstance(datasets, Dataset):
  637. datasets = [self] + [datasets]
  638. elif isinstance(datasets, list):
  639. datasets = [self] + datasets
  640. else:
  641. raise TypeError("The concat_dataset function %s type error!" % (datasets))
  642. return ConcatDataset(datasets)
  643. @check_rename
  644. def rename(self, input_columns, output_columns):
  645. """
  646. Renames the columns in input datasets.
  647. Args:
  648. input_columns (list[str]): list of names of the input columns.
  649. output_columns (list[str]): list of names of the output columns.
  650. Returns:
  651. RenameDataset, dataset renamed.
  652. Examples:
  653. >>> import mindspore.dataset as ds
  654. >>> # data is an instance of Dataset object.
  655. >>> input_columns = ["input_col1", "input_col2", "input_col3"]
  656. >>> output_columns = ["output_col1", "output_col2", "output_col3"]
  657. >>>
  658. >>> # creates a dataset where input_col1 is renamed to output_col1, and
  659. >>> # input_col2 is renamed to output_col2, and input_col3 is renamed
  660. >>> # to output_col3.
  661. >>> data = data.rename(input_columns=input_columns, output_columns=output_columns)
  662. """
  663. return RenameDataset(self, input_columns, output_columns)
  664. @check_project
  665. def project(self, columns):
  666. """
  667. Projects certain columns in input datasets.
  668. The specified columns will be selected from the dataset and passed down
  669. the pipeline in the order specified. The other columns are discarded.
  670. Args:
  671. columns(list[str]): list of names of the columns to project.
  672. Returns:
  673. ProjectDataset, dataset projected.
  674. Examples:
  675. >>> import mindspore.dataset as ds
  676. >>> # data is an instance of Dataset object
  677. >>> columns_to_project = ["column3", "column1", "column2"]
  678. >>>
  679. >>> # creates a dataset that consist of column3, column1, column2
  680. >>> # in that order, regardless of the original order of columns.
  681. >>> data = data.project(columns=columns_to_project)
  682. """
  683. return ProjectDataset(self, columns)
  684. def apply(self, apply_func):
  685. """
  686. Apply a function in this dataset.
  687. The specified apply_func is a function that must take one 'Dataset' as an argument
  688. and return a preprogressing 'Dataset'.
  689. Args:
  690. apply_func (function): A function that must take one 'Dataset' as an argument and
  691. return a preprogressing 'Dataset'.
  692. Returns:
  693. Dataset, applied by the function.
  694. Examples:
  695. >>> import mindspore.dataset as ds
  696. >>> # data is an instance of Dataset object
  697. >>> # declare an apply_func function which returns a Dataset object
  698. >>> def apply_func(ds):
  699. >>> ds = ds.batch(2)
  700. >>> return ds
  701. >>> # use apply to call apply_func
  702. >>> data = data.apply(apply_func)
  703. Raises:
  704. TypeError: If apply_func is not a function.
  705. TypeError: If apply_func doesn't return a Dataset.
  706. """
  707. if not hasattr(apply_func, '__call__'):
  708. raise TypeError("apply_func must be a function.")
  709. dataset = apply_func(self)
  710. if not isinstance(dataset, Dataset):
  711. raise TypeError("apply_func must return a dataset.")
  712. return dataset
  713. def device_que(self, prefetch_size=None):
  714. """
  715. Returns a transferredDataset that transfer data through device.
  716. Args:
  717. prefetch_size (int, optional): prefetch number of records ahead of the
  718. user's request (default=None).
  719. Note:
  720. If device is Ascend, features of data will be transferred one by one. The limitation
  721. of data transmission per time is 256M.
  722. Return:
  723. TransferDataset, dataset for transferring.
  724. """
  725. return self.to_device()
  726. def to_device(self, num_batch=None):
  727. """
  728. Transfers data through CPU, GPU or Ascend devices.
  729. Args:
  730. num_batch (int, optional): limit the number of batch to be sent to device (default=None).
  731. Note:
  732. If device is Ascend, features of data will be transferred one by one. The limitation
  733. of data transmission per time is 256M.
  734. Returns:
  735. TransferDataset, dataset for transferring.
  736. Raises:
  737. TypeError: If device_type is empty.
  738. ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'.
  739. ValueError: If num_batch is None or 0 or larger than int_max.
  740. RuntimeError: If dataset is unknown.
  741. RuntimeError: If distribution file path is given but failed to read.
  742. """
  743. if num_batch is None:
  744. num_batch = self.get_dataset_size()
  745. repeat_count = self.get_repeat_count()
  746. num_batch = num_batch * repeat_count
  747. queue_name = str(uuid.uuid1())
  748. if context:
  749. device_type = context.get_context("device_target")
  750. else:
  751. device_type = "CPU"
  752. if device_type == "":
  753. raise TypeError("Please set device_type in context")
  754. if device_type not in ('Ascend', 'GPU', 'CPU'):
  755. raise ValueError("only support CPU, Ascend, GPU")
  756. if num_batch is None or num_batch == 0:
  757. raise ValueError("num_batch is None or 0.")
  758. def get_distribution(output_dataset):
  759. dev_id = 0
  760. if isinstance(output_dataset, (MindDataset)):
  761. return output_dataset.distribution, dev_id
  762. if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2,
  763. ManifestDataset, MnistDataset, VOCDataset, CelebADataset)):
  764. sampler = output_dataset.sampler
  765. if isinstance(sampler, samplers.DistributedSampler):
  766. dev_id = sampler.shard_id
  767. return "", dev_id
  768. if isinstance(output_dataset, TFRecordDataset):
  769. if output_dataset.shard_id is not None:
  770. dev_id = output_dataset.shard_id
  771. return "", dev_id
  772. if not output_dataset.input:
  773. raise RuntimeError("Unknown output_dataset: {}".format(type(output_dataset)))
  774. input_dataset = output_dataset.input[0]
  775. return get_distribution(input_dataset)
  776. distribution_path, device_id = get_distribution(self)
  777. if distribution_path == "":
  778. return TransferDataset(self, queue_name, device_id, device_type, num_batch)
  779. try:
  780. with open(distribution_path, 'r') as distribution_f:
  781. dist = json.load(distribution_f)
  782. device_id = dist["deviceId"]
  783. except json.decoder.JSONDecodeError:
  784. raise RuntimeError("Json decode error when load distribution file")
  785. except Exception:
  786. raise RuntimeError("Distribution file failed to read")
  787. return TransferDataset(self, queue_name, device_id, device_type, num_batch)
  788. def create_tuple_iterator(self, columns=None):
  789. """
  790. Create an Iterator over the dataset. The data retrieved will be a list of ndarray of data.
  791. To specify which columns to list and the order needed, use columns_list. If columns_list
  792. is not provided, the order of the columns will not be changed.
  793. Args:
  794. columns (list[str], optional): List of columns to be used to specify the order of columns
  795. (defaults=None, means all columns).
  796. Returns:
  797. Iterator, list of ndarray.
  798. Examples:
  799. >>> import mindspore.dataset as ds
  800. >>> # data is an instance of Dataset object
  801. >>> # creates an iterator. The columns in the data obtained by the
  802. >>> # iterator will not be changed.
  803. >>> iterator = data.create_tuple_iterator()
  804. >>> for item in iterator:
  805. >>> # convert the returned tuple to a list and print
  806. >>> print(list(item))
  807. """
  808. return TupleIterator(self, columns)
  809. def create_dict_iterator(self):
  810. """
  811. Create an Iterator over the dataset.
  812. The data retrieved will be a dictionary. The order
  813. of the columns in the dictionary may not be the same as the original order.
  814. Returns:
  815. Iterator, dictionary of column_name-ndarray pair.
  816. Examples:
  817. >>> import mindspore.dataset as ds
  818. >>> # data is an instance of Dataset object
  819. >>> # creates an iterator. The columns in the data obtained by the
  820. >>> # iterator might be changed.
  821. >>> iterator = data.create_dict_iterator()
  822. >>> for item in iterator:
  823. >>> # print the data in column1
  824. >>> print(item["column1"])
  825. """
  826. return DictIterator(self)
  827. def __iter__(self):
  828. """Create an Iterator over the dataset."""
  829. return self.create_tuple_iterator()
  830. @property
  831. def input_indexs(self):
  832. return self._input_indexs
  833. @input_indexs.setter
  834. def input_indexs(self, value):
  835. self._input_indexs = value
  836. def _get_pipeline_info(self):
  837. """
  838. Gets pipeline information.
  839. """
  840. device_iter = TupleIterator(self)
  841. self._output_shapes = device_iter.get_output_shapes()
  842. self._output_types = device_iter.get_output_types()
  843. if self._dataset_size is None:
  844. self._dataset_size = device_iter.get_dataset_size()
  845. self._batch_size = device_iter.get_batch_size()
  846. self._num_classes = device_iter.num_classes()
  847. self._repeat_count = device_iter.get_repeat_count()
  848. device_iter.release()
  849. def output_shapes(self):
  850. """
  851. Get the shapes of output data.
  852. Return:
  853. List, list of shape of each column.
  854. """
  855. if self._output_shapes is None:
  856. self._get_pipeline_info()
  857. return self._output_shapes
  858. def output_types(self):
  859. """
  860. Get the types of output data.
  861. Return:
  862. List of data type.
  863. """
  864. if self._output_types is None:
  865. self._get_pipeline_info()
  866. return self._output_types
  867. def get_dataset_size(self):
  868. """
  869. Get the number of batches in an epoch.
  870. Return:
  871. Number, number of batches.
  872. """
  873. if self.input:
  874. return self.input[0].get_dataset_size()
  875. return None
  876. def num_classes(self):
  877. """
  878. Get the number of classes in a dataset.
  879. Return:
  880. Number, number of classes.
  881. """
  882. if self.input:
  883. return self.input[0].num_classes()
  884. return None
  885. def get_sync_notifiers(self):
  886. if self.input:
  887. return self.input[0].get_sync_notifiers()
  888. return {}
  889. def is_sync(self):
  890. if self.input:
  891. return self.input[0].is_sync()
  892. return False
  893. def sync_update(self, condition_name, num_batch=None, data=None):
  894. """
  895. Release a blocking condition and triger callback with given data.
  896. Args:
  897. condition_name (str): The condition name that is used to toggle sending next row.
  898. num_batch (int or None): The number of batches(rows) that are released.
  899. When num_batch is None, it will default to the number specified by the sync_wait operator.
  900. data (dict or None): The data passed to the callback.
  901. """
  902. notifiers_dict = self.get_sync_notifiers()
  903. if condition_name not in notifiers_dict:
  904. raise RuntimeError("Condition name not found")
  905. if num_batch is not None:
  906. num_batch *= self.get_batch_size()
  907. notifiers_dict[condition_name](num_batch, data)
  908. def get_batch_size(self):
  909. """
  910. Get the size of a batch.
  911. Return:
  912. Number, the number of data in a batch.
  913. """
  914. if self.input:
  915. return self.input[0].get_batch_size()
  916. return 1
  917. def get_repeat_count(self):
  918. """
  919. Get the replication times in RepeatDataset else 1.
  920. Return:
  921. Number, the count of repeat.
  922. """
  923. if self.input:
  924. return self.input[0].get_repeat_count()
  925. return 1
  926. def get_class_indexing(self):
  927. """
  928. Get the class index.
  929. Return:
  930. Dict, A str-to-int mapping from label name to index.
  931. """
  932. if self.input:
  933. return self.input[0].get_class_indexing()
  934. raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self)))
  935. def reset(self):
  936. """Reset the dataset for next epoch."""
  937. def is_shuffled(self):
  938. for input_dataset in self.input:
  939. if input_dataset.is_shuffled():
  940. return True
  941. return False
  942. def is_sharded(self):
  943. for input_dataset in self.input:
  944. if input_dataset.is_sharded():
  945. return True
  946. return False
  947. class SourceDataset(Dataset):
  948. """
  949. Abstract class to represent a source dataset which produces content to the data pipeline.
  950. """
  951. # No need for __init__ since it is the same as the super's init
  952. @staticmethod
  953. def _find_files(patterns):
  954. """
  955. Utility function to search for files with the given glob patterns.
  956. Args:
  957. patterns (str or list[str]): string or list of patterns to be searched.
  958. Returns:
  959. List, files.
  960. """
  961. if not isinstance(patterns, list):
  962. patterns = [patterns]
  963. file_list = []
  964. unmatched_patterns = []
  965. for pattern in patterns:
  966. matches = [match for match in glob.glob(pattern, recursive=True) if os.path.isfile(match)]
  967. if matches:
  968. file_list.extend(matches)
  969. else:
  970. unmatched_patterns.append(pattern)
  971. if unmatched_patterns:
  972. raise ValueError("The following patterns did not match any files: ", unmatched_patterns)
  973. if file_list: # not empty
  974. return file_list
  975. raise ValueError("The list of path names matching the patterns is empty.")
  976. def is_shuffled(self):
  977. raise NotImplementedError("SourceDataset must implement is_shuffled.")
  978. def is_sharded(self):
  979. raise NotImplementedError("SourceDataset must implement is_sharded.")
  980. class MappableDataset(SourceDataset):
  981. """
  982. Abstract class to represent a source dataset which supports use of samplers.
  983. """
  984. def __init__(self, num_parallel_workers=None):
  985. # check if all subclasses use this name
  986. super().__init__(num_parallel_workers)
  987. self.sampler = None
  988. def add_sampler(self, new_sampler):
  989. # note: by adding a sampler, we mean that the sampled ids will flow to new_sampler
  990. # after first passing through the current samplers attached to this dataset.
  991. new_sampler.add_child(self.sampler)
  992. self.sampler = new_sampler
  993. def use_sampler(self, new_sampler):
  994. """
  995. Will make the current dataset use the new_sampler provided.
  996. Args:
  997. new_sampler (Sampler): the sampler to use for the current dataset.
  998. Returns:
  999. Dataset, that uses new_sampler.
  1000. Examples:
  1001. >>> import mindspore.dataset as ds
  1002. >>>
  1003. >>> dataset_dir = "/path/to/imagefolder_directory"
  1004. >>> # a SequentialSampler is created by default
  1005. >>> data = ds.ImageFolderDatasetV2(dataset_dir)
  1006. >>>
  1007. >>> # use a DistributedSampler instead of the SequentialSampler
  1008. >>> new_sampler = ds.DistributedSampler(10, 2)
  1009. >>> data.use_sampler(new_sampler)
  1010. """
  1011. self.sampler = self.sampler.child_sampler
  1012. self.add_sampler(new_sampler)
  1013. def is_shuffled(self):
  1014. raise NotImplementedError("MappableDataset must implement is_shuffled.")
  1015. def is_sharded(self):
  1016. raise NotImplementedError("MappableDataset must implement is_sharded.")
  1017. @check_split
  1018. def split(self, sizes, randomize=True):
  1019. """
  1020. Splits the dataset into smaller, non-overlapping datasets.
  1021. There is the optimized split function, which will be called automatically when the dataset
  1022. that calls this function is a MappableDataset.
  1023. Args:
  1024. sizes (list of int or list of float): If a list of integers [s1, s2, …, sn] is
  1025. provided, the dataset will be split into n datasets of size s1, size s2, …, size sn
  1026. respectively. If the sum of all sizes does not equal the original dataset size, an
  1027. an error will occur.
  1028. If a list of floats [f1, f2, …, fn] is provided, the dataset will be split into n
  1029. Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
  1030. of the original dataset. If after rounding, any size equals 0, an error will occur.
  1031. All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
  1032. randomize (bool): determines whether or not to split the data randomly. If true, the data
  1033. will be randomly split. Otherwise, each split will be created with consecutive rows
  1034. from the dataset.
  1035. Note:
  1036. 1. Dataset should not be sharded if split is going to be called. Instead, create a
  1037. DistributedSampler and specify a split to shard after splitting. If dataset is
  1038. sharded after a split, it is strongly recommended to set the same seed in each instance
  1039. of execution, otherwise each shard may not be part of the same split (see Examples)
  1040. 2. It is strongly recommended to not shuffle the dataset, but use randomize=True instead.
  1041. Shuffling the dataset may not be deterministic, which means the data in each split
  1042. will be different in each epoch. Furthermore, if sharding occurs after split, each
  1043. shard may not be part of the same split.
  1044. Raises:
  1045. RuntimeError: If get_dataset_size returns None or is not supported for this dataset.
  1046. RuntimeError: If sizes is list of integers and sum of all elements in sizes does not
  1047. equal the dataset size.
  1048. RuntimeError: If sizes is list of float and there is a split with size 0 after calculations.
  1049. RuntimeError: If the dataset is sharded prior to calling split.
  1050. ValueError: If sizes is list of float and not all floats are between 0 and 1, or if the
  1051. floats don’t sum to 1.
  1052. Returns
  1053. tuple(Dataset), a tuple of datasets that have been split.
  1054. Examples:
  1055. >>> import mindspore.dataset as ds
  1056. >>>
  1057. >>> dataset_dir = "/path/to/imagefolder_directory"
  1058. >>>
  1059. >>> # many datasets have shuffle on by default, set shuffle to False if split will be called!
  1060. >>> data = ds.ImageFolderDatasetV2(dataset_dir, shuffle=False)
  1061. >>>
  1062. >>> # sets the seed, and tells split to use this seed when randomizing. This
  1063. >>> # is needed because we are sharding later
  1064. >>> ds.config.set_seed(58)
  1065. >>> train, test = data.split([0.9, 0.1])
  1066. >>>
  1067. >>> # if we want to shard the train dataset, we can use a DistributedSampler
  1068. >>> train_sampler = ds.DistributedSampler(10, 2)
  1069. >>> train.use_sampler(train_sampler)
  1070. """
  1071. if self.is_shuffled():
  1072. logger.warning("dataset is shuffled before split.")
  1073. if self.is_sharded():
  1074. raise RuntimeError("dataset should not be sharded before split.")
  1075. absolute_sizes = self._get_absolute_split_sizes(sizes)
  1076. splits = []
  1077. current_split_start_index = 0
  1078. for size in absolute_sizes:
  1079. ds = copy.deepcopy(self)
  1080. if randomize:
  1081. # want to shuffle the same way every epoch before split, we are assuming
  1082. # that the user will call set_seed
  1083. random_sampler = samplers.RandomSampler()
  1084. random_sampler.reshuffle_each_epoch = False
  1085. ds.add_sampler(random_sampler)
  1086. subset_sampler = samplers.SubsetSampler(current_split_start_index, size)
  1087. ds.add_sampler(subset_sampler)
  1088. # add sequential sampler, so that if user calls use_sampler, we will
  1089. # get rid of the sequential sampler instead of something we need
  1090. ds.add_sampler(samplers.SequentialSampler())
  1091. splits.append(ds)
  1092. current_split_start_index += size
  1093. return tuple(splits)
  1094. class DatasetOp(Dataset):
  1095. """
  1096. Abstract class to represent a operations on dataset.
  1097. """
  1098. # No need for __init__ since it is the same as the super's init
  1099. class BatchDataset(DatasetOp):
  1100. """
  1101. The result of applying Batch operator to the input dataset.
  1102. Args:
  1103. input_dataset (Dataset): Input Dataset to be batched.
  1104. batch_size (int or function): The number of rows each batch is created with. An
  1105. int or callable which takes exactly 1 parameter, BatchInfo.
  1106. drop_remainder (bool, optional): Determines whether or not to drop the last
  1107. possibly incomplete batch (default=False). If True, and if there are less
  1108. than batch_size rows available to make the last batch, then those rows will
  1109. be dropped and not propagated to the child node.
  1110. num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel (default=None).
  1111. per_batch_map (callable, optional): Per batch map callable. A callable which takes
  1112. (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represent a batch of
  1113. Tensors on a given column. The number of lists should match with number of entries in input_columns. The
  1114. last parameter of the callable should always be a BatchInfo object.
  1115. input_columns (list of string, optional): List of names of the input columns. The size of the list should
  1116. match with signature of per_batch_map callable.
  1117. pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)}
  1118. would pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0.
  1119. """
  1120. def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None,
  1121. per_batch_map=None, input_columns=None, pad_info=None):
  1122. super().__init__(num_parallel_workers)
  1123. if BatchDataset._is_ancestor_of_repeat(input_dataset):
  1124. logger.warning("Repeat is located before batch, data from two epochs can be batched together.")
  1125. BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
  1126. self.batch_size = batch_size
  1127. self.drop_remainder = drop_remainder
  1128. self.per_batch_map = per_batch_map
  1129. self.input_columns = input_columns
  1130. self.pad_info = pad_info
  1131. self.input.append(input_dataset)
  1132. input_dataset.output.append(self)
  1133. self._input_indexs = input_dataset.input_indexs
  1134. def get_args(self):
  1135. args = super().get_args()
  1136. args["batch_size"] = self.batch_size
  1137. args["drop_remainder"] = self.drop_remainder
  1138. args["per_batch_map"] = self.per_batch_map
  1139. args["input_columns"] = self.input_columns
  1140. args["pad_info"] = self.pad_info
  1141. return args
  1142. def get_dataset_size(self):
  1143. """
  1144. Get the number of batches in an epoch.
  1145. Return:
  1146. Number, number of batches.
  1147. """
  1148. child_size = self.input[0].get_dataset_size()
  1149. if child_size is not None:
  1150. if self.drop_remainder:
  1151. return math.floor(child_size / self.batch_size)
  1152. return math.ceil(child_size / self.batch_size)
  1153. return None
  1154. def get_batch_size(self):
  1155. """
  1156. Get the size of a batch.
  1157. Return:
  1158. Number, the number of data in a batch.
  1159. """
  1160. return self.batch_size
  1161. @staticmethod
  1162. def _is_ancestor_of_repeat(dataset):
  1163. """
  1164. Utility function to find the case where repeat is used before batch.
  1165. Args:
  1166. dataset (Dataset): dataset to be checked.
  1167. Return:
  1168. True or False.
  1169. """
  1170. if isinstance(dataset, RepeatDataset):
  1171. return True
  1172. flag = False
  1173. for input_dataset in dataset.input:
  1174. flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset)
  1175. return flag
  1176. @staticmethod
  1177. def _update_batch_size_for_syncwait(dataset, batch_size):
  1178. """
  1179. Utility function to notify batch size to sync_wait.
  1180. Args:
  1181. dataset (Dataset): dataset to be checked.
  1182. batchsize (int): batch size to notify.
  1183. """
  1184. if isinstance(dataset, SyncWaitDataset):
  1185. dataset.update_sync_batch_size(batch_size)
  1186. for input_dataset in dataset.input:
  1187. BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
  1188. class BatchInfo(CBatchInfo):
  1189. """
  1190. The information object associates with the current batch of tensors.
  1191. """
  1192. def get_batch_num(self):
  1193. """
  1194. Return the batch number of the current batch.
  1195. Return:
  1196. Number, number of the current batch.
  1197. """
  1198. return
  1199. def get_epoch_num(self):
  1200. """
  1201. Return the epoch number of the current batch.
  1202. Return:
  1203. Number, number of the current epoch.
  1204. """
  1205. return
  1206. class BlockReleasePair:
  1207. """
  1208. The blocking condition class used by SyncWaitDataset.
  1209. Args:
  1210. init_release_rows (int): Number of lines to allow through the pipeline.
  1211. callback (function): The callback funciton that will be called when release is called.
  1212. """
  1213. def __init__(self, init_release_rows, callback=None):
  1214. self.row_count = -init_release_rows
  1215. self.cv = threading.Condition()
  1216. self.callback = callback
  1217. self.default_rows = init_release_rows
  1218. def __deepcopy__(self, memodict):
  1219. if id(self) in memodict:
  1220. return memodict[id(self)]
  1221. memodict[id(self)] = self
  1222. # condition variable and callback are the same, but reset the counter
  1223. self.reset()
  1224. return self
  1225. def reset(self):
  1226. with self.cv:
  1227. self.row_count = -self.default_rows
  1228. self.cv.notify_all()
  1229. def update_batched_size(self, batch_size):
  1230. # should only use before the pipeline creates
  1231. self.row_count *= batch_size
  1232. self.default_rows *= batch_size
  1233. def block_func(self):
  1234. with self.cv:
  1235. self.cv.wait_for(lambda: self.row_count < 0)
  1236. self.row_count += 1
  1237. return True
  1238. def release_func(self, pass_rows=None, data=None):
  1239. with self.cv:
  1240. if pass_rows is None:
  1241. pass_rows = self.default_rows
  1242. self.row_count -= pass_rows
  1243. if self.callback is not None:
  1244. self.callback(data)
  1245. self.cv.notify_all()
  1246. class SyncWaitDataset(DatasetOp):
  1247. """
  1248. The result of adding a blocking condition to the input Dataset.
  1249. Args:
  1250. input_dataset (Dataset): Input dataset to apply flow control.
  1251. num_batch (int): the number of batches without blocking at the start of each epoch.
  1252. condition_name (str): The condition name that is used to toggle sending next row.
  1253. callback (function): The callback funciton that will be invoked when sync_update is called.
  1254. Raises:
  1255. RuntimeError: If condition name already exists.
  1256. """
  1257. def __init__(self, input_dataset, condition_name, num_batch, callback=None):
  1258. super().__init__()
  1259. self.input.append(input_dataset)
  1260. input_dataset.output.append(self)
  1261. # set to the default value, waiting for the batch to update it
  1262. self._condition_name = condition_name
  1263. self._pair = BlockReleasePair(num_batch, callback)
  1264. if self._condition_name in self.input[0].get_sync_notifiers():
  1265. raise RuntimeError("Condition name is already in use")
  1266. logger.warning("Please remember to add dataset.sync_update(condition=%s), otherwise will result in hanging",
  1267. condition_name)
  1268. def get_sync_notifiers(self):
  1269. return {**self.input[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}}
  1270. def is_sync(self):
  1271. return True
  1272. def get_args(self):
  1273. args = super().get_args()
  1274. args["condition_name"] = self._condition_name
  1275. args["condition_func"] = self._pair.block_func
  1276. return args
  1277. def update_sync_batch_size(self, batch_size):
  1278. self._pair.update_batched_size(batch_size)
  1279. @staticmethod
  1280. def _is_ancestor_of_batch(dataset):
  1281. """
  1282. Utility function to find the case where sync_wait is used before batch.
  1283. Args:
  1284. dataset (Dataset): dataset to be checked.
  1285. Return:
  1286. True or False.
  1287. """
  1288. if isinstance(dataset, BatchDataset):
  1289. return True
  1290. flag = False
  1291. for input_dataset in dataset.input:
  1292. flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset)
  1293. return flag
  1294. class ShuffleDataset(DatasetOp):
  1295. """
  1296. The result of applying Shuffle operator to the input Dataset.
  1297. Args:
  1298. input_dataset (Dataset): Input Dataset to be shuffled.
  1299. buffer_size (int): The size of the buffer.
  1300. Raises:
  1301. RuntimeError: If exist sync operators before shuffle.
  1302. """
  1303. def __init__(self, input_dataset, buffer_size):
  1304. super().__init__()
  1305. self.buffer_size = buffer_size
  1306. self.input.append(input_dataset)
  1307. self.reshuffle_each_epoch = None
  1308. input_dataset.output.append(self)
  1309. self._input_indexs = input_dataset.input_indexs
  1310. if self.is_sync():
  1311. raise RuntimeError("No shuffle after sync operators")
  1312. def get_args(self):
  1313. args = super().get_args()
  1314. args["buffer_size"] = self.buffer_size
  1315. if self.reshuffle_each_epoch is not None:
  1316. args["reshuffle_each_epoch"] = self.reshuffle_each_epoch
  1317. return args
  1318. def is_shuffled(self):
  1319. return True
  1320. # Pyfunc collection for multiprocess pyfunc
  1321. # This global variable will only be used within subprocesses
  1322. _GLOBAL_PYFUNC_LIST = []
  1323. # Pyfunc worker init function
  1324. # Python multiprocessing library forbid sending lambda function through pipe.
  1325. # This init function allow us to add all python function to a global collection and then fork afterwards.
  1326. def _pyfunc_worker_init(pyfunc_list):
  1327. global _GLOBAL_PYFUNC_LIST
  1328. _GLOBAL_PYFUNC_LIST = pyfunc_list
  1329. # Pyfunc worker execution function
  1330. # All exceptions will be raised to main processes
  1331. def _pyfunc_worker_exec(index, *args):
  1332. try:
  1333. return _GLOBAL_PYFUNC_LIST[index](*args)
  1334. except KeyboardInterrupt:
  1335. raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
  1336. # PythonCallable wrapper for multiprocess pyfunc
  1337. class _PythonCallable:
  1338. """
  1339. Internal python function wrapper for multiprocessing pyfunc.
  1340. """
  1341. def __init__(self, py_callable, idx, pool=None):
  1342. # Original python callable from user.
  1343. self.py_callable = py_callable
  1344. # Process pool created for current iterator.
  1345. self.pool = pool
  1346. # Python callable index for subprocess _GLOBAL_PYFUNC_LIST
  1347. self.idx = idx
  1348. def __call__(self, *args):
  1349. if self.pool is not None:
  1350. try:
  1351. # This call will send the tensors along with Python callable index to the process pool.
  1352. # Block, yield GIL. Current thread will reacquire GIL once result is returned.
  1353. return self.pool.apply(_pyfunc_worker_exec, [self.idx, *args])
  1354. except KeyboardInterrupt:
  1355. self.pool.terminate()
  1356. self.pool.join()
  1357. raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
  1358. # Invoke original python callable in master process in case the pool is gone.
  1359. return self.py_callable(*args)
  1360. class MapDataset(DatasetOp):
  1361. """
  1362. The result of applying Map operator to the input Dataset.
  1363. Args:
  1364. input_dataset (Dataset): Input Dataset to be mapped.
  1365. input_columns (list[str]): List of names of the input columns
  1366. (default=None, the operations will be applied on the first columns in the dataset).
  1367. The size of the list should match the number of inputs of the first operator.
  1368. operations (TensorOp): A function mapping a nested structure of tensors
  1369. to another nested structure of tensor (default=None).
  1370. output_columns (list[str], optional): list of names of the output columns.
  1371. The size of the list should match the number of outputs of the last operator
  1372. (default=None, output columns will be the input columns, i.e., the columns will
  1373. be replaced).
  1374. columns_order (list[str], optional): list of all the desired columns of the dataset (default=None).
  1375. The argument is mandatory if len(input_columns) != len(output_columns).
  1376. num_parallel_workers (int, optional): Number of workers to process the Dataset
  1377. in parallel (default=None).
  1378. python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
  1379. option could be beneficial if the python operation is computational heavy (default=False).
  1380. Raises:
  1381. ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
  1382. """
  1383. def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
  1384. num_parallel_workers=None, python_multiprocessing=False):
  1385. super().__init__(num_parallel_workers)
  1386. self.input.append(input_dataset)
  1387. if input_columns is not None and not isinstance(input_columns, list):
  1388. input_columns = [input_columns]
  1389. self.input_columns = input_columns
  1390. if operations is not None and not isinstance(operations, list):
  1391. operations = [operations]
  1392. self.operations = operations
  1393. if output_columns is not None and not isinstance(output_columns, list):
  1394. output_columns = [output_columns]
  1395. self.output_columns = output_columns
  1396. self.columns_order = columns_order
  1397. if self.input_columns and self.output_columns \
  1398. and len(self.input_columns) != len(self.output_columns) \
  1399. and self.columns_order is None:
  1400. raise ValueError("When (len(input_columns) != len(output_columns)), columns_order must be specified.")
  1401. input_dataset.output.append(self)
  1402. self._input_indexs = input_dataset.input_indexs
  1403. self.python_multiprocessing = python_multiprocessing
  1404. self.process_pool = None
  1405. def get_args(self):
  1406. args = super().get_args()
  1407. args["input_columns"] = self.input_columns
  1408. args["operations"] = self.operations
  1409. args["output_columns"] = self.output_columns
  1410. return args
  1411. def get_dataset_size(self):
  1412. """
  1413. Get the number of batches in an epoch.
  1414. Return:
  1415. Number, number of batches.
  1416. """
  1417. return self.input[0].get_dataset_size()
  1418. def __deepcopy__(self, memodict):
  1419. if id(self) in memodict:
  1420. return memodict[id(self)]
  1421. cls = self.__class__
  1422. new_op = cls.__new__(cls)
  1423. memodict[id(self)] = new_op
  1424. new_op.input = copy.deepcopy(self.input, memodict)
  1425. new_op.input_columns = copy.deepcopy(self.input_columns, memodict)
  1426. new_op.output_columns = copy.deepcopy(self.output_columns, memodict)
  1427. new_op.columns_order = copy.deepcopy(self.columns_order, memodict)
  1428. new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
  1429. new_op.output = copy.deepcopy(self.output, memodict)
  1430. new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict)
  1431. new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
  1432. new_op.operations = self.operations
  1433. return new_op
  1434. # Iterator bootstrap will be called on iterator construction.
  1435. # A deep copy of Dataset object is created prior of iterator_bootstrap.
  1436. # This method will create per iterator process pool and bind pyfunc execution to the pool.
  1437. def iterator_bootstrap(self):
  1438. """
  1439. Per iterator bootstrap callback.
  1440. """
  1441. if self.python_multiprocessing:
  1442. iter_specific_operations = []
  1443. callable_list = []
  1444. # Pass #1, look for python callables and build list
  1445. for op in self.operations:
  1446. if callable(op):
  1447. callable_list.append(op)
  1448. if callable_list:
  1449. # Construct pool with the callable list
  1450. # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
  1451. self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
  1452. initializer=_pyfunc_worker_init,
  1453. initargs=(callable_list,))
  1454. # Pass #2
  1455. idx = 0
  1456. for op in self.operations:
  1457. if callable(op):
  1458. # Wrap python callable into _PythonCallable
  1459. iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool))
  1460. idx += 1
  1461. else:
  1462. # CPP ops remain the same
  1463. iter_specific_operations.append(op)
  1464. self.operations = iter_specific_operations
  1465. def __del__(self):
  1466. if hasattr(self, 'process_pool') and self.process_pool is not None:
  1467. self.process_pool.terminate()
  1468. class FilterDataset(DatasetOp):
  1469. """
  1470. The result of applying filter predicate to the input Dataset.
  1471. Args:
  1472. input_dataset: Input Dataset to be mapped.
  1473. predicate: python callable which returns a boolean value, if False then filter the element.
  1474. input_columns: (list[str]): List of names of the input columns, when
  1475. default=None, the predicate will be applied all columns in the dataset.
  1476. num_parallel_workers (int, optional): Number of workers to process the Dataset
  1477. in parallel (default=None).
  1478. """
  1479. def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None):
  1480. super().__init__(num_parallel_workers)
  1481. self.predicate = lambda *args: bool(predicate(*args))
  1482. self.input.append(input_dataset)
  1483. input_dataset.output.append(self)
  1484. if input_columns is not None and not isinstance(input_columns, list):
  1485. input_columns = [input_columns]
  1486. self.input_columns = input_columns
  1487. def get_args(self):
  1488. args = super().get_args()
  1489. args["predicate"] = self.predicate
  1490. args["input_columns"] = self.input_columns
  1491. return args
  1492. def get_dataset_size(self):
  1493. """
  1494. Get the number of batches in an epoch.
  1495. the size cannot be determined before we run the pipeline.
  1496. Return:
  1497. 0
  1498. """
  1499. return 0
  1500. class RepeatDataset(DatasetOp):
  1501. """
  1502. The result of applying Repeat operator to the input Dataset.
  1503. Args:
  1504. input_dataset (Dataset): Input Dataset to be repeated.
  1505. count (int): Number of times the dataset should be repeated.
  1506. """
  1507. def __init__(self, input_dataset, count):
  1508. super().__init__()
  1509. if count is None:
  1510. self.count = -1
  1511. else:
  1512. self.count = count
  1513. self.input.append(input_dataset)
  1514. input_dataset.output.append(self)
  1515. self._input_indexs = input_dataset.input_indexs
  1516. def get_args(self):
  1517. args = super().get_args()
  1518. args["count"] = self.count
  1519. return args
  1520. def get_dataset_size(self):
  1521. """
  1522. Get the number of batches in an epoch.
  1523. Return:
  1524. Number, number of batches.
  1525. """
  1526. child_size = self.input[0].get_dataset_size()
  1527. if child_size is not None:
  1528. return child_size
  1529. return None
  1530. def get_repeat_count(self):
  1531. """
  1532. Get the replication times in RepeatDataset.
  1533. Return:
  1534. Number, the count of repeat.
  1535. """
  1536. return self.count
  1537. class SkipDataset(DatasetOp):
  1538. """
  1539. The result of applying Skip operator to the input Dataset.
  1540. Args:
  1541. datasets (tuple): A tuple of datasets to be skipped.
  1542. count (int): Number of rows the dataset should be skipped.
  1543. """
  1544. def __init__(self, input_dataset, count):
  1545. super().__init__()
  1546. self.count = count
  1547. self.input.append(input_dataset)
  1548. input_dataset.output.append(self)
  1549. self._input_indexs = input_dataset.input_indexs
  1550. def get_args(self):
  1551. args = super().get_args()
  1552. args["count"] = self.count
  1553. return args
  1554. def get_dataset_size(self):
  1555. """
  1556. Get the number of batches in an epoch.
  1557. Return:
  1558. Number, number of batches.
  1559. """
  1560. child_size = self.input[0].get_dataset_size()
  1561. output_size = 0
  1562. if self.count >= 0 and self.count < child_size:
  1563. output_size = child_size - self.count
  1564. return output_size
  1565. class TakeDataset(DatasetOp):
  1566. """
  1567. The result of applying Take operator to the input Dataset.
  1568. Args:
  1569. input_dataset (Dataset): Input Dataset to be taken element from.
  1570. count (int): Number of elements to be taken from the dataset.
  1571. """
  1572. def __init__(self, input_dataset, count):
  1573. super().__init__()
  1574. self.count = count
  1575. self.input.append(input_dataset)
  1576. input_dataset.output.append(self)
  1577. self._input_indexs = input_dataset.input_indexs
  1578. def get_args(self):
  1579. args = super().get_args()
  1580. args["count"] = self.count
  1581. return args
  1582. def get_dataset_size(self):
  1583. """
  1584. Get the number of batches in an epoch.
  1585. Return:
  1586. Number, number of batches.
  1587. """
  1588. child_size = self.input[0].get_dataset_size()
  1589. if child_size < self.count:
  1590. return child_size
  1591. return self.count
  1592. class ZipDataset(DatasetOp):
  1593. """
  1594. The result of applying Zip operator to the input Dataset.
  1595. Args:
  1596. datasets (tuple): A tuple of datasets to be zipped together.
  1597. Raises:
  1598. TypeError: If dataset is not an instance of Dataset.
  1599. """
  1600. def __init__(self, datasets):
  1601. super().__init__()
  1602. for dataset in datasets:
  1603. if not isinstance(dataset, Dataset):
  1604. raise TypeError("The parameter %s of zip has type error!" % (dataset))
  1605. self.datasets = datasets
  1606. for data in datasets:
  1607. self.input.append(data)
  1608. data.output.append(self)
  1609. def get_dataset_size(self):
  1610. """
  1611. Get the number of batches in an epoch.
  1612. Return:
  1613. Number, number of batches.
  1614. """
  1615. children_sizes = [c.get_dataset_size() for c in self.input]
  1616. if all(c is not None for c in children_sizes):
  1617. return min(children_sizes)
  1618. return None
  1619. def num_classes(self):
  1620. """
  1621. Get the number of classes in a dataset.
  1622. Return:
  1623. Number, number of classes.
  1624. """
  1625. return None
  1626. def is_sync(self):
  1627. return any([c.is_sync() for c in self.input])
  1628. def get_args(self):
  1629. args = super().get_args()
  1630. return args
  1631. class ConcatDataset(DatasetOp):
  1632. """
  1633. The result of applying concat dataset operator to the input Dataset.
  1634. Args:
  1635. datasets (list): A list of datasets to be concatenated together.
  1636. Raises:
  1637. TypeError: If dataset is not an instance of Dataset.
  1638. """
  1639. def __init__(self, datasets):
  1640. super().__init__()
  1641. for dataset in datasets:
  1642. if not isinstance(dataset, Dataset):
  1643. raise TypeError("The parameter %s of concat has type error!" % (dataset))
  1644. self.datasets = datasets
  1645. for data in datasets:
  1646. self.input.append(data)
  1647. data.output.append(self)
  1648. def get_dataset_size(self):
  1649. """
  1650. Get the number of batches in an epoch.
  1651. Return:
  1652. Number, number of batches.
  1653. """
  1654. children_sizes = [c.get_dataset_size() for c in self.input]
  1655. dataset_size = np.sum(children_sizes)
  1656. return dataset_size
  1657. class RenameDataset(DatasetOp):
  1658. """
  1659. The result of applying Rename operator to the input Dataset.
  1660. Args:
  1661. input_dataset (Dataset): Input Dataset to be Renamed.
  1662. input_column_names (list[str]): list of names of the input columns.
  1663. output_column_names (list[str]): list of names of the output columns.
  1664. """
  1665. def __init__(self, input_dataset, input_columns, output_columns):
  1666. super().__init__()
  1667. if not isinstance(input_columns, list):
  1668. input_columns = [input_columns]
  1669. if not isinstance(output_columns, list):
  1670. output_columns = [output_columns]
  1671. self.input_column_names = input_columns
  1672. self.output_column_names = output_columns
  1673. self.input.append(input_dataset)
  1674. input_dataset.output.append(self)
  1675. self._input_indexs = input_dataset.input_indexs
  1676. def get_args(self):
  1677. args = super().get_args()
  1678. args["input_columns"] = self.input_column_names
  1679. args["output_columns"] = self.output_column_names
  1680. return args
  1681. class ProjectDataset(DatasetOp):
  1682. """
  1683. The result of applying Project operator to the input Dataset.
  1684. Args:
  1685. input_dataset (Dataset): Input Dataset to be Project.
  1686. columns (list[str]): List of names of the columns to project.
  1687. prefetch_size (int, optional): Prefetch number of records ahead of the
  1688. user's request (default=None).
  1689. """
  1690. def __init__(self, input_dataset, columns, prefetch_size=None):
  1691. super().__init__()
  1692. if not isinstance(columns, list):
  1693. columns = [columns]
  1694. self.columns = columns
  1695. self.input.append(input_dataset)
  1696. self.prefetch_size = prefetch_size
  1697. input_dataset.output.append(self)
  1698. self._input_indexs = input_dataset.input_indexs
  1699. def get_args(self):
  1700. args = super().get_args()
  1701. args["columns"] = self.columns
  1702. args["prefetch_size"] = self.prefetch_size
  1703. return args
  1704. class TransferDataset(DatasetOp):
  1705. """
  1706. The result of applying TDT operator to the input Dataset.
  1707. Args:
  1708. input_dataset (Dataset): Input Dataset to be transferred.
  1709. queue_name (str): Name of device queue.
  1710. device_id (int): Id of device.
  1711. device_type (str): Type of device, including "CPU", "GPU", and "Ascend".
  1712. num_batch (int): limit the number of batch to be sent to device (default=None).
  1713. """
  1714. def __init__(self, input_dataset, queue_name, device_id, device_type, num_batch=None):
  1715. super().__init__()
  1716. self.input.append(input_dataset)
  1717. input_dataset.output.append(self)
  1718. self.queue_name = queue_name
  1719. self._input_indexs = input_dataset.input_indexs
  1720. self._device_type = device_type
  1721. self._device_id = device_id
  1722. self.__num_batch = num_batch
  1723. self.iterator = None
  1724. def get_args(self):
  1725. args = super().get_args()
  1726. args["queue_name"] = self.queue_name
  1727. args["device_type"] = self._device_type
  1728. args["device_id"] = self._device_id
  1729. args["num_batch"] = self.__num_batch
  1730. return args
  1731. def create_dict_iterator(self):
  1732. raise RuntimeError("TransferDataset is not iterable")
  1733. def create_tuple_iterator(self, columns=None):
  1734. raise RuntimeError("TransferDataset is not iterable")
  1735. def __iter__(self):
  1736. raise RuntimeError("TransferDataset is not iterable")
  1737. def output_shapes(self):
  1738. raise RuntimeError("TransferDataset does not support output_shapes")
  1739. def output_types(self):
  1740. raise RuntimeError("TransferDataset does not support output_types")
  1741. def send(self):
  1742. # need to keep iterator alive so the executionTree is not destroyed
  1743. self.iterator = TupleIterator(self)
  1744. class RangeDataset(MappableDataset):
  1745. """
  1746. A source dataset that reads and parses datasets stored on disk in a range.
  1747. Args:
  1748. start (int): starting index.
  1749. stop (int): ending index.
  1750. step (int): step size in a range.
  1751. """
  1752. def __init__(self, start, stop, step):
  1753. super().__init__()
  1754. self.start = start
  1755. self.stop = stop
  1756. self.step = step
  1757. def get_args(self):
  1758. args = super().get_args()
  1759. args["start"] = self.start
  1760. args["stop"] = self.stop
  1761. args["step"] = self.step
  1762. return args
  1763. def is_shuffled(self):
  1764. return False
  1765. def is_sharded(self):
  1766. return False
  1767. def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
  1768. """
  1769. Create sampler based on user input.
  1770. Args:
  1771. num_samples (int): Number of samples.
  1772. input_sampler (Iterable / Sampler): Sampler from user.
  1773. shuffle (bool): Shuffle.
  1774. num_shards (int): Number of shard for sharding.
  1775. shard_id (int): Shard ID.
  1776. """
  1777. if shuffle is None:
  1778. if input_sampler is not None:
  1779. # If shuffle is not specified, user provided sampler, use user's sampler
  1780. return input_sampler
  1781. if num_shards is not None:
  1782. # If shuffle is not specified, sharding enabled, use distributed random sampler
  1783. shuffle = True
  1784. return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle)
  1785. # If shuffle is not specified, sharding disabled, use random sampler
  1786. if num_samples is not None:
  1787. return samplers.RandomSampler(replacement=True, num_samples=num_samples)
  1788. return samplers.RandomSampler()
  1789. if shuffle is True:
  1790. if num_shards is not None:
  1791. # If shuffle enabled, sharding enabled, use distributed random sampler
  1792. return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle)
  1793. # If shuffle enabled, sharding disabled, use random sampler
  1794. if num_samples is not None:
  1795. return samplers.RandomSampler(replacement=True, num_samples=num_samples)
  1796. return samplers.RandomSampler()
  1797. if num_shards is not None:
  1798. # If shuffle disabled, sharding enabled, use distributed sequential sampler
  1799. return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle)
  1800. # If shuffle disabled, sharding disabled, use sequential sampler
  1801. return samplers.SequentialSampler()
  1802. class ImageFolderDatasetV2(MappableDataset):
  1803. """
  1804. A source dataset that reads images from a tree of directories.
  1805. All images within one folder have the same label.
  1806. The generated dataset has two columns ['image', 'label'].
  1807. The shape of the image column is [image_size] if decode flag is False, or [H,W,C]
  1808. otherwise.
  1809. The type of the image tensor is uint8. The label is just a scalar uint64
  1810. tensor.
  1811. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  1812. below shows what input args are allowed and their expected behavior.
  1813. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  1814. :widths: 25 25 50
  1815. :header-rows: 1
  1816. * - Parameter 'sampler'
  1817. - Parameter 'shuffle'
  1818. - Expected Order Behavior
  1819. * - None
  1820. - None
  1821. - random order
  1822. * - None
  1823. - True
  1824. - random order
  1825. * - None
  1826. - False
  1827. - sequential order
  1828. * - Sampler object
  1829. - None
  1830. - order defined by sampler
  1831. * - Sampler object
  1832. - True
  1833. - not allowed
  1834. * - Sampler object
  1835. - False
  1836. - not allowed
  1837. Args:
  1838. dataset_dir (str): Path to the root directory that contains the dataset.
  1839. num_samples (int, optional): The number of images to be included in the dataset
  1840. (default=None, all images).
  1841. num_parallel_workers (int, optional): Number of workers to read the data
  1842. (default=None, set in the config).
  1843. shuffle (bool, optional): Whether or not to perform shuffle on the dataset
  1844. (default=None, expected order behavior shown in the table).
  1845. sampler (Sampler, optional): Object used to choose samples from the
  1846. dataset (default=None, expected order behavior shown in the table).
  1847. extensions (list[str], optional): List of file extensions to be
  1848. included in the dataset (default=None).
  1849. class_indexing (dict, optional): A str-to-int mapping from folder name to index
  1850. (default=None, the folder names will be sorted
  1851. alphabetically and each class will be given a
  1852. unique index starting from 0).
  1853. decode (bool, optional): decode the images after reading (default=False).
  1854. num_shards (int, optional): Number of shards that the dataset should be divided
  1855. into (default=None).
  1856. shard_id (int, optional): The shard ID within num_shards (default=None). This
  1857. argument should be specified only when num_shards is also specified.
  1858. Raises:
  1859. RuntimeError: If sampler and shuffle are specified at the same time.
  1860. RuntimeError: If sampler and sharding are specified at the same time.
  1861. RuntimeError: If num_shards is specified but shard_id is None.
  1862. RuntimeError: If shard_id is specified but num_shards is None.
  1863. RuntimeError: If class_indexing is not a dictionary.
  1864. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  1865. Examples:
  1866. >>> import mindspore.dataset as ds
  1867. >>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images
  1868. >>> dataset_dir = "/path/to/imagefolder_directory"
  1869. >>> # 1) read all samples (image files) in dataset_dir with 8 threads
  1870. >>> imagefolder_dataset = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8)
  1871. >>> # 2) read all samples (image files) from folder cat and folder dog with label 0 and 1
  1872. >>> imagefolder_dataset = ds.ImageFolderDatasetV2(dataset_dir,class_indexing={"cat":0,"dog":1})
  1873. >>> # 3) read all samples (image files) in dataset_dir with extensions .JPEG and .png (case sensitive)
  1874. >>> imagefolder_dataset = ds.ImageFolderDatasetV2(dataset_dir, extensions={".JPEG",".png"})
  1875. """
  1876. @check_imagefolderdatasetv2
  1877. def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
  1878. shuffle=None, sampler=None, extensions=None, class_indexing=None,
  1879. decode=False, num_shards=None, shard_id=None):
  1880. super().__init__(num_parallel_workers)
  1881. self.dataset_dir = dataset_dir
  1882. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  1883. self.num_samples = num_samples
  1884. self.shuffle_level = shuffle
  1885. self.extensions = extensions
  1886. self.class_indexing = class_indexing
  1887. self.decode = decode
  1888. self.num_shards = num_shards
  1889. self.shard_id = shard_id
  1890. def get_args(self):
  1891. args = super().get_args()
  1892. args["dataset_dir"] = self.dataset_dir
  1893. args["num_samples"] = self.num_samples
  1894. args["sampler"] = self.sampler
  1895. args["shuffle"] = self.shuffle_level
  1896. args["extensions"] = self.extensions
  1897. args["class_indexing"] = self.class_indexing
  1898. args["decode"] = self.decode
  1899. args["num_shards"] = self.num_shards
  1900. args["shard_id"] = self.shard_id
  1901. return args
  1902. def get_dataset_size(self):
  1903. """
  1904. Get the number of batches in an epoch.
  1905. Return:
  1906. Number, number of batches.
  1907. """
  1908. if self.num_samples is None:
  1909. num_samples = 0
  1910. else:
  1911. num_samples = self.num_samples
  1912. num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0]
  1913. return get_num_rows(num_rows, self.num_shards)
  1914. def num_classes(self):
  1915. """
  1916. Get the number of classes in dataset.
  1917. Return:
  1918. Number, number of classes.
  1919. """
  1920. if self.num_samples is None:
  1921. num_samples = 0
  1922. else:
  1923. num_samples = self.num_samples
  1924. return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[1]
  1925. def is_shuffled(self):
  1926. if self.shuffle_level is None:
  1927. return True
  1928. return self.shuffle_level or self.sampler.is_shuffled()
  1929. def is_sharded(self):
  1930. if self.num_shards is not None:
  1931. return self.num_shards > 1
  1932. return self.sampler.is_sharded()
  1933. class MnistDataset(MappableDataset):
  1934. """
  1935. A source dataset for reading and parsing the Mnist dataset.
  1936. The generated dataset has two columns ['image', 'label'].
  1937. The type of the image tensor is uint8. The label is just a scalar uint32 tensor.
  1938. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  1939. below shows what input args are allowed and their expected behavior.
  1940. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  1941. :widths: 25 25 50
  1942. :header-rows: 1
  1943. * - Parameter 'sampler'
  1944. - Parameter 'shuffle'
  1945. - Expected Order Behavior
  1946. * - None
  1947. - None
  1948. - random order
  1949. * - None
  1950. - True
  1951. - random order
  1952. * - None
  1953. - False
  1954. - sequential order
  1955. * - Sampler object
  1956. - None
  1957. - order defined by sampler
  1958. * - Sampler object
  1959. - True
  1960. - not allowed
  1961. * - Sampler object
  1962. - False
  1963. - not allowed
  1964. Args:
  1965. dataset_dir (str): Path to the root directory that contains the dataset.
  1966. num_samples (int, optional): The number of images to be included in the dataset
  1967. (default=None, all images).
  1968. num_parallel_workers (int, optional): Number of workers to read the data
  1969. (default=value, set in the config).
  1970. shuffle (bool, optional): Whether or not to perform shuffle on the dataset
  1971. (default=None, expected order behavior shown in the table).
  1972. sampler (Sampler, optional): Object used to choose samples from the
  1973. dataset (default=None, expected order behavior shown in the table).
  1974. num_shards (int, optional): Number of shards that the dataset should be divided
  1975. into (default=None).
  1976. shard_id (int, optional): The shard ID within num_shards (default=None). This
  1977. argument should be specified only when num_shards is also specified.
  1978. Raises:
  1979. RuntimeError: If sampler and shuffle are specified at the same time.
  1980. RuntimeError: If sampler and sharding are specified at the same time.
  1981. RuntimeError: If num_shards is specified but shard_id is None.
  1982. RuntimeError: If shard_id is specified but num_shards is None.
  1983. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  1984. Examples:
  1985. >>> import mindspore.dataset as ds
  1986. >>> dataset_dir = "/path/to/mnist_folder"
  1987. >>> # 1) read 3 samples from mnist_dataset
  1988. >>> mnist_dataset = ds.MnistDataset(dataset_dir=dataset_dir, num_samples=3)
  1989. >>> # in mnist_dataset dataset, each dictionary has keys "image" and "label"
  1990. """
  1991. @check_mnist_cifar_dataset
  1992. def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
  1993. shuffle=None, sampler=None, num_shards=None, shard_id=None):
  1994. super().__init__(num_parallel_workers)
  1995. self.dataset_dir = dataset_dir
  1996. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  1997. self.num_samples = num_samples
  1998. self.shuffle_level = shuffle
  1999. self.num_shards = num_shards
  2000. self.shard_id = shard_id
  2001. def get_args(self):
  2002. args = super().get_args()
  2003. args["dataset_dir"] = self.dataset_dir
  2004. args["num_samples"] = self.num_samples
  2005. args["shuffle"] = self.shuffle_level
  2006. args["sampler"] = self.sampler
  2007. args["num_shards"] = self.num_shards
  2008. args["shard_id"] = self.shard_id
  2009. return args
  2010. def get_dataset_size(self):
  2011. """
  2012. Get the number of batches in an epoch.
  2013. Return:
  2014. Number, number of batches.
  2015. """
  2016. if self.num_samples is None:
  2017. num_samples = 0
  2018. else:
  2019. num_samples = self.num_samples
  2020. num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples)
  2021. return get_num_rows(num_rows, self.num_shards)
  2022. def is_shuffled(self):
  2023. if self.shuffle_level is None:
  2024. return True
  2025. return self.shuffle_level or self.sampler.is_shuffled()
  2026. def is_sharded(self):
  2027. if self.num_shards is not None:
  2028. return self.num_shards > 1
  2029. return self.sampler.is_sharded()
  2030. class MindDataset(SourceDataset):
  2031. """
  2032. A source dataset that reads from shard files and database.
  2033. Args:
  2034. dataset_file (str, list[str]): One of file names or file list in dataset.
  2035. columns_list (list[str], optional): List of columns to be read (default=None).
  2036. num_parallel_workers (int, optional): The number of readers (default=None).
  2037. shuffle (bool, optional): Whether or not to perform shuffle on the dataset
  2038. (default=None, performs shuffle).
  2039. num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
  2040. shard_id (int, optional): The shard ID within num_shards (default=None). This
  2041. argument should be specified only when num_shards is also specified.
  2042. block_reader (bool, optional): Whether read data by block mode (default=False).
  2043. sampler (Sampler, optional): Object used to choose samples from the
  2044. dataset (default=None, sampler is exclusive
  2045. with shuffle and block_reader). Support list: SubsetRandomSampler,
  2046. PkSampler
  2047. Raises:
  2048. ValueError: If num_shards is specified but shard_id is None.
  2049. ValueError: If shard_id is specified but num_shards is None.
  2050. ValueError: If block reader is true but partition is specified.
  2051. """
  2052. @check_minddataset
  2053. def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None,
  2054. shuffle=None, num_shards=None, shard_id=None,
  2055. block_reader=False, sampler=None):
  2056. super().__init__(num_parallel_workers)
  2057. if isinstance(dataset_file, list):
  2058. self.load_dataset = False
  2059. else:
  2060. self.load_dataset = True
  2061. self.dataset_file = dataset_file
  2062. self.columns_list = columns_list
  2063. self.global_shuffle = shuffle
  2064. self.distribution = ""
  2065. self.sampler = sampler
  2066. if num_shards is None or shard_id is None:
  2067. self.partitions = None
  2068. else:
  2069. self.partitions = [num_shards, shard_id]
  2070. if block_reader is True and self.partitions is not None:
  2071. raise ValueError("block reader not allowed true when use partitions")
  2072. if block_reader is True and shuffle is True:
  2073. raise ValueError("block reader not allowed true when use shuffle")
  2074. if block_reader is True:
  2075. logger.warning("WARN: global shuffle is not used.")
  2076. if sampler is not None:
  2077. if isinstance(sampler, samplers.SubsetRandomSampler) is False and \
  2078. isinstance(sampler, samplers.PKSampler) is False:
  2079. raise ValueError("the sampler is not supported yet.")
  2080. # sampler exclusive
  2081. if block_reader is True and sampler is not None:
  2082. raise ValueError("block reader not allowed true when use sampler")
  2083. if shuffle is not None and sampler is not None:
  2084. raise ValueError("shuffle not allowed when use sampler")
  2085. if block_reader is False and sampler is None:
  2086. self.global_shuffle = not bool(shuffle is False)
  2087. self.num_shards = num_shards
  2088. self.shard_id = shard_id
  2089. self.block_reader = block_reader
  2090. def get_args(self):
  2091. args = super().get_args()
  2092. args["dataset_file"] = self.dataset_file
  2093. args["load_dataset"] = self.load_dataset
  2094. args["columns_list"] = self.columns_list
  2095. args["global_shuffle"] = self.global_shuffle
  2096. args["partitions"] = self.partitions
  2097. args["block_reader"] = self.block_reader
  2098. args["num_shards"] = self.num_shards
  2099. args["shard_id"] = self.shard_id
  2100. args["sampler"] = self.sampler
  2101. return args
  2102. def get_dataset_size(self):
  2103. """
  2104. Get the number of batches in an epoch.
  2105. Return:
  2106. Number, number of batches.
  2107. """
  2108. if self.load_dataset:
  2109. dataset_file = [self.dataset_file]
  2110. else:
  2111. dataset_file = self.dataset_file
  2112. num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler)
  2113. if self.partitions is not None and self.partitions[0] > 0:
  2114. if num_rows % self.partitions[0] == 0:
  2115. num_rows = num_rows // self.partitions[0]
  2116. else:
  2117. num_rows = num_rows // self.partitions[0] + 1
  2118. return num_rows
  2119. def is_shuffled(self):
  2120. if self.global_shuffle is None:
  2121. return True
  2122. return self.global_shuffle or self.sampler.is_shuffled()
  2123. def is_sharded(self):
  2124. if self.num_shards is not None:
  2125. return self.num_shards > 1
  2126. return self.sampler.is_sharded()
  2127. def _iter_fn(dataset, num_samples):
  2128. """
  2129. Generator function wrapper for iterable dataset.
  2130. """
  2131. if num_samples is not None:
  2132. ds_iter = iter(dataset)
  2133. for _ in range(num_samples):
  2134. try:
  2135. val = next(ds_iter)
  2136. except StopIteration:
  2137. return
  2138. # convert output tensors to ndarrays
  2139. yield tuple([np.array(x, copy=False) for x in val])
  2140. else:
  2141. for val in dataset:
  2142. # convert output tensors to ndarrays
  2143. yield tuple([np.array(x, copy=False) for x in val])
  2144. def _generator_fn(generator, num_samples):
  2145. """
  2146. Generator function wrapper for generator function dataset.
  2147. """
  2148. if num_samples is not None:
  2149. gen_iter = generator()
  2150. for _ in range(num_samples):
  2151. try:
  2152. val = next(gen_iter)
  2153. except StopIteration:
  2154. return
  2155. yield val
  2156. else:
  2157. gen_iter = generator()
  2158. for val in gen_iter:
  2159. yield val
  2160. def _py_sampler_fn(sampler, num_samples, dataset):
  2161. """
  2162. Generator function wrapper for mappable dataset with python sampler.
  2163. """
  2164. if num_samples is not None:
  2165. sampler_iter = iter(sampler)
  2166. for _ in range(num_samples):
  2167. try:
  2168. idx = next(sampler_iter)
  2169. except StopIteration:
  2170. return
  2171. val = dataset[idx]
  2172. # convert output tensors to ndarrays
  2173. yield tuple([np.array(x, copy=False) for x in val])
  2174. else:
  2175. for i in sampler:
  2176. val = dataset[i]
  2177. # convert output tensors to ndarrays
  2178. yield tuple([np.array(x, copy=False) for x in val])
  2179. def _cpp_sampler_fn(sampler, dataset):
  2180. """
  2181. Generator function wrapper for mappable dataset with cpp sampler.
  2182. """
  2183. indices = sampler.get_indices()
  2184. for i in indices:
  2185. val = dataset[i]
  2186. # convert output tensors to ndarrays
  2187. yield tuple([np.array(x, copy=False) for x in val])
  2188. def _cpp_sampler_fn_mp(sampler, dataset, num_worker):
  2189. """
  2190. Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
  2191. """
  2192. indices = sampler.get_indices()
  2193. return _sampler_fn_mp(indices, dataset, num_worker)
  2194. def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker):
  2195. """
  2196. Multiprocessing generator function wrapper for mappable dataset with python sampler.
  2197. """
  2198. indices = _fetch_py_sampler_indices(sampler, num_samples)
  2199. return _sampler_fn_mp(indices, dataset, num_worker)
  2200. def _fetch_py_sampler_indices(sampler, num_samples):
  2201. """
  2202. Indices fetcher for python sampler.
  2203. """
  2204. if num_samples is not None:
  2205. sampler_iter = iter(sampler)
  2206. ret = []
  2207. for _ in range(num_samples):
  2208. try:
  2209. val = next(sampler_iter)
  2210. ret.append(val)
  2211. except StopIteration:
  2212. break
  2213. return ret
  2214. return [i for i in sampler]
  2215. def _fill_worker_indices(workers, indices, idx):
  2216. """
  2217. Worker index queue filler, fill worker index queue in round robin order.
  2218. """
  2219. num_worker = len(workers)
  2220. while idx < len(indices):
  2221. try:
  2222. workers[idx % num_worker].put(indices[idx])
  2223. idx += 1
  2224. except queue.Full:
  2225. break
  2226. return idx
  2227. def _sampler_fn_mp(indices, dataset, num_worker):
  2228. """
  2229. Multiprocessing generator function wrapper master process.
  2230. """
  2231. workers = []
  2232. # Event for end of epoch
  2233. eoe = multiprocessing.Event()
  2234. # Create workers
  2235. for _ in range(num_worker):
  2236. worker = _GeneratorWorker(dataset, eoe)
  2237. worker.daemon = True
  2238. workers.append(worker)
  2239. # Fill initial index queues
  2240. idx_cursor = 0
  2241. idx_cursor = _fill_worker_indices(workers, indices, idx_cursor)
  2242. # Start all workers
  2243. for w in workers:
  2244. w.start()
  2245. # Fetch results
  2246. for i in range(len(indices)):
  2247. # Fetch result and put index
  2248. try:
  2249. result = workers[i % num_worker].get()
  2250. except queue.Empty:
  2251. raise Exception("Generator worker process timeout")
  2252. except KeyboardInterrupt:
  2253. for w in workers:
  2254. w.terminate()
  2255. w.join()
  2256. raise Exception("Generator worker receives KeyboardInterrupt")
  2257. if idx_cursor < len(indices):
  2258. idx_cursor = _fill_worker_indices(workers, indices, idx_cursor)
  2259. # Set eoe event once all indices are sent
  2260. if idx_cursor == len(indices) and not eoe.is_set():
  2261. eoe.set()
  2262. yield tuple([np.array(x, copy=False) for x in result])
  2263. def _generator_worker_loop(dataset, idx_queue, result_queue, eoe):
  2264. """
  2265. Multiprocessing generator worker process loop.
  2266. """
  2267. while True:
  2268. # Fetch index, block
  2269. try:
  2270. idx = idx_queue.get()
  2271. except KeyboardInterrupt:
  2272. raise Exception("Generator worker receives KeyboardInterrupt")
  2273. if idx is None:
  2274. # When the queue is out of scope from master process, a None item can be fetched from the queue.
  2275. # Upon receiving None, worker process should check if EOE is set.
  2276. assert eoe.is_set(), ""
  2277. return
  2278. # Fetch data, any exception from __getitem__ will terminate worker and timeout master process
  2279. result = dataset[idx]
  2280. # Send data, block
  2281. try:
  2282. result_queue.put(result)
  2283. except KeyboardInterrupt:
  2284. raise Exception("Generator worker receives KeyboardInterrupt")
  2285. del result, idx
  2286. class _GeneratorWorker(multiprocessing.Process):
  2287. """
  2288. Worker process for multiprocess Generator.
  2289. """
  2290. def __init__(self, dataset, eoe):
  2291. self.idx_queue = multiprocessing.Queue(16)
  2292. self.res_queue = multiprocessing.Queue(16)
  2293. super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe))
  2294. def put(self, item):
  2295. """
  2296. Put function for worker index queue. Never block. Raise queue.Full on failure.
  2297. """
  2298. self.idx_queue.put_nowait(item)
  2299. def get(self):
  2300. """
  2301. Get function for worker result queue. Block with timeout.
  2302. """
  2303. return self.res_queue.get(timeout=5)
  2304. def __del__(self):
  2305. self.terminate()
  2306. class GeneratorDataset(MappableDataset):
  2307. """
  2308. A source dataset that generate data from python by invoking python data source each epoch.
  2309. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  2310. below shows what input args are allowed and their expected behavior.
  2311. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  2312. :widths: 25 25 50
  2313. :header-rows: 1
  2314. * - Parameter 'sampler'
  2315. - Parameter 'shuffle'
  2316. - Expected Order Behavior
  2317. * - None
  2318. - None
  2319. - random order
  2320. * - None
  2321. - True
  2322. - random order
  2323. * - None
  2324. - False
  2325. - sequential order
  2326. * - Sampler object
  2327. - None
  2328. - order defined by sampler
  2329. * - Sampler object
  2330. - True
  2331. - not allowed
  2332. * - Sampler object
  2333. - False
  2334. - not allowed
  2335. Args:
  2336. source (Callable/Iterable/Random Accessible):
  2337. A generator callable object, an iterable python object or a random accessible python object.
  2338. Callable source is required to return a tuple of numpy array as a row of the dataset on source().next().
  2339. Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next().
  2340. Random accessible source is required to return a tuple of numpy array as a row of the dataset on
  2341. source[idx].
  2342. column_names (list[str], optional): List of column names of the dataset (default=None). Users are required to
  2343. provide either column_names or schema.
  2344. column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None).
  2345. If provided, sanity check will be performed on generator output.
  2346. schema (Schema/String, optional): Path to the json schema file or schema object (default=None). Users are
  2347. required to provide either column_names or schema. If both are provided, schema will be used.
  2348. num_samples (int, optional): The number of samples to be included in the dataset
  2349. (default=None, all images).
  2350. num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
  2351. shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
  2352. (default=None, expected order behavior shown in the table).
  2353. sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
  2354. required (default=None, expected order behavior shown in the table).
  2355. num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
  2356. This argument should be specified only when 'num_samples' is "None". Random accessible input is required.
  2357. shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
  2358. when num_shards is also specified. Random accessible input is required.
  2359. Examples:
  2360. >>> import mindspore.dataset as ds
  2361. >>> # 1) Multidimensional generator function as callable input
  2362. >>> def generator_md():
  2363. >>> for i in range(64):
  2364. >>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
  2365. >>> # create multi_dimension_generator_dataset with GeneratorMD and column name "multi_dimensional_data"
  2366. >>> multi_dimension_generator_dataset = ds.GeneratorDataset(generator_md, ["multi_dimensional_data"])
  2367. >>> # 2) Multi-column generator function as callable input
  2368. >>> def generator_mc(maxid = 64):
  2369. >>> for i in range(maxid):
  2370. >>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
  2371. >>> # create multi_column_generator_dataset with GeneratorMC and column names "col1" and "col2"
  2372. >>> multi_column_generator_dataset = ds.GeneratorDataset(generator_mc, ["col1", "col2"])
  2373. >>> # 3) Iterable dataset as iterable input
  2374. >>> class MyIterable():
  2375. >>> def __iter__(self):
  2376. >>> return # User implementation
  2377. >>> # create iterable_generator_dataset with MyIterable object
  2378. >>> iterable_generator_dataset = ds.GeneratorDataset(MyIterable(), ["col1"])
  2379. >>> # 4) Random accessible dataset as Random accessible input
  2380. >>> class MyRA():
  2381. >>> def __getitem__(self, index):
  2382. >>> return # User implementation
  2383. >>> # create ra_generator_dataset with MyRA object
  2384. >>> ra_generator_dataset = ds.GeneratorDataset(MyRA(), ["col1"])
  2385. >>> # List/Dict/Tuple is also random accessible
  2386. >>> list_generator = ds.GeneratorDataset([(np.array(0),), (np.array(1)), (np.array(2))], ["col1"])
  2387. >>> # 5) Built-in Sampler
  2388. >>> my_generator = ds.GeneratorDataset(my_ds, ["img", "label"], sampler=samplers.RandomSampler())
  2389. >>>
  2390. """
  2391. @check_generatordataset
  2392. def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
  2393. num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None):
  2394. super().__init__(num_parallel_workers)
  2395. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  2396. if self.sampler is not None and hasattr(source, "__getitem__"):
  2397. if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
  2398. samplers.RandomSampler, samplers.SubsetRandomSampler,
  2399. samplers.WeightedRandomSampler, samplers.Sampler)):
  2400. if num_samples is None:
  2401. num_samples = len(source)
  2402. sampler_instance = self.sampler.create()
  2403. sampler_instance.set_num_rows(len(source))
  2404. sampler_instance.set_num_samples(num_samples)
  2405. sampler_instance.initialize()
  2406. if num_parallel_workers > 1:
  2407. self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers))
  2408. else:
  2409. self.source = (lambda: _cpp_sampler_fn(sampler_instance, source))
  2410. else:
  2411. if num_parallel_workers > 1:
  2412. self.source = (lambda: _py_sampler_fn_mp(self.sampler, num_samples, source, num_parallel_workers))
  2413. else:
  2414. self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source))
  2415. else:
  2416. try:
  2417. iter(source)
  2418. except TypeError:
  2419. # Use generator function if input callable
  2420. self.source = (lambda: _generator_fn(source, num_samples))
  2421. else:
  2422. # Use iterator function if input is iterable
  2423. # Random accessible input is also iterable
  2424. self.source = (lambda: _iter_fn(source, num_samples))
  2425. if column_names is not None and not isinstance(column_names, list):
  2426. column_names = [column_names]
  2427. self.column_names = column_names
  2428. if column_types is not None:
  2429. self.column_types = mstypelist_to_detypelist(column_types)
  2430. else:
  2431. self.column_types = column_types
  2432. if schema is not None:
  2433. self.schema = schema
  2434. if not isinstance(schema, Schema):
  2435. self.schema = Schema(schema)
  2436. self.column_names = []
  2437. self.column_types = []
  2438. for col in self.schema.columns:
  2439. self.column_names.append(col["name"])
  2440. self.column_types.append(DataType(col["type"]))
  2441. def get_args(self):
  2442. args = super().get_args()
  2443. args["source"] = self.source
  2444. args["column_names"] = self.column_names
  2445. args["column_types"] = self.column_types
  2446. return args
  2447. def get_dataset_size(self):
  2448. """
  2449. Get the number of batches in an epoch.
  2450. Return:
  2451. Number, number of batches.
  2452. """
  2453. return self._dataset_size
  2454. # manually set dataset_size as a temporary solution.
  2455. def set_dataset_size(self, value):
  2456. if value >= 0:
  2457. self._dataset_size = value
  2458. else:
  2459. raise ValueError('set dataset_size with negative value {}'.format(value))
  2460. def __deepcopy__(self, memodict):
  2461. if id(self) in memodict:
  2462. return memodict[id(self)]
  2463. cls = self.__class__
  2464. new_op = cls.__new__(cls)
  2465. memodict[id(self)] = new_op
  2466. new_op.input = copy.deepcopy(self.input, memodict)
  2467. new_op.output = copy.deepcopy(self.output, memodict)
  2468. new_op.num_parallel_workers = copy.deepcopy(self.num_parallel_workers, memodict)
  2469. new_op.column_types = copy.deepcopy(self.column_types, memodict)
  2470. new_op.column_names = copy.deepcopy(self.column_names, memodict)
  2471. new_op.source = self.source
  2472. new_op.sampler = self.sampler
  2473. return new_op
  2474. def is_shuffled(self):
  2475. return self.sampler.is_shuffled()
  2476. def is_sharded(self):
  2477. return self.sampler.is_sharded()
  2478. class TFRecordDataset(SourceDataset):
  2479. """
  2480. A source dataset that reads and parses datasets stored on disk in TFData format.
  2481. Args:
  2482. dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of
  2483. files. The list will be sorted in a lexicographical order.
  2484. schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
  2485. If the schema is not provided, the meta data from the TFData file is considered the schema.
  2486. columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
  2487. num_samples (int, optional): number of samples(rows) to read (default=None).
  2488. If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset;
  2489. If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows;
  2490. If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows.
  2491. num_parallel_workers (int, optional): number of workers to read the data
  2492. (default=None, number set in the config).
  2493. shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
  2494. If shuffle is False, no shuffling will be performed;
  2495. If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
  2496. Otherwise, there are two levels of shuffling:
  2497. - Shuffle.GLOBAL: Shuffle both the files and samples.
  2498. - Shuffle.FILES: Shuffle files only.
  2499. num_shards (int, optional): Number of shards that the dataset should be divided
  2500. into (default=None).
  2501. shard_id (int, optional): The shard ID within num_shards (default=None). This
  2502. argument should be specified only when num_shards is also specified.
  2503. shard_equal_rows (bool): Get equal rows for all shards(default=False). If shard_equal_rows is false, number
  2504. of rows of each shard may be not equal.
  2505. Examples:
  2506. >>> import mindspore.dataset as ds
  2507. >>> import mindspore.common.dtype as mstype
  2508. >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple tf data files
  2509. >>> # 1) get all rows from dataset_files with no explicit schema:
  2510. >>> # The meta-data in the first row will be used as a schema.
  2511. >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files)
  2512. >>> # 2) get all rows from dataset_files with user-defined schema:
  2513. >>> schema = ds.Schema()
  2514. >>> schema.add_column('col_1d', de_type=mindspore.int64, shape=[2])
  2515. >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema=schema)
  2516. >>> # 3) get all rows from dataset_files with schema file "./schema.json":
  2517. >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json")
  2518. """
  2519. @check_tfrecorddataset
  2520. def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
  2521. shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False):
  2522. super().__init__(num_parallel_workers)
  2523. self.dataset_files = self._find_files(dataset_files)
  2524. self.dataset_files.sort()
  2525. self.num_shards = num_shards
  2526. self.shard_id = shard_id
  2527. schema_obj = None
  2528. if (schema is not None) and (not isinstance(schema, Schema)):
  2529. schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it
  2530. self.schema = schema
  2531. self.columns_list = columns_list
  2532. self.num_samples = num_samples
  2533. if schema_obj is not None and num_samples is None:
  2534. self.num_samples = schema_obj.num_rows
  2535. if not isinstance(shuffle, (bool, Shuffle)):
  2536. raise TypeError("shuffle should be of boolean or enum 'Shuffle'.")
  2537. if not isinstance(shuffle, Shuffle):
  2538. if shuffle:
  2539. self.shuffle_level = Shuffle.GLOBAL
  2540. self.shuffle_files = True
  2541. else:
  2542. self.shuffle_level = None
  2543. self.shuffle_files = False
  2544. else:
  2545. self.shuffle_level = shuffle
  2546. self.shuffle_files = True
  2547. self.shard_equal_rows = shard_equal_rows
  2548. def get_args(self):
  2549. args = super().get_args()
  2550. args["dataset_files"] = self.dataset_files
  2551. if self.schema is not None:
  2552. if isinstance(self.schema, Schema):
  2553. self.schema.datasetType = 'TF'
  2554. if self.num_samples is not None:
  2555. self.schema.num_rows = self.num_samples
  2556. args["schema_json_string"] = self.schema.to_json()
  2557. else:
  2558. args["schema_file_path"] = self.schema
  2559. args["schema"] = self.schema
  2560. args["columns_list"] = self.columns_list
  2561. args["num_samples"] = self.num_samples
  2562. if self.shuffle_files is not None:
  2563. args["shuffle_files"] = self.shuffle_files
  2564. args["shuffle"] = self.shuffle_level
  2565. args["num_shards"] = self.num_shards
  2566. args["shard_id"] = self.shard_id
  2567. args["shard_equal_rows"] = self.shard_equal_rows
  2568. return args
  2569. def get_dataset_size(self, estimate=False):
  2570. """
  2571. Get the number of batches in an epoch.
  2572. Args:
  2573. estimate (bool, optional): Fast estimation of the dataset size instead of a full scan.
  2574. Return:
  2575. Number, number of batches.
  2576. """
  2577. if self._dataset_size is None:
  2578. num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate)
  2579. num_rows = get_num_rows(num_rows, self.num_shards)
  2580. if self.num_samples is None:
  2581. return num_rows
  2582. return min(self.num_samples, num_rows)
  2583. return self._dataset_size
  2584. # manually set dataset_size as a tempoary solution.
  2585. def set_dataset_size(self, value):
  2586. logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
  2587. if value >= 0:
  2588. self._dataset_size = value
  2589. else:
  2590. raise ValueError('set dataset_size with negative value {}'.format(value))
  2591. def is_shuffled(self):
  2592. return self.shuffle_files
  2593. def is_sharded(self):
  2594. if self.num_shards is not None:
  2595. return self.num_shards > 1
  2596. return False
  2597. class ManifestDataset(MappableDataset):
  2598. """
  2599. A source dataset that reads images from a manifest file.
  2600. The generated dataset has two columns ['image', 'label'].
  2601. The shape of the image column is [image_size] if decode flag is False, or [H,W,C]
  2602. otherwise.
  2603. The type of the image tensor is uint8. The label is just a scalar uint64
  2604. tensor.
  2605. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  2606. below shows what input args are allowed and their expected behavior.
  2607. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  2608. :widths: 25 25 50
  2609. :header-rows: 1
  2610. * - Parameter 'sampler'
  2611. - Parameter 'shuffle'
  2612. - Expected Order Behavior
  2613. * - None
  2614. - None
  2615. - random order
  2616. * - None
  2617. - True
  2618. - random order
  2619. * - None
  2620. - False
  2621. - sequential order
  2622. * - Sampler object
  2623. - None
  2624. - order defined by sampler
  2625. * - Sampler object
  2626. - True
  2627. - not allowed
  2628. * - Sampler object
  2629. - False
  2630. - not allowed
  2631. Args:
  2632. dataset_file (str): File to be read.
  2633. usage (str, optional): Need train, eval or inference data (default="train").
  2634. num_samples (int, optional): The number of images to be included in the dataset.
  2635. (default=None, all images).
  2636. num_parallel_workers (int, optional): Number of workers to read the data
  2637. (default=None, number set in the config).
  2638. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  2639. order behavior shown in the table).
  2640. sampler (Sampler, optional): Object used to choose samples from the
  2641. dataset (default=None, expected order behavior shown in the table).
  2642. class_indexing (dict, optional): A str-to-int mapping from label name to index
  2643. (default=None, the folder names will be sorted alphabetically and each
  2644. class will be given a unique index starting from 0).
  2645. decode (bool, optional): decode the images after reading (defaults=False).
  2646. num_shards (int, optional): Number of shards that the dataset should be divided
  2647. into (default=None).
  2648. shard_id (int, optional): The shard ID within num_shards (default=None). This
  2649. argument should be specified only when num_shards is also specified.
  2650. Raises:
  2651. RuntimeError: If sampler and shuffle are specified at the same time.
  2652. RuntimeError: If sampler and sharding are specified at the same time.
  2653. RuntimeError: If num_shards is specified but shard_id is None.
  2654. RuntimeError: If shard_id is specified but num_shards is None.
  2655. RuntimeError: If class_indexing is not a dictionary.
  2656. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  2657. Examples:
  2658. >>> import mindspore.dataset as ds
  2659. >>> dataset_file = "/path/to/manifest_file.manifest"
  2660. >>> # 1) read all samples specified in manifest_file dataset with 8 threads for training:
  2661. >>> manifest_dataset = ds.ManifestDataset(dataset_file, usage="train", num_parallel_workers=8)
  2662. >>> # 2) reads samples (specified in manifest_file.manifest) for shard 0 in a 2-way distributed training setup:
  2663. >>> manifest_dataset = ds.ManifestDataset(dataset_file, num_shards=2, shard_id=0)
  2664. """
  2665. @check_manifestdataset
  2666. def __init__(self, dataset_file, usage="train", num_samples=None, num_parallel_workers=None,
  2667. shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None):
  2668. super().__init__(num_parallel_workers)
  2669. self.dataset_file = dataset_file
  2670. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  2671. if class_indexing is not None and not isinstance(class_indexing, dict):
  2672. raise RuntimeError("class_indexing should be a dictionary.")
  2673. self.num_samples = num_samples
  2674. self.class_indexing = class_indexing
  2675. self.decode = decode
  2676. self.usage = usage
  2677. self.shuffle_level = shuffle
  2678. self.num_shards = num_shards
  2679. self.shard_id = shard_id
  2680. def get_args(self):
  2681. args = super().get_args()
  2682. args["dataset_file"] = self.dataset_file
  2683. args["usage"] = self.usage
  2684. args["num_samples"] = self.num_samples
  2685. args["shuffle"] = self.shuffle_level
  2686. args["sampler"] = self.sampler
  2687. args["class_indexing"] = self.class_indexing
  2688. args["decode"] = self.decode
  2689. args["num_shards"] = self.num_shards
  2690. args["shard_id"] = self.shard_id
  2691. return args
  2692. def get_dataset_size(self):
  2693. """
  2694. Get the number of batches in an epoch.
  2695. Return:
  2696. Number, number of batches.
  2697. """
  2698. if self.num_samples is None:
  2699. num_samples = 0
  2700. else:
  2701. num_samples = self.num_samples
  2702. if self.class_indexing is None:
  2703. class_indexing = dict()
  2704. else:
  2705. class_indexing = self.class_indexing
  2706. num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0]
  2707. return get_num_rows(num_rows, self.num_shards)
  2708. def num_classes(self):
  2709. """
  2710. Get the number of classes in a dataset.
  2711. Return:
  2712. Number, number of classes.
  2713. """
  2714. if self.num_samples is None:
  2715. num_samples = 0
  2716. else:
  2717. num_samples = self.num_samples
  2718. if self.class_indexing is None:
  2719. class_indexing = dict()
  2720. else:
  2721. class_indexing = self.class_indexing
  2722. return ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[1]
  2723. def get_class_indexing(self):
  2724. """
  2725. Get the class index.
  2726. Return:
  2727. Dict, A str-to-int mapping from label name to index.
  2728. """
  2729. if self.num_samples is None:
  2730. num_samples = 0
  2731. else:
  2732. num_samples = self.num_samples
  2733. if self.class_indexing is None:
  2734. class_indexing = dict()
  2735. else:
  2736. class_indexing = self.class_indexing
  2737. return ManifestOp.get_class_indexing(self.dataset_file, num_samples, class_indexing, self.usage)
  2738. def is_shuffled(self):
  2739. if self.shuffle_level is None:
  2740. return True
  2741. return self.shuffle_level or self.sampler.is_shuffled()
  2742. def is_sharded(self):
  2743. if self.num_shards is not None:
  2744. return self.num_shards > 1
  2745. return self.sampler.is_sharded()
  2746. class Cifar10Dataset(MappableDataset):
  2747. """
  2748. A source dataset that reads cifar10 data.
  2749. The generated dataset has two columns ['image', 'label'].
  2750. The type of the image tensor is uint8. The label is just a scalar uint32
  2751. tensor.
  2752. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  2753. below shows what input args are allowed and their expected behavior.
  2754. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  2755. :widths: 25 25 50
  2756. :header-rows: 1
  2757. * - Parameter 'sampler'
  2758. - Parameter 'shuffle'
  2759. - Expected Order Behavior
  2760. * - None
  2761. - None
  2762. - random order
  2763. * - None
  2764. - True
  2765. - random order
  2766. * - None
  2767. - False
  2768. - sequential order
  2769. * - Sampler object
  2770. - None
  2771. - order defined by sampler
  2772. * - Sampler object
  2773. - True
  2774. - not allowed
  2775. * - Sampler object
  2776. - False
  2777. - not allowed
  2778. Args:
  2779. dataset_dir (str): Path to the root directory that contains the dataset.
  2780. num_samples (int, optional): The number of images to be included in the dataset.
  2781. (default=None, all images).
  2782. num_parallel_workers (int, optional): Number of workers to read the data
  2783. (default=None, number set in the config).
  2784. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  2785. order behavior shown in the table).
  2786. sampler (Sampler, optional): Object used to choose samples from the
  2787. dataset (default=None, expected order behavior shown in the table).
  2788. num_shards (int, optional): Number of shards that the dataset should be divided
  2789. into (default=None).
  2790. shard_id (int, optional): The shard ID within num_shards (default=None). This
  2791. argument should be specified only when num_shards is also specified.
  2792. Raises:
  2793. RuntimeError: If sampler and shuffle are specified at the same time.
  2794. RuntimeError: If sampler and sharding are specified at the same time.
  2795. RuntimeError: If num_shards is specified but shard_id is None.
  2796. RuntimeError: If shard_id is specified but num_shards is None.
  2797. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  2798. Examples:
  2799. >>> import mindspore.dataset as ds
  2800. >>> dataset_dir = "/path/to/cifar10_dataset_directory"
  2801. >>> # 1) get all samples from CIFAR10 dataset in sequence:
  2802. >>> dataset = ds.Cifar10Dataset(dataset_dir=dataset_dir,shuffle=False)
  2803. >>> # 2) randomly select 350 samples from CIFAR10 dataset:
  2804. >>> dataset = ds.Cifar10Dataset(dataset_dir=dataset_dir,num_samples=350, shuffle=True)
  2805. >>> # 3) get samples from CIFAR10 dataset for shard 0 in a 2 way distributed training:
  2806. >>> dataset = ds.Cifar10Dataset(dataset_dir=dataset_dir,num_shards=2,shard_id=0)
  2807. >>> # in CIFAR10 dataset, each dictionary has keys "image" and "label"
  2808. """
  2809. @check_mnist_cifar_dataset
  2810. def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
  2811. shuffle=None, sampler=None, num_shards=None, shard_id=None):
  2812. super().__init__(num_parallel_workers)
  2813. self.dataset_dir = dataset_dir
  2814. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  2815. self.num_samples = num_samples
  2816. self.num_shards = num_shards
  2817. self.shard_id = shard_id
  2818. self.shuffle_level = shuffle
  2819. def get_args(self):
  2820. args = super().get_args()
  2821. args["dataset_dir"] = self.dataset_dir
  2822. args["num_samples"] = self.num_samples
  2823. args["sampler"] = self.sampler
  2824. args["num_shards"] = self.num_shards
  2825. args["shard_id"] = self.shard_id
  2826. args["shuffle"] = self.shuffle_level
  2827. return args
  2828. def get_dataset_size(self):
  2829. """
  2830. Get the number of batches in an epoch.
  2831. Return:
  2832. Number, number of batches.
  2833. """
  2834. if self.num_samples is None:
  2835. num_samples = 0
  2836. else:
  2837. num_samples = self.num_samples
  2838. num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True)
  2839. return get_num_rows(num_rows, self.num_shards)
  2840. def is_shuffled(self):
  2841. if self.shuffle_level is None:
  2842. return True
  2843. return self.shuffle_level or self.sampler.is_shuffled()
  2844. def is_sharded(self):
  2845. if self.num_shards is not None:
  2846. return self.num_shards > 1
  2847. return self.sampler.is_sharded()
  2848. class Cifar100Dataset(MappableDataset):
  2849. """
  2850. A source dataset that reads cifar100 data.
  2851. The generated dataset has three columns ['image', 'coarse_label', 'fine_label'].
  2852. The type of the image tensor is uint8. The coarse and fine are just a scalar uint32
  2853. tensor.
  2854. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  2855. below shows what input args are allowed and their expected behavior.
  2856. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  2857. :widths: 25 25 50
  2858. :header-rows: 1
  2859. * - Parameter 'sampler'
  2860. - Parameter 'shuffle'
  2861. - Expected Order Behavior
  2862. * - None
  2863. - None
  2864. - random order
  2865. * - None
  2866. - True
  2867. - random order
  2868. * - None
  2869. - False
  2870. - sequential order
  2871. * - Sampler object
  2872. - None
  2873. - order defined by sampler
  2874. * - Sampler object
  2875. - True
  2876. - not allowed
  2877. * - Sampler object
  2878. - False
  2879. - not allowed
  2880. Args:
  2881. dataset_dir (str): Path to the root directory that contains the dataset.
  2882. num_samples (int, optional): The number of images to be included in the dataset.
  2883. (default=None, all images).
  2884. num_parallel_workers (int, optional): Number of workers to read the data
  2885. (default=None, number set in the config).
  2886. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  2887. order behavior shown in the table).
  2888. sampler (Sampler, optional): Object used to choose samples from the
  2889. dataset (default=None, expected order behavior shown in the table).
  2890. num_shards (int, optional): Number of shards that the dataset should be divided
  2891. into (default=None).
  2892. shard_id (int, optional): The shard ID within num_shards (default=None). This
  2893. argument should be specified only when num_shards is also specified.
  2894. Raises:
  2895. RuntimeError: If sampler and shuffle are specified at the same time.
  2896. RuntimeError: If sampler and sharding are specified at the same time.
  2897. RuntimeError: If num_shards is specified but shard_id is None.
  2898. RuntimeError: If shard_id is specified but num_shards is None.
  2899. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  2900. Examples:
  2901. >>> import mindspore.dataset as ds
  2902. >>> dataset_dir = "/path/to/cifar100_dataset_directory"
  2903. >>> # 1) get all samples from CIFAR100 dataset in sequence:
  2904. >>> cifar100_dataset = ds.Cifar100Dataset(dataset_dir=dataset_dir,shuffle=False)
  2905. >>> # 2) randomly select 350 samples from CIFAR100 dataset:
  2906. >>> cifar100_dataset = ds.Cifar100Dataset(dataset_dir=dataset_dir,num_samples=350, shuffle=True)
  2907. >>> # in CIFAR100 dataset, each dictionary has 3 keys: "image", "fine_label" and "coarse_label"
  2908. """
  2909. @check_mnist_cifar_dataset
  2910. def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
  2911. shuffle=None, sampler=None, num_shards=None, shard_id=None):
  2912. super().__init__(num_parallel_workers)
  2913. self.dataset_dir = dataset_dir
  2914. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  2915. self.num_samples = num_samples
  2916. self.num_shards = num_shards
  2917. self.shard_id = shard_id
  2918. self.shuffle_level = shuffle
  2919. def get_args(self):
  2920. args = super().get_args()
  2921. args["dataset_dir"] = self.dataset_dir
  2922. args["num_samples"] = self.num_samples
  2923. args["sampler"] = self.sampler
  2924. args["num_shards"] = self.num_shards
  2925. args["shard_id"] = self.shard_id
  2926. args["shuffle"] = self.shuffle_level
  2927. return args
  2928. def get_dataset_size(self):
  2929. """
  2930. Get the number of batches in an epoch.
  2931. Return:
  2932. Number, number of batches.
  2933. """
  2934. if self.num_samples is None:
  2935. num_samples = 0
  2936. else:
  2937. num_samples = self.num_samples
  2938. num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False)
  2939. return get_num_rows(num_rows, self.num_shards)
  2940. def is_shuffled(self):
  2941. if self.shuffle_level is None:
  2942. return True
  2943. return self.shuffle_level or self.sampler.is_shuffled()
  2944. def is_sharded(self):
  2945. if self.num_shards is not None:
  2946. return self.num_shards > 1
  2947. return self.sampler.is_sharded()
  2948. class RandomDataset(SourceDataset):
  2949. """
  2950. A source dataset that generates random data.
  2951. Args:
  2952. num_samples (int): number of samples to generate.
  2953. schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
  2954. If the schema is not provided, the meta data from the TFRecord file is considered the schema.
  2955. columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
  2956. num_parallel_workers (int, optional): number of workers to read the data
  2957. (default=None, number set in the config).
  2958. """
  2959. def __init__(self, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None):
  2960. super().__init__(num_parallel_workers)
  2961. schema_obj = None
  2962. if (schema is not None) and (not isinstance(schema, Schema)):
  2963. schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it
  2964. self.schema = schema
  2965. self.columns_list = columns_list
  2966. self.num_samples = num_samples
  2967. if schema_obj is not None and num_samples is None:
  2968. self.num_samples = schema_obj.num_rows
  2969. def get_args(self):
  2970. args = super().get_args()
  2971. if self.schema is not None:
  2972. if isinstance(self.schema, Schema):
  2973. self.schema.datasetType = 'Random'
  2974. if self.num_samples is not None:
  2975. self.schema.num_rows = self.num_samples
  2976. args["schema_json_string"] = self.schema.to_json()
  2977. else:
  2978. args["schema_file_path"] = self.schema
  2979. args["schema"] = self.schema
  2980. if self.columns_list is not None:
  2981. args["columns_list"] = self.columns_list
  2982. if self.num_samples is not None:
  2983. args["num_samples"] = self.num_samples
  2984. return args
  2985. def get_dataset_size(self):
  2986. """
  2987. Get the number of batches in an epoch.
  2988. Return:
  2989. Number, number of batches.
  2990. """
  2991. return num_samples
  2992. def is_shuffled(self):
  2993. return True
  2994. def is_sharded(self):
  2995. return False
  2996. class Schema:
  2997. """
  2998. Class to represent a schema of dataset.
  2999. Args:
  3000. schema_file(str): Path of schema file (default=None).
  3001. Return:
  3002. Schema object, schema info about dataset.
  3003. Raises:
  3004. RuntimeError: If schema file failed to load.
  3005. Example:
  3006. >>> import mindspore.dataset as ds
  3007. >>> import mindspore.common.dtype as mstype
  3008. >>> # create schema, specify column name, mindspore.dtype and shape of the column
  3009. >>> schema = ds.Schema()
  3010. >>> schema.add_column('col1', de_type=mindspore.int64, shape=[2])
  3011. """
  3012. def __init__(self, schema_file=None):
  3013. self.num_rows = None
  3014. if schema_file is None:
  3015. self.columns = []
  3016. self.dataset_type = ''
  3017. else:
  3018. if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK):
  3019. raise ValueError("The file %s does not exist or permission denied!" % schema_file)
  3020. try:
  3021. with open(schema_file, 'r') as load_f:
  3022. json_obj = json.load(load_f)
  3023. except json.decoder.JSONDecodeError:
  3024. raise RuntimeError("Schema file failed to load.")
  3025. except UnicodeDecodeError:
  3026. raise RuntimeError("Schema file failed to decode.")
  3027. except Exception:
  3028. raise RuntimeError("Schema file failed to open.")
  3029. self.from_json(json_obj)
  3030. @check_add_column
  3031. def add_column(self, name, de_type, shape=None):
  3032. """
  3033. Add new column to the schema.
  3034. Args:
  3035. name (str): name of the column.
  3036. de_type (str): data type of the column.
  3037. shape (list[int], optional): shape of the column
  3038. (default=None, [-1] which is an unknown shape of rank 1).
  3039. Raises:
  3040. ValueError: If column type is unknown.
  3041. """
  3042. new_column = dict()
  3043. new_column["name"] = name
  3044. if isinstance(de_type, typing.Type):
  3045. de_type = mstype_to_detype(de_type)
  3046. new_column["type"] = str(de_type)
  3047. else:
  3048. new_column["type"] = str(DataType(de_type))
  3049. if shape is not None:
  3050. new_column["shape"] = shape
  3051. new_column["rank"] = len(shape)
  3052. else:
  3053. new_column["rank"] = 1
  3054. self.columns.append(new_column)
  3055. def to_json(self):
  3056. """
  3057. Get a JSON string of the schema.
  3058. Returns:
  3059. Str, JSON string of the schema.
  3060. """
  3061. json_file = dict()
  3062. json_file["columns"] = self.columns
  3063. if self.dataset_type:
  3064. json_file["datasetType"] = self.dataset_type
  3065. if self.num_rows:
  3066. json_file["numRows"] = self.num_rows
  3067. return json.dumps(json_file, indent=2)
  3068. def parse_columns(self, columns):
  3069. """
  3070. Parse the columns and add it to self.
  3071. Args:
  3072. columns (dict or list[dict]): dataset attribution information, decoded from schema file.
  3073. - list[dict], 'name' and 'type' must be in keys, 'shape' optional.
  3074. - dict, columns.keys() as name, columns.values() is dict, and 'type' inside, 'shape' optional.
  3075. Raises:
  3076. RuntimeError: If failed to parse columns.
  3077. RuntimeError: If unknown items in columns.
  3078. RuntimeError: If column's name field is missing.
  3079. RuntimeError: If column's type field is missing.
  3080. Example:
  3081. >>> schema = Schema()
  3082. >>> columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]},
  3083. >>> {'name': 'label', 'type': 'int8', 'shape': [1]}]
  3084. >>> schema.parse_columns(columns1)
  3085. >>> columns2 = {'image': {'shape': [3, 3], 'type': 'int8'}, 'label': {'shape': [1], 'type': 'int8'}}
  3086. >>> schema.parse_columns(columns2)
  3087. """
  3088. self.columns = []
  3089. if isinstance(columns, list):
  3090. for column in columns:
  3091. try:
  3092. name = column.pop("name")
  3093. except KeyError:
  3094. raise RuntimeError("Column's name is missing")
  3095. try:
  3096. de_type = column.pop("type")
  3097. except KeyError:
  3098. raise RuntimeError("Column' type is missing")
  3099. shape = column.pop("shape", None)
  3100. column.pop("t_impl", None)
  3101. column.pop("rank", None)
  3102. if column:
  3103. raise RuntimeError("Unknown field {}".format(",".join(column.keys())))
  3104. self.add_column(name, de_type, shape)
  3105. elif isinstance(columns, dict):
  3106. for key, value in columns.items():
  3107. name = key
  3108. try:
  3109. de_type = value.pop("type")
  3110. except KeyError:
  3111. raise RuntimeError("Column' type is missing")
  3112. shape = value.pop("shape", None)
  3113. value.pop("t_impl", None)
  3114. value.pop("rank", None)
  3115. if value:
  3116. raise RuntimeError("Unknown field {}".format(",".join(value.keys())))
  3117. self.add_column(name, de_type, shape)
  3118. else:
  3119. raise RuntimeError("columns must be dict or list, columns contain name, type, shape(optional).")
  3120. def from_json(self, json_obj):
  3121. """
  3122. Get schema file from json file.
  3123. Args:
  3124. json_obj(dictionary): object of json parsed.
  3125. Raises:
  3126. RuntimeError: if there is unknown item in the object.
  3127. RuntimeError: if dataset type is missing in the object.
  3128. RuntimeError: if columns are missing in the object.
  3129. """
  3130. if not isinstance(json_obj, dict) or json_obj is None:
  3131. raise ValueError("Expected non-empty dict.")
  3132. for k, v in json_obj.items():
  3133. if k == "datasetType":
  3134. self.dataset_type = v
  3135. elif k == "numRows":
  3136. self.num_rows = v
  3137. elif k == "columns":
  3138. self.parse_columns(v)
  3139. else:
  3140. raise RuntimeError("Unknown field %s" % k)
  3141. if self.dataset_type is None:
  3142. raise RuntimeError("DatasetType field is missing.")
  3143. if self.columns is None:
  3144. raise RuntimeError("Columns are missing.")
  3145. if self.num_rows is not None:
  3146. if not isinstance(self.num_rows, int) or self.num_rows <= 0:
  3147. raise ValueError("numRows must be greater than 0")
  3148. def __str__(self):
  3149. return self.to_json()
  3150. class VOCDataset(MappableDataset):
  3151. """
  3152. A source dataset for reading and parsing VOC dataset.
  3153. The generated dataset has two columns ['image', 'target'].
  3154. The shape of both column is [image_size] if decode flag is False, or [H, W, C]
  3155. otherwise.
  3156. The type of both tensor is uint8.
  3157. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  3158. below shows what input args are allowed and their expected behavior.
  3159. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  3160. :widths: 25 25 50
  3161. :header-rows: 1
  3162. * - Parameter 'sampler'
  3163. - Parameter 'shuffle'
  3164. - Expected Order Behavior
  3165. * - None
  3166. - None
  3167. - random order
  3168. * - None
  3169. - True
  3170. - random order
  3171. * - None
  3172. - False
  3173. - sequential order
  3174. * - Sampler object
  3175. - None
  3176. - order defined by sampler
  3177. * - Sampler object
  3178. - True
  3179. - not allowed
  3180. * - Sampler object
  3181. - False
  3182. - not allowed
  3183. Args:
  3184. dataset_dir (str): Path to the root directory that contains the dataset.
  3185. task (str): Set the task type of reading voc data, now only support "Segmentation" or "Detection"
  3186. (default="Segmentation")
  3187. mode(str): Set the data list txt file to be readed (default="train")
  3188. class_indexing (dict, optional): A str-to-int mapping from label name to index
  3189. (default=None, the folder names will be sorted alphabetically and each
  3190. class will be given a unique index starting from 0).
  3191. num_samples (int, optional): The number of images to be included in the dataset
  3192. (default=None, all images).
  3193. num_parallel_workers (int, optional): Number of workers to read the data
  3194. (default=None, number set in the config).
  3195. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  3196. order behavior shown in the table).
  3197. decode (bool, optional): Decode the images after reading (default=False).
  3198. sampler (Sampler, optional): Object used to choose samples from the dataset
  3199. (default=None, expected order behavior shown in the table).
  3200. num_shards (int, optional): Number of shards that the dataset should be divided
  3201. into (default=None).
  3202. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3203. argument should be specified only when num_shards is also specified.
  3204. Raises:
  3205. RuntimeError: If xml of Annotations is a invalid format
  3206. RuntimeError: If xml of Annotations loss attribution of "object"
  3207. RuntimeError: If xml of Annotations loss attribution of "bndbox"
  3208. RuntimeError: If sampler and shuffle are specified at the same time.
  3209. RuntimeError: If sampler and sharding are specified at the same time.
  3210. RuntimeError: If num_shards is specified but shard_id is None.
  3211. RuntimeError: If shard_id is specified but num_shards is None.
  3212. ValueError: If task is not equal 'Segmentation' or 'Detection'.
  3213. ValueError: If task equal 'Segmentation' but class_indexing is not None.
  3214. ValueError: If txt related to mode is not exist.
  3215. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  3216. Examples:
  3217. >>> import mindspore.dataset as ds
  3218. >>> dataset_dir = "/path/to/voc_dataset_directory"
  3219. >>> # 1) read VOC data for segmenatation train
  3220. >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Segmentation", mode="train")
  3221. >>> # 2) read VOC data for detection train
  3222. >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train")
  3223. >>> # 3) read all VOC dataset samples in dataset_dir with 8 threads in random order:
  3224. >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train", num_parallel_workers=8)
  3225. >>> # 4) read then decode all VOC dataset samples in dataset_dir in sequence:
  3226. >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train", decode=True, shuffle=False)
  3227. >>> # in VOC dataset, if task='Segmentation', each dictionary has keys "image" and "target"
  3228. >>> # in VOC dataset, if task='Detection', each dictionary has keys "image" and "annotation"
  3229. """
  3230. @check_vocdataset
  3231. def __init__(self, dataset_dir, task="Segmentation", mode="train", class_indexing=None, num_samples=None,
  3232. num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
  3233. super().__init__(num_parallel_workers)
  3234. self.dataset_dir = dataset_dir
  3235. self.task = task
  3236. self.mode = mode
  3237. self.class_indexing = class_indexing
  3238. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  3239. self.num_samples = num_samples
  3240. self.decode = decode
  3241. self.shuffle_level = shuffle
  3242. self.num_shards = num_shards
  3243. self.shard_id = shard_id
  3244. def get_args(self):
  3245. args = super().get_args()
  3246. args["dataset_dir"] = self.dataset_dir
  3247. args["task"] = self.task
  3248. args["mode"] = self.mode
  3249. args["class_indexing"] = self.class_indexing
  3250. args["num_samples"] = self.num_samples
  3251. args["sampler"] = self.sampler
  3252. args["decode"] = self.decode
  3253. args["shuffle"] = self.shuffle_level
  3254. args["num_shards"] = self.num_shards
  3255. args["shard_id"] = self.shard_id
  3256. return args
  3257. def get_dataset_size(self):
  3258. """
  3259. Get the number of batches in an epoch.
  3260. Return:
  3261. Number, number of batches.
  3262. """
  3263. return self.num_samples
  3264. def get_class_indexing(self):
  3265. """
  3266. Get the class index.
  3267. Return:
  3268. Dict, A str-to-int mapping from label name to index.
  3269. """
  3270. if self.task != "Detection":
  3271. raise NotImplementedError()
  3272. if self.num_samples is None:
  3273. num_samples = 0
  3274. else:
  3275. num_samples = self.num_samples
  3276. if self.class_indexing is None:
  3277. class_indexing = dict()
  3278. else:
  3279. class_indexing = self.class_indexing
  3280. return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
  3281. def is_shuffled(self):
  3282. if self.shuffle_level is None:
  3283. return True
  3284. return self.shuffle_level or self.sampler.is_shuffled()
  3285. def is_sharded(self):
  3286. if self.num_shards is not None:
  3287. return self.num_shards > 1
  3288. return self.sampler.is_sharded()
  3289. class CelebADataset(MappableDataset):
  3290. """
  3291. A source dataset for reading and parsing CelebA dataset.Only support list_attr_celeba.txt currently.
  3292. Note:
  3293. The generated dataset has two columns ['image', 'attr'].
  3294. The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
  3295. Args:
  3296. dataset_dir (str): Path to the root directory that contains the dataset.
  3297. num_parallel_workers (int, optional): Number of workers to read the data (default=value set in the config).
  3298. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None).
  3299. dataset_type (string): one of 'all', 'train', 'valid' or 'test'.
  3300. sampler (Sampler, optional): Object used to choose samples from the dataset (default=None).
  3301. decode (bool, optional): decode the images after reading (default=False).
  3302. extensions (list[str], optional): List of file extensions to be
  3303. included in the dataset (default=None).
  3304. num_samples (int, optional): The number of images to be included in the dataset.
  3305. (default=None, all images).
  3306. num_shards (int, optional): Number of shards that the dataset should be divided
  3307. into (default=None).
  3308. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3309. argument should be specified only when num_shards is also specified.
  3310. """
  3311. @check_celebadataset
  3312. def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, dataset_type='all',
  3313. sampler=None, decode=False, extensions=None, num_samples=None, num_shards=None, shard_id=None):
  3314. super().__init__(num_parallel_workers)
  3315. self.dataset_dir = dataset_dir
  3316. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  3317. self.num_parallel_workers = num_parallel_workers
  3318. self.decode = decode
  3319. self.extensions = extensions
  3320. self.num_samples = num_samples
  3321. self.dataset_type = dataset_type
  3322. self.num_shards = num_shards
  3323. self.shard_id = shard_id
  3324. self.shuffle_level = shuffle
  3325. def get_args(self):
  3326. args = super().get_args()
  3327. args["dataset_dir"] = self.dataset_dir
  3328. args["sampler"] = self.sampler
  3329. args["shuffle"] = self.shuffle_level
  3330. args["decode"] = self.decode
  3331. args["extensions"] = self.extensions
  3332. args["num_samples"] = self.num_samples
  3333. args["dataset_type"] = self.dataset_type
  3334. args["num_shards"] = self.num_shards
  3335. args["shard_id"] = self.shard_id
  3336. return args
  3337. def is_shuffled(self):
  3338. if self.shuffle_level is None:
  3339. return True
  3340. return self.shuffle_level or self.sampler.is_shuffled()
  3341. def is_sharded(self):
  3342. if self.num_shards is not None:
  3343. return self.num_shards > 1
  3344. return self.sampler.is_sharded()
  3345. class TextFileDataset(SourceDataset):
  3346. """
  3347. A source dataset that reads and parses datasets stored on disk in text format.
  3348. The generated dataset has one columns ['text'].
  3349. Args:
  3350. dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of
  3351. files. The list will be sorted in a lexicographical order.
  3352. num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
  3353. num_parallel_workers (int, optional): number of workers to read the data
  3354. (default=None, number set in the config).
  3355. shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
  3356. If shuffle is False, no shuffling will be performed;
  3357. If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
  3358. Otherwise, there are two levels of shuffling:
  3359. - Shuffle.GLOBAL: Shuffle both the files and samples.
  3360. - Shuffle.FILES: Shuffle files only.
  3361. num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
  3362. shard_id (int, optional): The shard ID within num_shards (default=None). This
  3363. argument should be specified only when num_shards is also specified.
  3364. Examples:
  3365. >>> import mindspore.dataset as ds
  3366. >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
  3367. >>> dataset = ds.TextFileDataset(dataset_files=dataset_files)
  3368. """
  3369. @check_textfiledataset
  3370. def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None,
  3371. shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
  3372. super().__init__(num_parallel_workers)
  3373. self.dataset_files = self._find_files(dataset_files)
  3374. self.dataset_files.sort()
  3375. self.num_samples = num_samples
  3376. if not isinstance(shuffle, (bool, Shuffle)):
  3377. raise TypeError("shuffle should be of boolean or enum 'Shuffle'.")
  3378. if not isinstance(shuffle, Shuffle):
  3379. if shuffle:
  3380. self.shuffle_level = Shuffle.GLOBAL
  3381. self.shuffle_files = True
  3382. else:
  3383. self.shuffle_level = None
  3384. self.shuffle_files = False
  3385. else:
  3386. self.shuffle_level = shuffle
  3387. self.shuffle_files = True
  3388. self.num_shards = num_shards
  3389. self.shard_id = shard_id
  3390. def get_args(self):
  3391. args = super().get_args()
  3392. args["dataset_files"] = self.dataset_files
  3393. args["num_samples"] = self.num_samples
  3394. if self.shuffle_files is not None:
  3395. args["shuffle_files"] = self.shuffle_files
  3396. args["shuffle"] = self.shuffle_level
  3397. args["num_shards"] = self.num_shards
  3398. args["shard_id"] = self.shard_id
  3399. return args
  3400. def get_dataset_size(self):
  3401. """
  3402. Get the number of batches in an epoch.
  3403. Return:
  3404. Number, number of batches.
  3405. """
  3406. if self._dataset_size is None:
  3407. num_rows = TextFileOp.get_num_rows(self.dataset_files)
  3408. num_rows = get_num_rows(num_rows, self.num_shards)
  3409. if self.num_samples is None:
  3410. return num_rows
  3411. return min(self.num_samples, num_rows)
  3412. return self._dataset_size
  3413. def is_shuffled(self):
  3414. return self.shuffle_files
  3415. def is_sharded(self):
  3416. if self.num_shards is not None:
  3417. return self.num_shards > 1
  3418. return False