You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 113 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. datasets.py supports various formats of datasets, including ImageNet, TFData,
  17. MNIST, Cifar10/100, Manifest, MindRecord, etc. This module could load data in
  18. high performance and parse data precisely. It also provides the following
  19. operations for users to preprocess data: shuffle, batch, repeat, map, and zip.
  20. """
  21. import glob
  22. import json
  23. import math
  24. import os
  25. import random
  26. import uuid
  27. from enum import Enum
  28. from importlib import import_module
  29. import numpy as np
  30. from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
  31. MindRecordOp, CBatchInfo
  32. from mindspore._c_expression import typing
  33. from mindspore import log as logger
  34. from . import samplers
  35. from .iterators import DictIterator, TupleIterator
  36. from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_zip, check_rename, \
  37. check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
  38. check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
  39. check_zip_dataset, check_add_column
  40. from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
  41. try:
  42. context = import_module("mindspore.context")
  43. except ModuleNotFoundError:
  44. context = None
  45. class Shuffle(str, Enum):
  46. GLOBAL: str = "global"
  47. FILES: str = "file"
  48. @check_zip
  49. def zip(datasets):
  50. """
  51. Zips the datasets in the input tuple of datasets.
  52. Args:
  53. datasets (tuple of class Dataset): A tuple of datasets to be zipped together.
  54. The number of datasets should be more than 1.
  55. Returns:
  56. DatasetOp, ZipDataset.
  57. Raises:
  58. ValueError: If the number of datasets is 1.
  59. TypeError: If datasets is not a tuple.
  60. Examples:
  61. >>> import mindspore.dataset as ds
  62. >>>
  63. >>> dataset_dir1 = "path/to/imagefolder_directory1"
  64. >>> dataset_dir2 = "path/to/imagefolder_directory2"
  65. >>> ds1 = ds.ImageFolderDatasetV2(dataset_dir1, num_parallel_workers=8)
  66. >>> ds2 = ds.ImageFolderDatasetV2(dataset_dir2, num_parallel_workers=8)
  67. >>>
  68. >>> # creates a dataset which is the combination of ds1 and ds2
  69. >>> data = ds.zip((ds1, ds2))
  70. """
  71. if len(datasets) <= 1:
  72. raise ValueError(
  73. "Can't zip empty or just one dataset!")
  74. return ZipDataset(datasets)
  75. def get_num_rows(num_rows, num_shards):
  76. """
  77. Get the number rows of the dataset according to the shards.
  78. Args:
  79. num_rows (int): The number rows of the dataset should be more than 0.
  80. The number rows of the dataset should be more than 0.
  81. num_shards (int or None): Number of shards that the dataset should be divided into.
  82. The number of shards should be None or more than 1.
  83. Returns:
  84. Int, number of rows.
  85. Raises:
  86. ValueError: If num_rows is invalid (< 0).
  87. ValueError: If num_shards is invalid (<= 0).
  88. """
  89. if num_rows < 0:
  90. raise ValueError("num_rows is invalid (< 0)")
  91. if num_shards is not None:
  92. if num_shards <= 0:
  93. raise ValueError("num_shards is invalid (<= 0)")
  94. if num_rows % num_shards == 0:
  95. num_rows = num_rows // num_shards
  96. else:
  97. num_rows = num_rows // num_shards + 1
  98. return num_rows
  99. class Dataset:
  100. """
  101. Abstract class to represent a dataset in DataEngine's data pipeline.
  102. This class is the base class of SourceDataset and DatasetOp, and represents
  103. a node in the data flow graph.
  104. Args:
  105. num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel
  106. (default=None).
  107. """
  108. def __init__(self, num_parallel_workers=None):
  109. self.input = []
  110. self.output = []
  111. self.num_parallel_workers = num_parallel_workers
  112. self._device_iter = 0
  113. self._input_indexs = ()
  114. self._output_types = None
  115. self._output_shapes = None
  116. self._dataset_size = None
  117. self._batch_size = None
  118. self._num_classes = None
  119. self._repeat_count = None
  120. def get_args(self):
  121. """
  122. Returns attributes (member variables) related to the current class.
  123. Must include all arguments passed to the __init__() of the current class, excluding 'input_dataset'.
  124. Args:
  125. Returns:
  126. Python dictionary.
  127. """
  128. args = dict()
  129. args["num_parallel_workers"] = self.num_parallel_workers
  130. return args
  131. @check_batch
  132. def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
  133. input_columns=None):
  134. """
  135. Combines batch_size number of consecutive rows into batches.
  136. For any child node, a batch is treated as a single row.
  137. For any column, all the elements within that column must have the same shape.
  138. If a per_batch_map callable is provided, it will be applied to the batches of tensors.
  139. Note:
  140. The order of using repeat and batch reflects the number of batches. Recommend that
  141. repeat operation should be used after batch operation.
  142. Args:
  143. batch_size (int or function): The number of rows each batch is created with. An
  144. int or callable which takes exactly 1 parameter, BatchInfo.
  145. drop_remainder (bool, optional): Determines whether or not to drop the last
  146. possibly incomplete batch (default=False). If True, and if there are less
  147. than batch_size rows available to make the last batch, then those rows will
  148. be dropped and not propogated to the child node.
  149. num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel (default=None).
  150. per_batch_map (callable, optional): Per batch map callable. A callable which takes
  151. (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represent a batch of
  152. Tensors on a given column. The number of lists should match with number of entries in input_columns. The
  153. last parameter of the callable should always be a BatchInfo object.
  154. input_columns (list of string, optional): List of names of the input columns. The size of the list should
  155. match with signature of per_batch_map callable.
  156. Returns:
  157. BatchDataset, dataset batched.
  158. Examples:
  159. >>> import mindspore.dataset as ds
  160. >>> # data is an instance of Dataset object.
  161. >>> # creates a dataset where every 100 rows is combined into a batch
  162. >>> # and drops the last incomplete batch if there is one.
  163. >>> data = data.batch(100, True)
  164. """
  165. return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns)
  166. @check_shuffle
  167. def shuffle(self, buffer_size):
  168. """
  169. Randomly shuffles the rows of this dataset using the following algorithm:
  170. 1. Make a shuffle buffer that contains the first buffer_size rows.
  171. 2. Randomly select an element from the shuffle buffer to be the next row
  172. propogated to the child node.
  173. 3. Get the next row (if any) from the parent node and put it in the shuffle buffer.
  174. 4. Repeat steps 2 and 3 until there are no more rows left in the shuffle buffer.
  175. A seed can be provided to be used on the first epoch. In every subsequent
  176. epoch, the seed is changed to a new one, randomly generated value.
  177. Args:
  178. buffer_size (int): The size of the buffer (must be larger than 1) for
  179. shuffling. Setting buffer_size equal to the number of rows in the entire
  180. dataset will result in a global shuffle.
  181. Returns:
  182. ShuffleDataset, dataset shuffled.
  183. Examples:
  184. >>> import mindspore.dataset as ds
  185. >>> # data is an instance of Dataset object
  186. >>> # optionally set the seed for the first epoch
  187. >>> ds.config.set_seed(58)
  188. >>>
  189. >>> # creates a shuffled dataset using a shuffle buffer of size 4
  190. >>> data = data.shuffle(4)
  191. """
  192. return ShuffleDataset(self, buffer_size)
  193. @check_map
  194. def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
  195. num_parallel_workers=None):
  196. """
  197. Applies each operation in operations to this dataset.
  198. The order of operations is determined by the position of each operation in operations.
  199. operations[0] will be applied first, then operations[1], then operations[2], etc.
  200. Each operation will be passed one or more columns from the dataset as input, and zero or
  201. more columns will be outputted. The first operation will be passed the columns specified
  202. in input_columns as input. If there is more than one operator in operations, the outputted
  203. columns of the previous operation are used as the input columns for the next operation.
  204. The columns outputted by the very last operation will be assigned names specified by
  205. output_columns.
  206. Only the columns specified in columns_order will be propagated to the child node. These
  207. columns will be in the same order as specified in columns_order.
  208. Args:
  209. input_columns (list[str]): List of the names of the columns that will be passed to
  210. the first operation as input. The size of this list must match the number of
  211. input columns expected by the first operator. (default=None, the first
  212. operation will be passed however many columns that is required, starting from
  213. the first column).
  214. operations (list[TensorOp] or Python list[functions]): List of operations to be
  215. applied on the dataset. Operations are applied in the order they appear in this list.
  216. output_columns (list[str], optional): List of names assigned to the columns outputted by
  217. the last operation. This parameter is mandatory if len(input_columns) !=
  218. len(output_columns). The size of this list must match the number of output
  219. columns of the last operation. (default=None, output columns will have the same
  220. name as the input columns, i.e., the columns will be replaced).
  221. columns_order (list[str], optional): list of all the desired columns to propagate to the
  222. child node. This list must be a subset of all the columns in the dataset after
  223. all operations are applied. The order of the columns in each row propagated to the
  224. child node follow the order they appear in this list. The parameter is mandatory
  225. if the len(input_columns) != len(output_columns). (default=None, all columns
  226. will be propagated to the child node, the order of the columns will remain the
  227. same).
  228. num_parallel_workers (int, optional): Number of threads used to process the dataset in
  229. parallel (default=None, the value from the config will be used).
  230. Returns:
  231. MapDataset, dataset after mapping operation.
  232. Examples:
  233. >>> import mindspore.dataset as ds
  234. >>> import mindspore.dataset.transforms.vision.c_transforms as c_transforms
  235. >>>
  236. >>> # data is an instance of Dataset which has 2 columns, "image" and "label".
  237. >>> # ds_pyfunc is an instance of Dataset which has 3 columns, "col0", "col1", and "col2". Each column is
  238. >>> # a 2d array of integers.
  239. >>>
  240. >>> # This config is a global setting, meaning that all future operations which
  241. >>> # uses this config value will use 2 worker threads, unless if specified
  242. >>> # otherwise in their constructor. set_num_parallel_workers can be called
  243. >>> # again later if a different number of worker threads are needed.
  244. >>> ds.config.set_num_parallel_workers(2)
  245. >>>
  246. >>> # Two operations, which takes 1 column for input and outputs 1 column.
  247. >>> decode_op = c_transforms.Decode(rgb_format=True)
  248. >>> random_jitter_op = c_transforms.RandomColorAdjust((0.8, 0.8), (1, 1), (1, 1), (0, 0))
  249. >>>
  250. >>> # 1) Simple map example
  251. >>>
  252. >>> operations = [decode_op]
  253. >>> input_columns = ["image"]
  254. >>>
  255. >>> # Applies decode_op on column "image". This column will be replaced by the outputed
  256. >>> # column of decode_op. Since columns_order is not provided, both columns "image"
  257. >>> # and "label" will be propagated to the child node in their original order.
  258. >>> ds_decoded = data.map(input_columns, operations)
  259. >>>
  260. >>> # Rename column "image" to "decoded_image"
  261. >>> output_columns = ["decoded_image"]
  262. >>> ds_decoded = data.map(input_columns, operations, output_columns)
  263. >>>
  264. >>> # Specify the order of the columns.
  265. >>> columns_order ["label", "image"]
  266. >>> ds_decoded = data.map(input_columns, operations, None, columns_order)
  267. >>>
  268. >>> # Rename column "image" to "decoded_image" and also specify the order of the columns.
  269. >>> columns_order ["label", "decoded_image"]
  270. >>> output_columns = ["decoded_image"]
  271. >>> ds_decoded = data.map(input_columns, operations, output_columns, columns_order)
  272. >>>
  273. >>> # Rename column "image" to "decoded_image" and keep only this column.
  274. >>> columns_order ["decoded_image"]
  275. >>> output_columns = ["decoded_image"]
  276. >>> ds_decoded = data.map(input_columns, operations, output_columns, columns_order)
  277. >>>
  278. >>> # Simple example using pyfunc. Renaming columns and specifying column order
  279. >>> # work in the same way as the previous examples.
  280. >>> input_columns = ["col0"]
  281. >>> operations = [(lambda x: x + 1)]
  282. >>> ds_mapped = ds_pyfunc.map(input_columns, operations)
  283. >>>
  284. >>> # 2) Map example with more than one operation
  285. >>>
  286. >>> # If this list of operations is used with map, decode_op will be applied
  287. >>> # first, then random_jitter_op will be applied.
  288. >>> operations = [decode_op, random_jitter_op]
  289. >>>
  290. >>> input_columns = ["image"]
  291. >>>
  292. >>> # Creates a dataset where the images are decoded, then randomly color jittered.
  293. >>> # decode_op takes column "image" as input and outputs one column. The column
  294. >>> # outputted by decode_op is passed as input to random_jitter_op.
  295. >>> # random_jitter_op will output one column. Column "image" will be replaced by
  296. >>> # the column outputted by random_jitter_op (the very last operation). All other
  297. >>> # columns are unchanged. Since columns_order is not specified, the order of the
  298. >>> # columns will remain the same.
  299. >>> ds_mapped = data.map(input_columns, operations)
  300. >>>
  301. >>> # Creates a dataset that is identical to ds_mapped, except the column "image"
  302. >>> # that is outputted by random_jitter_op is renamed to "image_transformed".
  303. >>> # Specifying column order works in the same way as examples in 1).
  304. >>> output_columns = ["image_transformed"]
  305. >>> ds_mapped_and_renamed = data.map(input_columns, operation, output_columns)
  306. >>>
  307. >>> # Multiple operations using pyfunc. Renaming columns and specifying column order
  308. >>> # work in the same way as examples in 1).
  309. >>> input_columns = ["col0"]
  310. >>> operations = [(lambda x: x + x), (lambda x: x - 1)]
  311. >>> output_columns = ["col0_mapped"]
  312. >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns)
  313. >>>
  314. >>> # 3) Example where number of input columns is not equal to number of output columns
  315. >>>
  316. >>> # operations[0] is a lambda that takes 2 columns as input and outputs 3 columns.
  317. >>> # operations[1] is a lambda that takes 3 columns as input and outputs 1 column.
  318. >>> # operations[1] is a lambda that takes 1 column as input and outputs 4 columns.
  319. >>> #
  320. >>> # Note: the number of output columns of operation[i] must equal the number of
  321. >>> # input columns of operation[i+1]. Otherwise, this map call will also result
  322. >>> # in an error.
  323. >>> operations = [(lambda x y: (x, x + y, x + y + 1)),
  324. >>> (lambda x y z: x * y * z),
  325. >>> (lambda x: (x % 2, x % 3, x % 5, x % 7))]
  326. >>>
  327. >>> # Note: because the number of input columns is not the same as the number of
  328. >>> # output columns, the output_columns and columns_order parameter must be
  329. >>> # specified. Otherwise, this map call will also result in an error.
  330. >>> input_columns = ["col2", "col0"]
  331. >>> output_columns = ["mod2", "mod3", "mod5", "mod7"]
  332. >>>
  333. >>> # Propagate all columns to the child node in this order:
  334. >>> columns_order = ["col0", "col2", "mod2", "mod3", "mod5", "mod7", "col1"]
  335. >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
  336. >>>
  337. >>> # Propagate some columns to the child node in this order:
  338. >>> columns_order = ["mod7", "mod3", "col1"]
  339. >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
  340. """
  341. return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers)
  342. @check_repeat
  343. def repeat(self, count=None):
  344. """
  345. Repeats this dataset count times. Repeat indefinitely if the count is None or -1.
  346. Note:
  347. The order of using repeat and batch reflects the number of batches. Recommend that
  348. repeat operation should be used after batch operation.
  349. If dataset_sink_mode is False, here repeat operation is invalid.
  350. If dataset_sink_mode is True, repeat count should be euqal to the epoch of training. Otherwise,
  351. errors could occur since the amount of data is not the amount training requires.
  352. Args:
  353. count (int): Number of times the dataset should be repeated (default=None).
  354. Returns:
  355. RepeatDataset, dataset repeated.
  356. Examples:
  357. >>> import mindspore.dataset as ds
  358. >>> # data is an instance of Dataset object.
  359. >>> # creates a dataset where the dataset is repeated for 50 epochs
  360. >>> repeated = data.repeat(50)
  361. >>>
  362. >>> # creates a dataset where each epoch is shuffled individually
  363. >>> shuffled_and_repeated = data.shuffle(10)
  364. >>> shuffled_and_repeated = shuffled_and_repeated.repeat(50)
  365. >>>
  366. >>> # creates a dataset where the dataset is first repeated for
  367. >>> # 50 epochs before shuffling. the shuffle operator will treat
  368. >>> # the entire 50 epochs as one big dataset.
  369. >>> repeat_and_shuffle = data.repeat(50)
  370. >>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10)
  371. """
  372. return RepeatDataset(self, count)
  373. @check_zip_dataset
  374. def zip(self, datasets):
  375. """
  376. Zips the datasets in the input tuple of datasets. Columns in the input datasets must not have the same name.
  377. Args:
  378. datasets (tuple or class Dataset): A tuple of datasets or a single class Dataset
  379. to be zipped together with this dataset.
  380. Returns:
  381. ZipDataset, dataset zipped.
  382. Examples:
  383. >>> import mindspore.dataset as ds
  384. >>> # ds1 and ds2 are instances of Dataset object
  385. >>> # creates a dataset which is the combination of ds1 and ds2
  386. >>> data = ds1.zip(ds2)
  387. """
  388. if isinstance(datasets, tuple):
  389. datasets = (self, *datasets)
  390. elif isinstance(datasets, Dataset):
  391. datasets = (self, datasets)
  392. else:
  393. raise TypeError("The zip function %s type error!" % (datasets))
  394. return ZipDataset(datasets)
  395. @check_rename
  396. def rename(self, input_columns, output_columns):
  397. """
  398. Renames the columns in input datasets.
  399. Args:
  400. input_columns (list[str]): list of names of the input columns.
  401. output_columns (list[str]): list of names of the output columns.
  402. Returns:
  403. RenameDataset, dataset renamed.
  404. Examples:
  405. >>> import mindspore.dataset as ds
  406. >>> # data is an instance of Dataset object.
  407. >>> input_columns = ["input_col1", "input_col2", "input_col3"]
  408. >>> output_columns = ["output_col1", "output_col2", "output_col3"]
  409. >>>
  410. >>> # creates a dataset where input_col1 is renamed to output_col1, and
  411. >>> # input_col2 is renamed to output_col2, and input_col3 is renamed
  412. >>> # to output_col3.
  413. >>> data = data.rename(input_columns=input_columns, output_columns=output_columns)
  414. """
  415. return RenameDataset(self, input_columns, output_columns)
  416. @check_project
  417. def project(self, columns):
  418. """
  419. Projects certain columns in input datasets.
  420. The specified columns will be selected from the dataset and passed down
  421. the pipeline in the order specified. The other columns are discarded.
  422. Args:
  423. columns(list[str]): list of names of the columns to project.
  424. Returns:
  425. ProjectDataset, dataset projected.
  426. Examples:
  427. >>> import mindspore.dataset as ds
  428. >>> # data is an instance of Dataset object
  429. >>> columns_to_project = ["column3", "column1", "column2"]
  430. >>>
  431. >>> # creates a dataset that consist of column3, column1, column2
  432. >>> # in that order, regardless of the original order of columns.
  433. >>> data = data.project(columns=columns_to_project)
  434. """
  435. return ProjectDataset(self, columns)
  436. def apply(self, apply_func):
  437. """
  438. Apply a function in this dataset.
  439. The specified apply_func is a function that must take one 'Dataset' as an argument
  440. and return a preprogressing 'Dataset'.
  441. Args:
  442. apply_func (function): A function that must take one 'Dataset' as an argument and
  443. return a preprogressing 'Dataset'.
  444. Returns:
  445. Dataset, applied by the function.
  446. Examples:
  447. >>> import numpy as np
  448. >>> import mindspore.dataset as ds
  449. >>> # Generate 1d int numpy array from 0 - 6
  450. >>> def generator_1d():
  451. >>> for i in range(6):
  452. >>> yield (np.array([i]),)
  453. >>> # 1) get all data from dataset
  454. >>> data = ds.GeneratorDataset(generator_1d, ["data"])
  455. >>> # 2) declare a apply_func function
  456. >>> def apply_func(ds):
  457. >>> ds = ds.batch(2)
  458. >>> return ds
  459. >>> # 3) use apply to call apply_func
  460. >>> data = data.apply(apply_func)
  461. >>> for item in data.create_dict_iterator():
  462. >>> print(item["data"])
  463. Raises:
  464. TypeError: If apply_func is not a function.
  465. TypeError: If apply_func doesn't return a Dataset.
  466. """
  467. if not hasattr(apply_func, '__call__'):
  468. raise TypeError("apply_func must be a function.")
  469. dataset = apply_func(self)
  470. if not isinstance(dataset, Dataset):
  471. raise TypeError("apply_func must return a dataset.")
  472. return dataset
  473. def device_que(self, prefetch_size=None):
  474. """
  475. Returns a transferredDataset that transfer data through tdt.
  476. Args:
  477. prefetch_size (int, optional): prefetch number of records ahead of the
  478. user's request (default=None).
  479. Return:
  480. TransferDataset, dataset for transferring.
  481. """
  482. return self.to_device()
  483. def to_device(self, num_batch=None):
  484. """
  485. Transfers data through CPU, GPU or Ascend devices.
  486. Args:
  487. num_batch (int, optional): limit the number of batch to be sent to device (default=None).
  488. Returns:
  489. TransferDataset, dataset for transferring.
  490. Raises:
  491. TypeError: If device_type is empty.
  492. ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'.
  493. ValueError: If num_batch is None or 0 or larger than int_max.
  494. RuntimeError: If dataset is unknown.
  495. RuntimeError: If distribution file path is given but failed to read.
  496. """
  497. if num_batch is None:
  498. num_batch = self.get_dataset_size()
  499. repeat_count = self.get_repeat_count()
  500. num_batch = num_batch * repeat_count
  501. queue_name = str(uuid.uuid1())
  502. if context:
  503. device_type = context.get_context("device_target")
  504. else:
  505. device_type = "CPU"
  506. if device_type == "":
  507. raise TypeError("Please set device_type in context")
  508. if device_type not in ('Ascend', 'GPU', 'CPU'):
  509. raise ValueError("only support CPU, Ascend, GPU")
  510. if num_batch is None or num_batch == 0:
  511. raise ValueError("num_batch is None or 0.")
  512. def get_distribution(output_dataset):
  513. dev_id = 0
  514. if isinstance(output_dataset, (StorageDataset, GeneratorDataset, MindDataset)):
  515. return output_dataset.distribution, dev_id
  516. if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, ImageFolderDatasetV2,
  517. ManifestDataset, MnistDataset, VOCDataset, CelebADataset)):
  518. sampler = output_dataset.sampler
  519. if isinstance(sampler, samplers.DistributedSampler):
  520. dev_id = sampler.shard_id
  521. return "", dev_id
  522. if isinstance(output_dataset, TFRecordDataset):
  523. if output_dataset.shard_id is not None:
  524. dev_id = output_dataset.shard_id
  525. return "", dev_id
  526. if not output_dataset.input:
  527. raise RuntimeError("Unknown output_dataset: {}".format(type(output_dataset)))
  528. input_dataset = output_dataset.input[0]
  529. return get_distribution(input_dataset)
  530. distribution_path, device_id = get_distribution(self)
  531. if distribution_path == "":
  532. return TransferDataset(self, queue_name, device_id, device_type, num_batch)
  533. try:
  534. with open(distribution_path, 'r') as distribution_f:
  535. dist = json.load(distribution_f)
  536. device_id = dist["deviceId"]
  537. except json.decoder.JSONDecodeError:
  538. raise RuntimeError("Json decode error when load distribution file")
  539. except Exception:
  540. raise RuntimeError("Distribution file failed to read")
  541. return TransferDataset(self, queue_name, device_id, device_type, num_batch)
  542. def create_tuple_iterator(self, columns=None):
  543. """
  544. Create an Iterator over the dataset. The data retrieved will be a list of ndarray of data.
  545. To specify which columns to list and the order needed, use columns_list. If columns_list
  546. is not provided, the order of the columns will not be changed.
  547. Args:
  548. columns (list[str], optional): List of columns to be used to specify the order of columns
  549. (defaults=None, means all columns).
  550. Returns:
  551. Iterator, list of ndarray.
  552. Examples:
  553. >>> import mindspore.dataset as ds
  554. >>> # data is an instance of Dataset object
  555. >>> # creates an iterator. The columns in the data obtained by the
  556. >>> # iterator will not be changed.
  557. >>> iterator = data.create_tuple_iterator()
  558. >>> for item in iterator:
  559. >>> # convert the returned tuple to a list and print
  560. >>> print(list(item))
  561. """
  562. return TupleIterator(self, columns)
  563. def create_dict_iterator(self):
  564. """
  565. Create an Iterator over the dataset.
  566. The data retrieved will be a dictionary. The order
  567. of the columns in the dictionary may not be the same as the original order.
  568. Returns:
  569. Iterator, dictionary of column_name-ndarray pair.
  570. Examples:
  571. >>> import mindspore.dataset as ds
  572. >>> # data is an instance of Dataset object
  573. >>> # creates an iterator. The columns in the data obtained by the
  574. >>> # iterator might be changed.
  575. >>> iterator = data.create_dict_iterator()
  576. >>> for item in iterator:
  577. >>> # print the data in column1
  578. >>> print(item["column1"])
  579. """
  580. return DictIterator(self)
  581. def __iter__(self):
  582. """Create an Iterator over the dataset."""
  583. return self.create_tuple_iterator()
  584. @staticmethod
  585. def read_dir(dir_path, schema, columns_list=None, num_parallel_workers=None,
  586. deterministic_output=True, prefetch_size=None, shuffle=False, seed=None, distribution=""):
  587. """
  588. Append the path of all files in the dir_path to StorageDataset.
  589. Args:
  590. dir_path (str): Path to the directory that contains the dataset.
  591. schema (str): Path to the json schema file.
  592. columns_list (list[str], optional): List of columns to be read (default=None).
  593. If not provided, read all columns.
  594. num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel
  595. (default=None).
  596. deterministic_output (bool, optional): Whether the result of this dataset can be reproduced
  597. or not (default=True). If True, performance might be affected.
  598. prefetch_size (int, optional): Prefetch number of records ahead of the
  599. user's request (default=None).
  600. shuffle (bool, optional): Shuffle the list of files in the directory (default=False).
  601. seed (int, optional): Create a random generator with a fixed seed. If set to None,
  602. create a random seed (default=None).
  603. distribution (str, optional): The path of distribution config file (default="").
  604. Returns:
  605. StorageDataset.
  606. Raises:
  607. ValueError: If dataset folder does not exist.
  608. ValueError: If dataset folder permission denied.
  609. """
  610. logger.warning("WARN_DEPRECATED: The usage of read_dir is deprecated, please use TFRecordDataset with GLOB.")
  611. list_files = []
  612. if not os.path.isdir(dir_path):
  613. raise ValueError("The dataset folder does not exist!")
  614. if not os.access(dir_path, os.R_OK):
  615. raise ValueError("The dataset folder permission denied!")
  616. for root, _, files in os.walk(dir_path):
  617. for file in files:
  618. list_files.append(os.path.join(root, file))
  619. list_files.sort()
  620. if shuffle:
  621. rand = random.Random(seed)
  622. rand.shuffle(list_files)
  623. return StorageDataset(list_files, schema, distribution, columns_list, num_parallel_workers,
  624. deterministic_output, prefetch_size)
  625. @property
  626. def input_indexs(self):
  627. return self._input_indexs
  628. @input_indexs.setter
  629. def input_indexs(self, value):
  630. self._input_indexs = value
  631. def _get_pipeline_info(self):
  632. device_iter = TupleIterator(self)
  633. self._output_shapes = device_iter.get_output_shapes()
  634. self._output_types = device_iter.get_output_types()
  635. if self._dataset_size is None:
  636. self._dataset_size = device_iter.get_dataset_size()
  637. self._batch_size = device_iter.get_batch_size()
  638. self._num_classes = device_iter.num_classes()
  639. self._repeat_count = device_iter.get_repeat_count()
  640. device_iter.release()
  641. def output_shapes(self):
  642. """
  643. Get the shapes of output data.
  644. Return:
  645. List, list of shape of each column.
  646. """
  647. if self._output_shapes is None:
  648. self._get_pipeline_info()
  649. return self._output_shapes
  650. def output_types(self):
  651. """
  652. Get the types of output data.
  653. Return:
  654. List of data type.
  655. """
  656. if self._output_types is None:
  657. self._get_pipeline_info()
  658. return self._output_types
  659. def get_dataset_size(self):
  660. """
  661. Get the number of batches in an epoch.
  662. Return:
  663. Number, number of batches.
  664. """
  665. if self.input:
  666. return self.input[0].get_dataset_size()
  667. return None
  668. def num_classes(self):
  669. """
  670. Get the number of classes in a dataset.
  671. Return:
  672. Number, number of classes.
  673. """
  674. if self.input:
  675. return self.input[0].num_classes()
  676. return None
  677. def get_batch_size(self):
  678. """
  679. Get the size of a batch.
  680. Return:
  681. Number, the number of data in a batch.
  682. """
  683. if self.input:
  684. return self.input[0].get_batch_size()
  685. return 1
  686. def get_repeat_count(self):
  687. """
  688. Get the replication times in RepeatDataset else 1
  689. Return:
  690. Number, the count of repeat.
  691. """
  692. if self.input:
  693. return self.input[0].get_repeat_count()
  694. return 1
  695. def get_class_indexing(self):
  696. """
  697. Get the class index.
  698. Return:
  699. Dict, A str-to-int mapping from label name to index.
  700. """
  701. if self.input:
  702. return self.input[0].get_class_indexing()
  703. raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self)))
  704. def reset(self):
  705. """Reset the dataset for next epoch"""
  706. class SourceDataset(Dataset):
  707. """
  708. Abstract class to represent a source dataset which produces content to the data pipeline.
  709. """
  710. # No need for __init__ since it is the same as the super's init
  711. class DatasetOp(Dataset):
  712. """
  713. Abstract class to represent a operations on dataset.
  714. """
  715. # No need for __init__ since it is the same as the super's init
  716. class BatchDataset(DatasetOp):
  717. """
  718. The result of applying Batch operator to the input dataset.
  719. Args:
  720. input_dataset (Dataset): Input Dataset to be batched.
  721. batch_size (int): The size of the batch.
  722. drop_remainder (bool, optional): Whether drop the remainder batch of data (drop_remainder=False).
  723. If True, the last incomplete batch will be dropped.
  724. """
  725. def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None,
  726. per_batch_map=None, input_columns=None):
  727. super().__init__(num_parallel_workers)
  728. if BatchDataset._is_ancestor_of_repeat(input_dataset):
  729. logger.warning("Repeat is located before batch, data from two epochs can be batched together.")
  730. self.batch_size = batch_size
  731. self.drop_remainder = drop_remainder
  732. self.per_batch_map = per_batch_map
  733. self.input_columns = input_columns
  734. self.input.append(input_dataset)
  735. input_dataset.output.append(self)
  736. self._input_indexs = input_dataset.input_indexs
  737. def get_args(self):
  738. args = super().get_args()
  739. args["batch_size"] = self.batch_size
  740. args["drop_remainder"] = self.drop_remainder
  741. args["per_batch_map"] = self.per_batch_map
  742. args["input_columns"] = self.input_columns
  743. return args
  744. def get_dataset_size(self):
  745. """
  746. Get the number of batches in an epoch.
  747. Return:
  748. Number, number of batches.
  749. """
  750. child_size = self.input[0].get_dataset_size()
  751. if child_size is not None:
  752. if self.drop_remainder:
  753. return math.floor(child_size / self.batch_size)
  754. return math.ceil(child_size / self.batch_size)
  755. return None
  756. def get_batch_size(self):
  757. """
  758. Get the size of a batch.
  759. Return:
  760. Number, the number of data in a batch.
  761. """
  762. return self.batch_size
  763. @staticmethod
  764. def _is_ancestor_of_repeat(dataset):
  765. """
  766. Utility function to find the case where repeat is used before batch.
  767. Args:
  768. dataset (Dataset): dataset to be checked
  769. Return:
  770. True or False
  771. """
  772. if isinstance(dataset, RepeatDataset):
  773. return True
  774. flag = False
  775. for input_dataset in dataset.input:
  776. flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset)
  777. return flag
  778. class BatchInfo(CBatchInfo):
  779. """
  780. The information object associates with the current batch of tensors.
  781. """
  782. def get_batch_num(self):
  783. """
  784. Return the batch number of the current batch.
  785. Return:
  786. Number, number of the current batch.
  787. """
  788. return
  789. def get_epoch_num(self):
  790. """
  791. Return the epoch number of the current batch.
  792. Return:
  793. Number, number of the current epoch.
  794. """
  795. return
  796. class ShuffleDataset(DatasetOp):
  797. """
  798. The result of applying Shuffle operator to the input Dataset.
  799. Args:
  800. input_dataset (Dataset): Input Dataset to be shuffled.
  801. buffer_size (int): The size of the buffer.
  802. """
  803. def __init__(self, input_dataset, buffer_size):
  804. super().__init__()
  805. self.buffer_size = buffer_size
  806. self.input.append(input_dataset)
  807. input_dataset.output.append(self)
  808. self._input_indexs = input_dataset.input_indexs
  809. def get_args(self):
  810. args = super().get_args()
  811. args["buffer_size"] = self.buffer_size
  812. return args
  813. class MapDataset(DatasetOp):
  814. """
  815. The result of applying Map operator to the input Dataset.
  816. Args:
  817. input_dataset (Dataset): Input Dataset to be mapped.
  818. input_columns (list[str]): List of names of the input columns
  819. (default=None, the operations will be applied on the first columns in the dataset).
  820. The size of the list should match the number of inputs of the first operator.
  821. operations (TensorOp): A function mapping a nested structure of tensors
  822. to another nested structure of tensor (default=None).
  823. output_columns (list[str], optional): list of names of the output columns.
  824. The size of the list should match the number of outputs of the last operator
  825. (default=None, output columns will be the input columns, i.e., the columns will
  826. be replaced).
  827. columns_order (list[str], optional): list of all the desired columns of the dataset (default=None).
  828. The argument is mandatory if len(input_columns) != len(output_columns).
  829. num_parallel_workers (int, optional): Number of workers to process the Dataset
  830. in parallel (default=None).
  831. Raises:
  832. ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
  833. """
  834. def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
  835. num_parallel_workers=None):
  836. super().__init__(num_parallel_workers)
  837. self.input.append(input_dataset)
  838. if input_columns is not None and not isinstance(input_columns, list):
  839. input_columns = [input_columns]
  840. self.input_columns = input_columns
  841. if operations is not None and not isinstance(operations, list):
  842. operations = [operations]
  843. self.operations = operations
  844. if output_columns is not None and not isinstance(output_columns, list):
  845. output_columns = [output_columns]
  846. self.output_columns = output_columns
  847. self.columns_order = columns_order
  848. if self.input_columns and self.output_columns \
  849. and len(self.input_columns) != len(self.output_columns) \
  850. and self.columns_order is None:
  851. raise ValueError("When (len(input_columns) != len(output_columns)), columns_order must be specified.")
  852. input_dataset.output.append(self)
  853. self._input_indexs = input_dataset.input_indexs
  854. def get_args(self):
  855. args = super().get_args()
  856. args["input_columns"] = self.input_columns
  857. args["operations"] = self.operations
  858. args["output_columns"] = self.output_columns
  859. return args
  860. def get_dataset_size(self):
  861. """
  862. Get the number of batches in an epoch.
  863. Return:
  864. Number, number of batches.
  865. """
  866. return self.input[0].get_dataset_size()
  867. class RepeatDataset(DatasetOp):
  868. """
  869. The result of applying Repeat operator to the input Dataset.
  870. Args:
  871. input_dataset (Dataset): Input Dataset to be repeated.
  872. count (int): Number of times the dataset should be repeated.
  873. """
  874. def __init__(self, input_dataset, count):
  875. super().__init__()
  876. if count is None:
  877. self.count = -1
  878. else:
  879. self.count = count
  880. self.input.append(input_dataset)
  881. input_dataset.output.append(self)
  882. self._input_indexs = input_dataset.input_indexs
  883. def get_args(self):
  884. args = super().get_args()
  885. args["count"] = self.count
  886. return args
  887. def get_dataset_size(self):
  888. """
  889. Get the number of batches in an epoch.
  890. Return:
  891. Number, number of batches.
  892. """
  893. child_size = self.input[0].get_dataset_size()
  894. if child_size is not None:
  895. return child_size
  896. return None
  897. def get_repeat_count(self):
  898. """
  899. Get the replication times in RepeatDataset.
  900. Return:
  901. Number, the count of repeat.
  902. """
  903. return self.count
  904. class ZipDataset(DatasetOp):
  905. """
  906. The result of applying Zip operator to the input Dataset.
  907. Args:
  908. datasets (tuple): A tuple of datasets to be zipped together.
  909. Raises:
  910. TypeError: If dataset is not an instance of Dataset.
  911. """
  912. def __init__(self, datasets):
  913. super().__init__()
  914. for dataset in datasets:
  915. if not isinstance(dataset, Dataset):
  916. raise TypeError("The parameter %s of zip has type error!" % (dataset))
  917. self.datasets = datasets
  918. for data in datasets:
  919. self.input.append(data)
  920. data.output.append(self)
  921. def get_dataset_size(self):
  922. """
  923. Get the number of batches in an epoch.
  924. Return:
  925. Number, number of batches.
  926. """
  927. children_sizes = [c.get_dataset_size() for c in self.input]
  928. if all(c is not None for c in children_sizes):
  929. return min(children_sizes)
  930. return None
  931. def num_classes(self):
  932. """
  933. Get the number of classes in a dataset.
  934. Return:
  935. Number, number of classes.
  936. """
  937. return None
  938. def get_args(self):
  939. args = super().get_args()
  940. return args
  941. class RenameDataset(DatasetOp):
  942. """
  943. The result of applying Rename operator to the input Dataset.
  944. Args:
  945. input_dataset (Dataset): Input Dataset to be Renamed.
  946. input_column_names (list[str]): list of names of the input columns.
  947. output_column_names (list[str]): list of names of the output columns.
  948. """
  949. def __init__(self, input_dataset, input_columns, output_columns):
  950. super().__init__()
  951. if not isinstance(input_columns, list):
  952. input_columns = [input_columns]
  953. if not isinstance(output_columns, list):
  954. output_columns = [output_columns]
  955. self.input_column_names = input_columns
  956. self.output_column_names = output_columns
  957. self.input.append(input_dataset)
  958. input_dataset.output.append(self)
  959. self._input_indexs = input_dataset.input_indexs
  960. def get_args(self):
  961. args = super().get_args()
  962. args["input_columns"] = self.input_column_names
  963. args["output_columns"] = self.output_column_names
  964. return args
  965. class ProjectDataset(DatasetOp):
  966. """
  967. The result of applying Project operator to the input Dataset.
  968. Args:
  969. input_dataset (Dataset): Input Dataset to be Project.
  970. columns (list[str]): List of names of the columns to project.
  971. prefetch_size (int, optional): Prefetch number of records ahead of the
  972. user's request (default=None).
  973. """
  974. def __init__(self, input_dataset, columns, prefetch_size=None):
  975. super().__init__()
  976. if not isinstance(columns, list):
  977. columns = [columns]
  978. self.columns = columns
  979. self.input.append(input_dataset)
  980. self.prefetch_size = prefetch_size
  981. input_dataset.output.append(self)
  982. self._input_indexs = input_dataset.input_indexs
  983. def get_args(self):
  984. args = super().get_args()
  985. args["columns"] = self.columns
  986. args["prefetch_size"] = self.prefetch_size
  987. return args
  988. class TransferDataset(DatasetOp):
  989. """
  990. The result of applying TDT operator to the input Dataset.
  991. Args:
  992. input_dataset (Dataset): Input Dataset to be transferred.
  993. queue_name (str): Name of device queue.
  994. device_id (int): Id of device.
  995. device_type (str): Type of device, including "CPU", "GPU", and "Ascend".
  996. num_batch (int): limit the number of batch to be sent to device (default=None).
  997. """
  998. def __init__(self, input_dataset, queue_name, device_id, device_type, num_batch=None):
  999. super().__init__()
  1000. self.input.append(input_dataset)
  1001. input_dataset.output.append(self)
  1002. self.queue_name = queue_name
  1003. self._input_indexs = input_dataset.input_indexs
  1004. self._device_type = device_type
  1005. self._device_id = device_id
  1006. self.__num_batch = num_batch
  1007. self.iterator = None
  1008. def get_args(self):
  1009. args = super().get_args()
  1010. args["queue_name"] = self.queue_name
  1011. args["device_type"] = self._device_type
  1012. args["device_id"] = self._device_id
  1013. args["num_batch"] = self.__num_batch
  1014. return args
  1015. def create_dict_iterator(self):
  1016. raise RuntimeError("TransferDataset is not iterable")
  1017. def create_tuple_iterator(self, columns=None):
  1018. raise RuntimeError("TransferDataset is not iterable")
  1019. def __iter__(self):
  1020. raise RuntimeError("TransferDataset is not iterable")
  1021. def output_shapes(self):
  1022. raise RuntimeError("TransferDataset does not support output_shapes")
  1023. def output_types(self):
  1024. raise RuntimeError("TransferDataset does not support output_types")
  1025. def send(self):
  1026. # need to keep iterator alive so the executionTree is not destroyed
  1027. self.iterator = TupleIterator(self)
  1028. class StorageDataset(SourceDataset):
  1029. """
  1030. A source dataset that reads and parses datasets stored on disk in various formats, including TFData format.
  1031. Args:
  1032. dataset_files (list[str]): List of files to be read.
  1033. schema (str): Path to the json schema file.
  1034. distribution (str, optional): Path of distribution config file (default="").
  1035. columns_list (list[str], optional): List of columns to be read (default=None, read all columns).
  1036. num_parallel_workers (int, optional): Number of parallel working threads (default=None).
  1037. deterministic_output (bool, optional): Whether the result of this dataset can be reproduced
  1038. or not (default=True). If True, performance might be affected.
  1039. prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None).
  1040. Raises:
  1041. RuntimeError: If schema file failed to read.
  1042. RuntimeError: If distribution file path is given but failed to read.
  1043. """
  1044. @check
  1045. def __init__(self, dataset_files, schema, distribution="", columns_list=None, num_parallel_workers=None,
  1046. deterministic_output=None, prefetch_size=None):
  1047. super().__init__(num_parallel_workers)
  1048. logger.warning("WARN_DEPRECATED: The usage of StorageDataset is deprecated, please use TFRecordDataset.")
  1049. self.dataset_files = dataset_files
  1050. try:
  1051. with open(schema, 'r') as load_f:
  1052. json.load(load_f)
  1053. except json.decoder.JSONDecodeError:
  1054. raise RuntimeError("Json decode error when load schema file")
  1055. except Exception:
  1056. raise RuntimeError("Schema file failed to load")
  1057. if distribution != "":
  1058. try:
  1059. with open(distribution, 'r') as load_d:
  1060. json.load(load_d)
  1061. except json.decoder.JSONDecodeError:
  1062. raise RuntimeError("Json decode error when load distribution file")
  1063. except Exception:
  1064. raise RuntimeError("Distribution file failed to load")
  1065. if self.dataset_files is None:
  1066. schema = None
  1067. distribution = None
  1068. self.schema = schema
  1069. self.distribution = distribution
  1070. self.columns_list = columns_list
  1071. self.deterministic_output = deterministic_output
  1072. self.prefetch_size = prefetch_size
  1073. def get_args(self):
  1074. args = super().get_args()
  1075. args["dataset_files"] = self.dataset_files
  1076. args["schema"] = self.schema
  1077. args["distribution"] = self.distribution
  1078. args["columns_list"] = self.columns_list
  1079. args["deterministic_output"] = self.deterministic_output
  1080. args["prefetch_size"] = self.prefetch_size
  1081. return args
  1082. def get_dataset_size(self):
  1083. """
  1084. Get the number of batches in an epoch.
  1085. Return:
  1086. Number, number of batches.
  1087. """
  1088. if self._dataset_size is None:
  1089. self._get_pipeline_info()
  1090. return self._dataset_size
  1091. # manually set dataset_size as a temporary solution.
  1092. def set_dataset_size(self, value):
  1093. logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
  1094. if value >= 0:
  1095. self._dataset_size = value
  1096. else:
  1097. raise ValueError('set dataset_size with negative value {}'.format(value))
  1098. def num_classes(self):
  1099. """
  1100. Get the number of classes in dataset.
  1101. Return:
  1102. Number, number of classes.
  1103. Raises:
  1104. ValueError: If dataset type is invalid.
  1105. ValueError: If dataset is not Imagenet dataset or manifest dataset.
  1106. RuntimeError: If schema file is given but failed to load.
  1107. """
  1108. cur_dataset = self
  1109. while cur_dataset.input:
  1110. cur_dataset = cur_dataset.input[0]
  1111. if not hasattr(cur_dataset, "schema"):
  1112. raise ValueError("Dataset type is invalid")
  1113. # Only IMAGENET/MANIFEST support numclass
  1114. try:
  1115. with open(cur_dataset.schema, 'r') as load_f:
  1116. load_dict = json.load(load_f)
  1117. except json.decoder.JSONDecodeError:
  1118. raise RuntimeError("Json decode error when load schema file")
  1119. except Exception:
  1120. raise RuntimeError("Schema file failed to load")
  1121. if load_dict["datasetType"] != "IMAGENET" and load_dict["datasetType"] != "MANIFEST":
  1122. raise ValueError("%s dataset does not support num_classes!" % (load_dict["datasetType"]))
  1123. if self._num_classes is None:
  1124. self._get_pipeline_info()
  1125. return self._num_classes
  1126. class RangeDataset(SourceDataset):
  1127. """
  1128. A source dataset that reads and parses datasets stored on disk in a range.
  1129. Args:
  1130. start (int): starting index.
  1131. stop (int): ending index.
  1132. step (int): step size in a range.
  1133. """
  1134. def __init__(self, start, stop, step):
  1135. super().__init__()
  1136. self.start = start
  1137. self.stop = stop
  1138. self.step = step
  1139. def get_args(self):
  1140. args = super().get_args()
  1141. args["start"] = self.start
  1142. args["stop"] = self.stop
  1143. args["step"] = self.step
  1144. return args
  1145. def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
  1146. """
  1147. Create sampler based on user input.
  1148. Args:
  1149. num_samples (int): Number of samples
  1150. input_sampler (Iterable / Sampler): Sampler from user
  1151. shuffle (bool): Shuffle
  1152. num_shards (int): Number of shard for sharding
  1153. shard_id (int): Shard ID
  1154. """
  1155. if shuffle is None:
  1156. if input_sampler is not None:
  1157. # If shuffle is not specified, user provided sampler, use user's sampler
  1158. return input_sampler
  1159. if num_shards is not None:
  1160. # If shuffle is not specified, sharding enabled, use distributed random sampler
  1161. shuffle = True
  1162. return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle)
  1163. # If shuffle is not specified, sharding disabled, use random sampler
  1164. if num_samples is not None:
  1165. return samplers.RandomSampler(replacement=True, num_samples=num_samples)
  1166. return samplers.RandomSampler()
  1167. if shuffle is True:
  1168. if num_shards is not None:
  1169. # If shuffle enabled, sharding enabled, use distributed random sampler
  1170. return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle)
  1171. # If shuffle enabled, sharding disabled, use random sampler
  1172. if num_samples is not None:
  1173. return samplers.RandomSampler(replacement=True, num_samples=num_samples)
  1174. return samplers.RandomSampler()
  1175. if num_shards is not None:
  1176. # If shuffle disabled, sharding enabled, use distributed sequential sampler
  1177. return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle)
  1178. # If shuffle disabled, sharding disabled, use sequential sampler
  1179. return samplers.SequentialSampler()
  1180. class ImageFolderDatasetV2(SourceDataset):
  1181. """
  1182. A source dataset that reads images from a tree of directories.
  1183. All images within one folder have the same label.
  1184. The generated dataset has two columns ['image', 'label'].
  1185. The shape of the image column is [image_size] if decode flag is False, or [H,W,C]
  1186. otherwise.
  1187. The type of the image tensor is uint8. The label is just a scalar uint64
  1188. tensor.
  1189. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  1190. below shows what input args are allowed and their expected behavior.
  1191. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  1192. :widths: 25 25 50
  1193. :header-rows: 1
  1194. * - Parameter 'sampler'
  1195. - Parameter 'shuffle'
  1196. - Expected Order Behavior
  1197. * - None
  1198. - None
  1199. - random order
  1200. * - None
  1201. - True
  1202. - random order
  1203. * - None
  1204. - False
  1205. - sequential order
  1206. * - Sampler object
  1207. - None
  1208. - order defined by sampler
  1209. * - Sampler object
  1210. - True
  1211. - not allowed
  1212. * - Sampler object
  1213. - False
  1214. - not allowed
  1215. Args:
  1216. dataset_dir (str): Path to the root directory that contains the dataset.
  1217. num_samples (int, optional): The number of images to be included in the dataset
  1218. (default=None, all images).
  1219. num_parallel_workers (int, optional): Number of workers to read the data
  1220. (default=None, set in the config).
  1221. shuffle (bool, optional): Whether or not to perform shuffle on the dataset
  1222. (default=None, expected order behavior shown in the table).
  1223. sampler (Sampler, optional): Object used to choose samples from the
  1224. dataset (default=None, expected order behavior shown in the table).
  1225. extensions (list[str], optional): List of file extensions to be
  1226. included in the dataset (default=None).
  1227. class_indexing (dict, optional): A str-to-int mapping from folder name to index
  1228. (default=None, the folder names will be sorted
  1229. alphabetically and each class will be given a
  1230. unique index starting from 0).
  1231. decode (bool, optional): decode the images after reading (default=False).
  1232. num_shards (int, optional): Number of shards that the dataset should be divided
  1233. into (default=None).
  1234. shard_id (int, optional): The shard ID within num_shards (default=None). This
  1235. argument should be specified only when num_shards is also specified.
  1236. Raises:
  1237. RuntimeError: If sampler and shuffle are specified at the same time.
  1238. RuntimeError: If sampler and sharding are specified at the same time.
  1239. RuntimeError: If num_shards is specified but shard_id is None.
  1240. RuntimeError: If shard_id is specified but num_shards is None.
  1241. RuntimeError: If class_indexing is not a dictionary.
  1242. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  1243. Examples:
  1244. >>> import mindspore.dataset as ds
  1245. >>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images
  1246. >>> dataset_dir = "/path/to/imagefolder_directory"
  1247. >>> # 1) read all samples (image files) in dataset_dir with 8 threads
  1248. >>> imagefolder_dataset = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8)
  1249. >>> # 2) read all samples (image files) from folder cat and folder dog with label 0 and 1
  1250. >>> imagefolder_dataset = ds.ImageFolderDatasetV2(dataset_dir,class_indexing={"cat":0,"dog":1})
  1251. >>> # 3) read all samples (image files) in dataset_dir with extensions .JPEG and .png (case sensitive)
  1252. >>> imagefolder_dataset = ds.ImageFolderDatasetV2(dataset_dir, extensions={".JPEG",".png"})
  1253. """
  1254. @check_imagefolderdatasetv2
  1255. def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
  1256. shuffle=None, sampler=None, extensions=None, class_indexing=None,
  1257. decode=False, num_shards=None, shard_id=None):
  1258. super().__init__(num_parallel_workers)
  1259. self.dataset_dir = dataset_dir
  1260. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  1261. self.num_samples = num_samples
  1262. self.shuffle_level = shuffle
  1263. self.extensions = extensions
  1264. self.class_indexing = class_indexing
  1265. self.decode = decode
  1266. self.num_shards = num_shards
  1267. self.shard_id = shard_id
  1268. def get_args(self):
  1269. args = super().get_args()
  1270. args["dataset_dir"] = self.dataset_dir
  1271. args["num_samples"] = self.num_samples
  1272. args["sampler"] = self.sampler
  1273. args["shuffle"] = self.shuffle_level
  1274. args["extensions"] = self.extensions
  1275. args["class_indexing"] = self.class_indexing
  1276. args["decode"] = self.decode
  1277. args["num_shards"] = self.num_shards
  1278. args["shard_id"] = self.shard_id
  1279. return args
  1280. def get_dataset_size(self):
  1281. """
  1282. Get the number of batches in an epoch.
  1283. Return:
  1284. Number, number of batches.
  1285. """
  1286. if self.num_samples is None:
  1287. num_samples = 0
  1288. else:
  1289. num_samples = self.num_samples
  1290. num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0]
  1291. return get_num_rows(num_rows, self.num_shards)
  1292. def num_classes(self):
  1293. """
  1294. Get the number of classes in dataset.
  1295. Return:
  1296. Number, number of classes.
  1297. """
  1298. if self.num_samples is None:
  1299. num_samples = 0
  1300. else:
  1301. num_samples = self.num_samples
  1302. return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[1]
  1303. class MnistDataset(SourceDataset):
  1304. """
  1305. A source dataset for reading and parsing the Mnist dataset.
  1306. The generated dataset has two columns ['image', 'label'].
  1307. The type of the image tensor is uint8. The label is just a scalar uint32 tensor.
  1308. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  1309. below shows what input args are allowed and their expected behavior.
  1310. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  1311. :widths: 25 25 50
  1312. :header-rows: 1
  1313. * - Parameter 'sampler'
  1314. - Parameter 'shuffle'
  1315. - Expected Order Behavior
  1316. * - None
  1317. - None
  1318. - random order
  1319. * - None
  1320. - True
  1321. - random order
  1322. * - None
  1323. - False
  1324. - sequential order
  1325. * - Sampler object
  1326. - None
  1327. - order defined by sampler
  1328. * - Sampler object
  1329. - True
  1330. - not allowed
  1331. * - Sampler object
  1332. - False
  1333. - not allowed
  1334. Args:
  1335. dataset_dir (str): Path to the root directory that contains the dataset.
  1336. num_samples (int, optional): The number of images to be included in the dataset
  1337. (default=None, all images).
  1338. num_parallel_workers (int, optional): Number of workers to read the data
  1339. (default=value, set in the config).
  1340. shuffle (bool, optional): Whether or not to perform shuffle on the dataset
  1341. (default=None, expected order behavior shown in the table).
  1342. sampler (Sampler, optional): Object used to choose samples from the
  1343. dataset (default=None, expected order behavior shown in the table).
  1344. num_shards (int, optional): Number of shards that the dataset should be divided
  1345. into (default=None).
  1346. shard_id (int, optional): The shard ID within num_shards (default=None). This
  1347. argument should be specified only when num_shards is also specified.
  1348. Raises:
  1349. RuntimeError: If sampler and shuffle are specified at the same time.
  1350. RuntimeError: If sampler and sharding are specified at the same time.
  1351. RuntimeError: If num_shards is specified but shard_id is None.
  1352. RuntimeError: If shard_id is specified but num_shards is None.
  1353. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  1354. Examples:
  1355. >>> import mindspore.dataset as ds
  1356. >>> dataset_dir = "/path/to/mnist_folder"
  1357. >>> # 1) read 3 samples from mnist_dataset
  1358. >>> mnist_dataset = ds.MnistDataset(dataset_dir=dataset_dir, num_samples=3)
  1359. >>> # in mnist_dataset dataset, each dictionary has keys "image" and "label"
  1360. """
  1361. @check_mnist_cifar_dataset
  1362. def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
  1363. shuffle=None, sampler=None, num_shards=None, shard_id=None):
  1364. super().__init__(num_parallel_workers)
  1365. self.dataset_dir = dataset_dir
  1366. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  1367. self.num_samples = num_samples
  1368. self.shuffle_level = shuffle
  1369. self.num_shards = num_shards
  1370. self.shard_id = shard_id
  1371. def get_args(self):
  1372. args = super().get_args()
  1373. args["dataset_dir"] = self.dataset_dir
  1374. args["num_samples"] = self.num_samples
  1375. args["shuffle"] = self.shuffle_level
  1376. args["sampler"] = self.sampler
  1377. args["num_shards"] = self.num_shards
  1378. args["shard_id"] = self.shard_id
  1379. return args
  1380. def get_dataset_size(self):
  1381. """
  1382. Get the number of batches in an epoch.
  1383. Return:
  1384. Number, number of batches.
  1385. """
  1386. if self.num_samples is None:
  1387. num_samples = 0
  1388. else:
  1389. num_samples = self.num_samples
  1390. num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples)
  1391. return get_num_rows(num_rows, self.num_shards)
  1392. class MindDataset(SourceDataset):
  1393. """
  1394. A source dataset that reads from shard files and database.
  1395. Args:
  1396. dataset_file (str): one of file names in dataset.
  1397. columns_list (list[str], optional): List of columns to be read (default=None).
  1398. num_parallel_workers (int, optional): The number of readers (default=None).
  1399. shuffle (bool, optional): Whether or not to perform shuffle on the dataset
  1400. (default=None, performs shuffle).
  1401. num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
  1402. shard_id (int, optional): The shard ID within num_shards (default=None). This
  1403. argument should be specified only when num_shards is also specified.
  1404. block_reader (bool, optional): Whether read data by block mode (default=False).
  1405. sampler (Sampler, optional): Object used to choose samples from the
  1406. dataset (default=None, sampler is exclusive
  1407. with shuffle and block_reader). Support list: SubsetRandomSampler.
  1408. Raises:
  1409. ValueError: If num_shards is specified but shard_id is None.
  1410. ValueError: If shard_id is specified but num_shards is None.
  1411. ValueError: If block reader is true but partition is specified.
  1412. """
  1413. @check_minddataset
  1414. def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None,
  1415. shuffle=None, num_shards=None, shard_id=None,
  1416. block_reader=False, sampler=None):
  1417. super().__init__(num_parallel_workers)
  1418. self.dataset_file = dataset_file
  1419. self.columns_list = columns_list
  1420. self.global_shuffle = shuffle
  1421. self.distribution = ""
  1422. self.sampler = sampler
  1423. if num_shards is None or shard_id is None:
  1424. self.partitions = None
  1425. else:
  1426. self.partitions = [num_shards, shard_id]
  1427. if block_reader is True and self.partitions is not None:
  1428. raise ValueError("block reader not allowed true when use partitions")
  1429. if block_reader is True and shuffle is True:
  1430. raise ValueError("block reader not allowed true when use shuffle")
  1431. if block_reader is True:
  1432. logger.warning("WARN: global shuffle is not used.")
  1433. if sampler is not None and isinstance(sampler, samplers.SubsetRandomSampler) is False:
  1434. raise ValueError("the sampler is not supported yet.")
  1435. # sampler exclusive
  1436. if block_reader is True and sampler is not None:
  1437. raise ValueError("block reader not allowed true when use sampler")
  1438. if shuffle is True and sampler is not None:
  1439. raise ValueError("shuffle not allowed true when use sampler")
  1440. if block_reader is False and sampler is None:
  1441. self.global_shuffle = not bool(shuffle is False)
  1442. self.num_shards = num_shards
  1443. self.shard_id = shard_id
  1444. self.block_reader = block_reader
  1445. def get_args(self):
  1446. args = super().get_args()
  1447. args["dataset_file"] = self.dataset_file
  1448. args["columns_list"] = self.columns_list
  1449. args["global_shuffle"] = self.global_shuffle
  1450. args["partitions"] = self.partitions
  1451. args["block_reader"] = self.block_reader
  1452. args["num_shards"] = self.num_shards
  1453. args["shard_id"] = self.shard_id
  1454. args["sampler"] = self.sampler
  1455. return args
  1456. def get_dataset_size(self):
  1457. """
  1458. Get the number of batches in an epoch.
  1459. Return:
  1460. Number, number of batches.
  1461. """
  1462. num_rows = MindRecordOp.get_num_rows(self.dataset_file)
  1463. if self.partitions is not None and self.partitions[0] > 0:
  1464. if num_rows % self.partitions[0] == 0:
  1465. num_rows = num_rows // self.partitions[0]
  1466. else:
  1467. num_rows = num_rows // self.partitions[0] + 1
  1468. return num_rows
  1469. def _iter_fn(dataset, num_samples):
  1470. """
  1471. Generator function wrapper for iterable dataset
  1472. """
  1473. if num_samples is not None:
  1474. ds_iter = iter(dataset)
  1475. for _ in range(num_samples):
  1476. try:
  1477. val = next(ds_iter)
  1478. except StopIteration:
  1479. return
  1480. # convert output tensors to ndarrays
  1481. yield tuple([np.array(x) for x in val])
  1482. else:
  1483. for val in dataset:
  1484. # convert output tensors to ndarrays
  1485. yield tuple([np.array(x) for x in val])
  1486. def _generator_fn(generator, num_samples):
  1487. """
  1488. Generator function wrapper for generator function dataset
  1489. """
  1490. if num_samples is not None:
  1491. gen_iter = generator()
  1492. for _ in range(num_samples):
  1493. try:
  1494. val = next(gen_iter)
  1495. except StopIteration:
  1496. return
  1497. yield val
  1498. else:
  1499. gen_iter = generator()
  1500. for val in gen_iter:
  1501. yield val
  1502. def _py_sampler_fn(sampler, num_samples, dataset):
  1503. """
  1504. Generator function wrapper for mappable dataset with python sampler
  1505. """
  1506. if num_samples is not None:
  1507. sampler_iter = iter(sampler)
  1508. for _ in range(num_samples):
  1509. try:
  1510. idx = next(sampler_iter)
  1511. except StopIteration:
  1512. return
  1513. val = dataset[idx]
  1514. # convert output tensors to ndarrays
  1515. yield tuple([np.array(x) for x in val])
  1516. else:
  1517. for i in sampler:
  1518. val = dataset[i]
  1519. # convert output tensors to ndarrays
  1520. yield tuple([np.array(x) for x in val])
  1521. def _cpp_sampler_fn(sampler, dataset):
  1522. """
  1523. Generator function wrapper for mappable dataset with cpp sampler
  1524. """
  1525. indices = sampler.get_indices()
  1526. for i in indices:
  1527. val = dataset[i]
  1528. # convert output tensors to ndarrays
  1529. yield tuple([np.array(x) for x in val])
  1530. class GeneratorDataset(SourceDataset):
  1531. """
  1532. A source dataset that generate data from python by invoking python data source each epoch.
  1533. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  1534. below shows what input args are allowed and their expected behavior.
  1535. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  1536. :widths: 25 25 50
  1537. :header-rows: 1
  1538. * - Parameter 'sampler'
  1539. - Parameter 'shuffle'
  1540. - Expected Order Behavior
  1541. * - None
  1542. - None
  1543. - random order
  1544. * - None
  1545. - True
  1546. - random order
  1547. * - None
  1548. - False
  1549. - sequential order
  1550. * - Sampler object
  1551. - None
  1552. - order defined by sampler
  1553. * - Sampler object
  1554. - True
  1555. - not allowed
  1556. * - Sampler object
  1557. - False
  1558. - not allowed
  1559. Args:
  1560. source (Callable/Iterable/Random Accessible):
  1561. A generator callable object, an iterable python object or a random accessible python object.
  1562. Callable source is required to return a tuple of numpy array as a row of the dataset on source().next().
  1563. Iterable source is required to return a tuple of numpy array as a row of the dataset on iter(source).next().
  1564. Random accessible source is required to return a tuple of numpy array as a row of the dataset on
  1565. source[idx].
  1566. column_names (list[str]): List of column names of the dataset.
  1567. column_types (list[mindspore.dtype], optional): List of column data types of the dataset (default=None).
  1568. If provided, sanity check will be performed on generator output.
  1569. schema (Schema/String, optional): Path to the json schema file or schema object (default=None).
  1570. If the schema is not provided, the meta data from column_names and column_types is considered the schema.
  1571. num_samples (int, optional): The number of samples to be included in the dataset
  1572. (default=None, all images).
  1573. shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
  1574. (default=None, expected order behavior shown in the table).
  1575. sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is
  1576. required.
  1577. (default=None, expected order behavior shown in the table).
  1578. num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
  1579. This argument should be specified only when 'num_samples' is "None". Random accessible input is required.
  1580. shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
  1581. when num_shards is also specified. Random accessible input is required.
  1582. Examples:
  1583. >>> import mindspore.dataengine as de
  1584. >>> # 1) Multidimensional generator function as callable input
  1585. >>> def generator_md():
  1586. >>> for i in range(64):
  1587. >>> yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
  1588. >>> # create multi_dimension_generator_dataset with GeneratorMD and column name "multi_dimensional_data"
  1589. >>> multi_dimension_generator_dataset = de.GeneratorDataset(generator_md, ["multi_dimensional_data"])
  1590. >>> # 2) Multi-column generator function as callable input
  1591. >>> def generator_mc(maxid = 64):
  1592. >>> for i in range(maxid):
  1593. >>> yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
  1594. >>> # create multi_column_generator_dataset with GeneratorMC and column names "col1" and "col2"
  1595. >>> multi_column_generator_dataset = de.GeneratorDataset(generator_mc, ["col1, col2"])
  1596. >>> # 3) Iterable dataset as iterable input
  1597. >>> class MyIterable():
  1598. >>> def __iter__(self):
  1599. >>> return # User implementation
  1600. >>> # create iterable_generator_dataset with MyIterable object
  1601. >>> iterable_generator_dataset = de.GeneratorDataset(MyIterable(), ["col1"])
  1602. >>> # 4) Random accessible dataset as Random accessible input
  1603. >>> class MyRA():
  1604. >>> def __getitem__(self, index):
  1605. >>> return # User implementation
  1606. >>> # create ra_generator_dataset with MyRA object
  1607. >>> ra_generator_dataset = de.GeneratorDataset(MyRA(), ["col1"])
  1608. >>> # List/Dict/Tuple is also random accessible
  1609. >>> list_generator = de.GeneratorDataset([(np.array(0),), (np.array(1)), (np.array(2))], ["col1"])
  1610. >>> # 5) Built-in Sampler
  1611. >>> my_generator = de.GeneratorDataset(my_ds, ["img", "label"], sampler=samplers.RandomSampler())
  1612. >>>
  1613. """
  1614. @check_generatordataset
  1615. def __init__(self, source, column_names, column_types=None, schema=None, num_samples=None, num_parallel_workers=1,
  1616. shuffle=None, sampler=None, num_shards=None, shard_id=None):
  1617. super().__init__(num_parallel_workers)
  1618. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  1619. if self.sampler is not None and hasattr(source, "__getitem__"):
  1620. if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
  1621. samplers.RandomSampler, samplers.SubsetRandomSampler,
  1622. samplers.WeightedRandomSampler)):
  1623. if num_samples is None:
  1624. num_samples = len(source)
  1625. sampler_instance = self.sampler.create()
  1626. sampler_instance.set_num_rows(len(source))
  1627. sampler_instance.set_num_samples(num_samples)
  1628. sampler_instance.initialize()
  1629. self.source = (lambda: _cpp_sampler_fn(sampler_instance, source))
  1630. else:
  1631. self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source))
  1632. else:
  1633. try:
  1634. iter(source)
  1635. except TypeError:
  1636. # Use generator function if input callable
  1637. self.source = (lambda: _generator_fn(source, num_samples))
  1638. else:
  1639. # Use iterator function if input is iterable
  1640. # Random accessible input is also iterable
  1641. self.source = (lambda: _iter_fn(source, num_samples))
  1642. self.column_names = column_names
  1643. if column_types is not None:
  1644. self.column_types = mstypelist_to_detypelist(column_types)
  1645. else:
  1646. self.column_types = column_types
  1647. def get_args(self):
  1648. args = super().get_args()
  1649. args["source"] = self.source
  1650. args["column_names"] = self.column_names
  1651. args["column_types"] = self.column_types
  1652. return args
  1653. def get_dataset_size(self):
  1654. """
  1655. Get the number of batches in an epoch.
  1656. Return:
  1657. Number, number of batches.
  1658. """
  1659. return self._dataset_size
  1660. # manually set dataset_size as a temporary solution.
  1661. def set_dataset_size(self, value):
  1662. if value >= 0:
  1663. self._dataset_size = value
  1664. else:
  1665. raise ValueError('set dataset_size with negative value {}'.format(value))
  1666. class TFRecordDataset(SourceDataset):
  1667. """
  1668. A source dataset that reads and parses datasets stored on disk in TFData format.
  1669. Args:
  1670. dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of
  1671. files. The list will be sorted in a lexicographical order.
  1672. schema (str or Schema, optional): Path to the json schema file or schema object (default=None).
  1673. If the schema is not provided, the meta data from the TFData file is considered the schema.
  1674. columns_list (list[str], optional): List of columns to be read (default=None, read all columns)
  1675. num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
  1676. num_parallel_workers (int, optional): number of workers to read the data
  1677. (default=None, number set in the config).
  1678. shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
  1679. If shuffle is False, no shuffling will be performed;
  1680. If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
  1681. Otherwise, there are two levels of shuffling:
  1682. - Shuffle.GLOBAL: Shuffle both the files and samples.
  1683. - Shuffle.FILES: Shuffle files only.
  1684. num_shards (int, optional): Number of shards that the dataset should be divided
  1685. into (default=None).
  1686. shard_id (int, optional): The shard ID within num_shards (default=None). This
  1687. argument should be specified only when num_shards is also specified.
  1688. shard_equal_rows (bool): Get equal rows for all shards(default=False). If shard_equal_rows is false, number
  1689. of rows of each shard may be not equal.
  1690. Examples:
  1691. >>> import mindspore.dataset as ds
  1692. >>> import mindspore.common.dtype as mstype
  1693. >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple tf data files
  1694. >>> # 1) get all rows from dataset_files with no explicit schema:
  1695. >>> # The meta-data in the first row will be used as a schema.
  1696. >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files)
  1697. >>> # 2) get all rows from dataset_files with user-defined schema:
  1698. >>> schema = ds.Schema()
  1699. >>> schema.add_column('col_1d', de_type=mindspore.int64, shape=[2])
  1700. >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema=schema)
  1701. >>> # 3) get all rows from dataset_files with schema file "./schema.json":
  1702. >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json")
  1703. """
  1704. @staticmethod
  1705. def _find_files(patterns):
  1706. """
  1707. Utility function to search for files with the given glob patterns.
  1708. Args:
  1709. patterns (str or list[str]): string or list of patterns to be searched.
  1710. Returns:
  1711. List, files.
  1712. """
  1713. def flat(lists):
  1714. return list(np.array(lists).flatten())
  1715. if not isinstance(patterns, list):
  1716. patterns = [patterns]
  1717. file_list = flat([glob.glob(file, recursive=True) for file in patterns])
  1718. if file_list: # not empty
  1719. return file_list
  1720. raise ValueError("The list of path names matching the patterns is empty.")
  1721. @check_tfrecorddataset
  1722. def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None,
  1723. shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False):
  1724. super().__init__(num_parallel_workers)
  1725. self.dataset_files = self._find_files(dataset_files)
  1726. self.dataset_files.sort()
  1727. self.num_shards = num_shards
  1728. self.shard_id = shard_id
  1729. schema_obj = None
  1730. if (schema is not None) and (not isinstance(schema, Schema)):
  1731. schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it
  1732. self.schema = schema
  1733. self.columns_list = columns_list
  1734. self.num_samples = num_samples
  1735. if schema_obj is not None and num_samples is None:
  1736. self.num_samples = schema_obj.num_rows
  1737. if not isinstance(shuffle, (bool, Shuffle)):
  1738. raise TypeError("shuffle should be of boolean or enum 'Shuffle'.")
  1739. if not isinstance(shuffle, Shuffle):
  1740. if shuffle:
  1741. self.shuffle_level = Shuffle.GLOBAL
  1742. self.shuffle_files = True
  1743. else:
  1744. self.shuffle_level = None
  1745. self.shuffle_files = False
  1746. else:
  1747. self.shuffle_level = shuffle
  1748. self.shuffle_files = True
  1749. self.shard_equal_rows = shard_equal_rows
  1750. def get_args(self):
  1751. args = super().get_args()
  1752. args["dataset_files"] = self.dataset_files
  1753. if self.schema is not None:
  1754. if isinstance(self.schema, Schema):
  1755. self.schema.datasetType = 'TF'
  1756. if self.num_samples is not None:
  1757. self.schema.num_rows = self.num_samples
  1758. args["schema_json_string"] = self.schema.to_json()
  1759. else:
  1760. args["schema_file_path"] = self.schema
  1761. args["schema"] = self.schema
  1762. args["columns_list"] = self.columns_list
  1763. args["num_samples"] = self.num_samples
  1764. if self.shuffle_files is not None:
  1765. args["shuffle_files"] = self.shuffle_files
  1766. args["shuffle"] = self.shuffle_level
  1767. args["num_shards"] = self.num_shards
  1768. args["shard_id"] = self.shard_id
  1769. args["shard_equal_rows"] = self.shard_equal_rows
  1770. return args
  1771. def get_dataset_size(self, estimate=False):
  1772. """
  1773. Get the number of batches in an epoch.
  1774. Args:
  1775. estimate (bool, optional): Fast estimation of the dataset size instead of a full scan.
  1776. Return:
  1777. Number, number of batches.
  1778. """
  1779. if self._dataset_size is None:
  1780. num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate)
  1781. num_rows = get_num_rows(num_rows, self.num_shards)
  1782. if self.num_samples is None:
  1783. return num_rows
  1784. return min(self.num_samples, num_rows)
  1785. return self._dataset_size
  1786. # manually set dataset_size as a tempoary solution.
  1787. def set_dataset_size(self, value):
  1788. logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
  1789. if value >= 0:
  1790. self._dataset_size = value
  1791. else:
  1792. raise ValueError('set dataset_size with negative value {}'.format(value))
  1793. class ManifestDataset(SourceDataset):
  1794. """
  1795. A source dataset that reads images from a manifest file.
  1796. The generated dataset has two columns ['image', 'label'].
  1797. The shape of the image column is [image_size] if decode flag is False, or [H,W,C]
  1798. otherwise.
  1799. The type of the image tensor is uint8. The label is just a scalar uint64
  1800. tensor.
  1801. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  1802. below shows what input args are allowed and their expected behavior.
  1803. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  1804. :widths: 25 25 50
  1805. :header-rows: 1
  1806. * - Parameter 'sampler'
  1807. - Parameter 'shuffle'
  1808. - Expected Order Behavior
  1809. * - None
  1810. - None
  1811. - random order
  1812. * - None
  1813. - True
  1814. - random order
  1815. * - None
  1816. - False
  1817. - sequential order
  1818. * - Sampler object
  1819. - None
  1820. - order defined by sampler
  1821. * - Sampler object
  1822. - True
  1823. - not allowed
  1824. * - Sampler object
  1825. - False
  1826. - not allowed
  1827. Args:
  1828. dataset_file (str): File to be read.
  1829. usage (str, optional): Need train, eval or inference data (default="train").
  1830. num_samples (int, optional): The number of images to be included in the dataset.
  1831. (default=None, all images).
  1832. num_parallel_workers (int, optional): Number of workers to read the data
  1833. (default=None, number set in the config).
  1834. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  1835. order behavior shown in the table).
  1836. sampler (Sampler, optional): Object used to choose samples from the
  1837. dataset (default=None, expected order behavior shown in the table).
  1838. class_indexing (dict, optional): A str-to-int mapping from label name to index
  1839. (default=None, the folder names will be sorted alphabetically and each
  1840. class will be given a unique index starting from 0).
  1841. decode (bool, optional): decode the images after reading (defaults=False).
  1842. num_shards (int, optional): Number of shards that the dataset should be divided
  1843. into (default=None).
  1844. shard_id (int, optional): The shard ID within num_shards (default=None). This
  1845. argument should be specified only when num_shards is also specified.
  1846. Raises:
  1847. RuntimeError: If sampler and shuffle are specified at the same time.
  1848. RuntimeError: If sampler and sharding are specified at the same time.
  1849. RuntimeError: If num_shards is specified but shard_id is None.
  1850. RuntimeError: If shard_id is specified but num_shards is None.
  1851. RuntimeError: If class_indexing is not a dictionary.
  1852. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  1853. Examples:
  1854. >>> import mindspore.dataset as ds
  1855. >>> dataset_file = "/path/to/manifest_file.manifest"
  1856. >>> # 1) read all samples specified in manifest_file dataset with 8 threads for training:
  1857. >>> manifest_dataset = ds.ManifestDataset(dataset_file, usage="train", num_parallel_workers=8)
  1858. >>> # 2) reads samples (specified in manifest_file.manifest) for shard 0 in a 2-way distributed training setup:
  1859. >>> manifest_dataset = ds.ManifestDataset(dataset_file, num_shards=2, shard_id=0)
  1860. """
  1861. @check_manifestdataset
  1862. def __init__(self, dataset_file, usage="train", num_samples=None, num_parallel_workers=None,
  1863. shuffle=None, sampler=None, class_indexing=None, decode=False, num_shards=None, shard_id=None):
  1864. super().__init__(num_parallel_workers)
  1865. self.dataset_file = dataset_file
  1866. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  1867. if class_indexing is not None and not isinstance(class_indexing, dict):
  1868. raise RuntimeError("class_indexing should be a dictionary.")
  1869. self.num_samples = num_samples
  1870. self.class_indexing = class_indexing
  1871. self.decode = decode
  1872. self.usage = usage
  1873. self.shuffle_level = shuffle
  1874. self.num_shards = num_shards
  1875. self.shard_id = shard_id
  1876. def get_args(self):
  1877. args = super().get_args()
  1878. args["dataset_file"] = self.dataset_file
  1879. args["usage"] = self.usage
  1880. args["num_samples"] = self.num_samples
  1881. args["shuffle"] = self.shuffle_level
  1882. args["sampler"] = self.sampler
  1883. args["class_indexing"] = self.class_indexing
  1884. args["decode"] = self.decode
  1885. args["num_shards"] = self.num_shards
  1886. args["shard_id"] = self.shard_id
  1887. return args
  1888. def get_dataset_size(self):
  1889. """
  1890. Get the number of batches in an epoch.
  1891. Return:
  1892. Number, number of batches.
  1893. """
  1894. if self.num_samples is None:
  1895. num_samples = 0
  1896. else:
  1897. num_samples = self.num_samples
  1898. if self.class_indexing is None:
  1899. class_indexing = dict()
  1900. else:
  1901. class_indexing = self.class_indexing
  1902. num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0]
  1903. return get_num_rows(num_rows, self.num_shards)
  1904. def num_classes(self):
  1905. """
  1906. Get the number of classes in a dataset.
  1907. Return:
  1908. Number, number of classes.
  1909. """
  1910. if self.num_samples is None:
  1911. num_samples = 0
  1912. else:
  1913. num_samples = self.num_samples
  1914. if self.class_indexing is None:
  1915. class_indexing = dict()
  1916. else:
  1917. class_indexing = self.class_indexing
  1918. return ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[1]
  1919. def get_class_indexing(self):
  1920. """
  1921. Get the class index
  1922. Return:
  1923. Dict, A str-to-int mapping from label name to index.
  1924. """
  1925. if self.num_samples is None:
  1926. num_samples = 0
  1927. else:
  1928. num_samples = self.num_samples
  1929. if self.class_indexing is None:
  1930. class_indexing = dict()
  1931. else:
  1932. class_indexing = self.class_indexing
  1933. return ManifestOp.get_class_indexing(self.dataset_file, num_samples, class_indexing, self.usage)
  1934. class Cifar10Dataset(SourceDataset):
  1935. """
  1936. A source dataset that reads cifar10 data.
  1937. The generated dataset has two columns ['image', 'label'].
  1938. The type of the image tensor is uint8. The label is just a scalar uint32
  1939. tensor.
  1940. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  1941. below shows what input args are allowed and their expected behavior.
  1942. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  1943. :widths: 25 25 50
  1944. :header-rows: 1
  1945. * - Parameter 'sampler'
  1946. - Parameter 'shuffle'
  1947. - Expected Order Behavior
  1948. * - None
  1949. - None
  1950. - random order
  1951. * - None
  1952. - True
  1953. - random order
  1954. * - None
  1955. - False
  1956. - sequential order
  1957. * - Sampler object
  1958. - None
  1959. - order defined by sampler
  1960. * - Sampler object
  1961. - True
  1962. - not allowed
  1963. * - Sampler object
  1964. - False
  1965. - not allowed
  1966. Args:
  1967. dataset_dir (str): Path to the root directory that contains the dataset.
  1968. num_samples (int, optional): The number of images to be included in the dataset.
  1969. (default=None, all images).
  1970. num_parallel_workers (int, optional): Number of workers to read the data
  1971. (default=None, number set in the config).
  1972. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  1973. order behavior shown in the table).
  1974. sampler (Sampler, optional): Object used to choose samples from the
  1975. dataset (default=None, expected order behavior shown in the table).
  1976. num_shards (int, optional): Number of shards that the dataset should be divided
  1977. into (default=None).
  1978. shard_id (int, optional): The shard ID within num_shards (default=None). This
  1979. argument should be specified only when num_shards is also specified.
  1980. Raises:
  1981. RuntimeError: If sampler and shuffle are specified at the same time.
  1982. RuntimeError: If sampler and sharding are specified at the same time.
  1983. RuntimeError: If num_shards is specified but shard_id is None.
  1984. RuntimeError: If shard_id is specified but num_shards is None.
  1985. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  1986. Examples:
  1987. >>> import mindspore.dataset as ds
  1988. >>> dataset_dir = "/path/to/cifar10_dataset_directory"
  1989. >>> # 1) get all samples from CIFAR10 dataset in sequence:
  1990. >>> dataset = ds.Cifar10Dataset(dataset_dir=dataset_dir,shuffle=False)
  1991. >>> # 2) randomly select 350 samples from CIFAR10 dataset:
  1992. >>> dataset = ds.Cifar10Dataset(dataset_dir=dataset_dir,num_samples=350, shuffle=True)
  1993. >>> # 3) get samples from CIFAR10 dataset for shard 0 in a 2 way distributed training:
  1994. >>> dataset = ds.Cifar10Dataset(dataset_dir=dataset_dir,num_shards=2,shard_id=0)
  1995. >>> # in CIFAR10 dataset, each dictionary has keys "image" and "label"
  1996. """
  1997. @check_mnist_cifar_dataset
  1998. def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
  1999. shuffle=None, sampler=None, num_shards=None, shard_id=None):
  2000. super().__init__(num_parallel_workers)
  2001. self.dataset_dir = dataset_dir
  2002. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  2003. self.num_samples = num_samples
  2004. self.num_shards = num_shards
  2005. self.shard_id = shard_id
  2006. self.shuffle_level = shuffle
  2007. def get_args(self):
  2008. args = super().get_args()
  2009. args["dataset_dir"] = self.dataset_dir
  2010. args["num_samples"] = self.num_samples
  2011. args["sampler"] = self.sampler
  2012. args["num_shards"] = self.num_shards
  2013. args["shard_id"] = self.shard_id
  2014. args["shuffle"] = self.shuffle_level
  2015. return args
  2016. def get_dataset_size(self):
  2017. """
  2018. Get the number of batches in an epoch.
  2019. Return:
  2020. Number, number of batches.
  2021. """
  2022. if self.num_samples is None:
  2023. num_samples = 0
  2024. else:
  2025. num_samples = self.num_samples
  2026. num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True)
  2027. return get_num_rows(num_rows, self.num_shards)
  2028. class Cifar100Dataset(SourceDataset):
  2029. """
  2030. A source dataset that reads cifar100 data.
  2031. The generated dataset has three columns ['image', 'coarse_label', 'fine_label'].
  2032. The type of the image tensor is uint8. The coarse and fine are just a scalar uint32
  2033. tensor.
  2034. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  2035. below shows what input args are allowed and their expected behavior.
  2036. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  2037. :widths: 25 25 50
  2038. :header-rows: 1
  2039. * - Parameter 'sampler'
  2040. - Parameter 'shuffle'
  2041. - Expected Order Behavior
  2042. * - None
  2043. - None
  2044. - random order
  2045. * - None
  2046. - True
  2047. - random order
  2048. * - None
  2049. - False
  2050. - sequential order
  2051. * - Sampler object
  2052. - None
  2053. - order defined by sampler
  2054. * - Sampler object
  2055. - True
  2056. - not allowed
  2057. * - Sampler object
  2058. - False
  2059. - not allowed
  2060. Args:
  2061. dataset_dir (str): Path to the root directory that contains the dataset.
  2062. num_samples (int, optional): The number of images to be included in the dataset.
  2063. (default=None, all images).
  2064. num_parallel_workers (int, optional): Number of workers to read the data
  2065. (default=None, number set in the config).
  2066. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  2067. order behavior shown in the table).
  2068. sampler (Sampler, optional): Object used to choose samples from the
  2069. dataset (default=None, expected order behavior shown in the table).
  2070. num_shards (int, optional): Number of shards that the dataset should be divided
  2071. into (default=None).
  2072. shard_id (int, optional): The shard ID within num_shards (default=None). This
  2073. argument should be specified only when num_shards is also specified.
  2074. Raises:
  2075. RuntimeError: If sampler and shuffle are specified at the same time.
  2076. RuntimeError: If sampler and sharding are specified at the same time.
  2077. RuntimeError: If num_shards is specified but shard_id is None.
  2078. RuntimeError: If shard_id is specified but num_shards is None.
  2079. ValueError: If shard_id is invalid (< 0 or >= num_shards).
  2080. Examples:
  2081. >>> import mindspore.dataset as ds
  2082. >>> dataset_dir = "/path/to/cifar100_dataset_directory"
  2083. >>> # 1) get all samples from CIFAR100 dataset in sequence:
  2084. >>> cifar100_dataset = ds.Cifar100Dataset(dataset_dir=dataset_dir,shuffle=False)
  2085. >>> # 2) randomly select 350 samples from CIFAR100 dataset:
  2086. >>> cifar100_dataset = ds.Cifar100Dataset(dataset_dir=dataset_dir,num_samples=350, shuffle=True)
  2087. >>> # in CIFAR100 dataset, each dictionary has 3 keys: "image", "fine_label" and "coarse_label"
  2088. """
  2089. @check_mnist_cifar_dataset
  2090. def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
  2091. shuffle=None, sampler=None, num_shards=None, shard_id=None):
  2092. super().__init__(num_parallel_workers)
  2093. self.dataset_dir = dataset_dir
  2094. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  2095. self.num_samples = num_samples
  2096. self.num_shards = num_shards
  2097. self.shard_id = shard_id
  2098. self.shuffle_level = shuffle
  2099. def get_args(self):
  2100. args = super().get_args()
  2101. args["dataset_dir"] = self.dataset_dir
  2102. args["num_samples"] = self.num_samples
  2103. args["sampler"] = self.sampler
  2104. args["num_shards"] = self.num_shards
  2105. args["shard_id"] = self.shard_id
  2106. args["shuffle"] = self.shuffle_level
  2107. return args
  2108. def get_dataset_size(self):
  2109. """
  2110. Get the number of batches in an epoch.
  2111. Return:
  2112. Number, number of batches.
  2113. """
  2114. if self.num_samples is None:
  2115. num_samples = 0
  2116. else:
  2117. num_samples = self.num_samples
  2118. num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False)
  2119. return get_num_rows(num_rows, self.num_shards)
  2120. class Schema:
  2121. """
  2122. Class to represent a schema of dataset.
  2123. Args:
  2124. schema_file(str): Path of schema file (default=None).
  2125. Return:
  2126. Schema object, schema info about dataset.
  2127. Raises:
  2128. RuntimeError: If schema file failed to load.
  2129. Example:
  2130. >>> import mindspore.dataset as ds
  2131. >>> import mindspore.common.dtype as mstype
  2132. >>> # create schema, specify column name, mindspore.dtype and shape of the column
  2133. >>> schema = ds.Schema()
  2134. >>> schema.add_column('col1', de_type=mindspore.int64, shape=[2])
  2135. """
  2136. def __init__(self, schema_file=None):
  2137. if schema_file is None:
  2138. self.columns = []
  2139. self.dataset_type = ''
  2140. self.num_rows = 0
  2141. else:
  2142. if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK):
  2143. raise ValueError("The file %s does not exist or permission denied!" % schema_file)
  2144. try:
  2145. with open(schema_file, 'r') as load_f:
  2146. json_obj = json.load(load_f)
  2147. except json.decoder.JSONDecodeError:
  2148. raise RuntimeError("Schema file failed to load.")
  2149. except UnicodeDecodeError:
  2150. raise RuntimeError("Schema file failed to decode.")
  2151. except Exception:
  2152. raise RuntimeError("Schema file failed to open.")
  2153. self.from_json(json_obj)
  2154. @check_add_column
  2155. def add_column(self, name, de_type, shape=None):
  2156. """
  2157. Add new column to the schema.
  2158. Args:
  2159. name (str): name of the column.
  2160. de_type (str): data type of the column.
  2161. shape (list[int], optional): shape of the column
  2162. (default=None, [-1] which is an unknown shape of rank 1).
  2163. Raises:
  2164. ValueError: If column type is unknown.
  2165. """
  2166. new_column = dict()
  2167. new_column["name"] = name
  2168. if isinstance(de_type, typing.Type):
  2169. de_type = mstype_to_detype(de_type)
  2170. new_column["type"] = str(de_type)
  2171. else:
  2172. new_column["type"] = str(DataType(de_type))
  2173. if shape is not None:
  2174. new_column["shape"] = shape
  2175. new_column["rank"] = len(shape)
  2176. else:
  2177. new_column["rank"] = 1
  2178. self.columns.append(new_column)
  2179. def to_json(self):
  2180. """
  2181. Get a JSON string of the schema.
  2182. Returns:
  2183. Str, JSON string of the schema.
  2184. """
  2185. json_file = dict()
  2186. json_file["columns"] = self.columns
  2187. if self.dataset_type:
  2188. json_file["datasetType"] = self.dataset_type
  2189. if self.num_rows:
  2190. json_file["numRows"] = self.num_rows
  2191. return json.dumps(json_file, indent=2)
  2192. def parse_columns(self, columns):
  2193. """
  2194. Parse the columns and add it to self.
  2195. Args:
  2196. columns (dict or list[str]): names of columns.
  2197. Raises:
  2198. RuntimeError: If failed to parse schema file.
  2199. RuntimeError: If unknown items in schema file.
  2200. RuntimeError: If column's name field is missing.
  2201. RuntimeError: If column's type field is missing.
  2202. """
  2203. if columns is None:
  2204. raise TypeError("Expected non-empty dict or string list.")
  2205. self.columns = []
  2206. for col in columns:
  2207. name = None
  2208. shape = None
  2209. data_type = None
  2210. col_details = None
  2211. if isinstance(columns, list):
  2212. col_details = col
  2213. if "name" in col:
  2214. name = col["name"]
  2215. elif isinstance(columns, dict):
  2216. col_details = columns[col]
  2217. name = col
  2218. else:
  2219. raise RuntimeError("Error parsing the schema file")
  2220. for k, v in col_details.items():
  2221. if k == "shape":
  2222. shape = v
  2223. elif k == "type":
  2224. data_type = v
  2225. elif k in ("t_impl", "rank"):
  2226. pass
  2227. else:
  2228. raise RuntimeError("Unknown field %s" % k)
  2229. if name is None:
  2230. raise RuntimeError("Column's name field is missing.")
  2231. if data_type is None:
  2232. raise RuntimeError("Column's type field is missing.")
  2233. self.add_column(name, data_type, shape)
  2234. def from_json(self, json_obj):
  2235. """
  2236. Get schema file from json file.
  2237. Args:
  2238. json_obj(dictionary): object of json parsed.
  2239. Raises:
  2240. RuntimeError: if there is unknown item in the object.
  2241. RuntimeError: if dataset type is missing in the object.
  2242. RuntimeError: if columns are missing in the object.
  2243. """
  2244. if not isinstance(json_obj, dict) or json_obj is None:
  2245. raise ValueError("Expected non-empty dict.")
  2246. for k, v in json_obj.items():
  2247. if k == "datasetType":
  2248. self.dataset_type = v
  2249. elif k == "numRows":
  2250. self.num_rows = v
  2251. elif k == "columns":
  2252. self.parse_columns(v)
  2253. else:
  2254. raise RuntimeError("Unknown field %s" % k)
  2255. if self.dataset_type is None:
  2256. raise RuntimeError("DatasetType field is missing.")
  2257. if self.columns is None:
  2258. raise RuntimeError("Columns are missing.")
  2259. def __str__(self):
  2260. return self.to_json()
  2261. class VOCDataset(SourceDataset):
  2262. """
  2263. A source dataset for reading and parsing VOC dataset.
  2264. The generated dataset has two columns ['image', 'target'].
  2265. The shape of both column is [image_size] if decode flag is False, or [H, W, C]
  2266. otherwise.
  2267. The type of both tensor is uint8.
  2268. This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
  2269. below shows what input args are allowed and their expected behavior.
  2270. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
  2271. :widths: 25 25 50
  2272. :header-rows: 1
  2273. * - Parameter 'sampler'
  2274. - Parameter 'shuffle'
  2275. - Expected Order Behavior
  2276. * - None
  2277. - None
  2278. - random order
  2279. * - None
  2280. - True
  2281. - random order
  2282. * - None
  2283. - False
  2284. - sequential order
  2285. * - Sampler object
  2286. - None
  2287. - order defined by sampler
  2288. * - Sampler object
  2289. - True
  2290. - not allowed
  2291. * - Sampler object
  2292. - False
  2293. - not allowed
  2294. Args:
  2295. dataset_dir (str): Path to the root directory that contains the dataset.
  2296. num_samples (int, optional): The number of images to be included in the dataset
  2297. (default=None, all images).
  2298. num_parallel_workers (int, optional): Number of workers to read the data
  2299. (default=None, number set in the config).
  2300. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
  2301. order behavior shown in the table).
  2302. decode (bool, optional): Decode the images after reading (default=False).
  2303. sampler (Sampler, optional): Object used to choose samples from the dataset
  2304. (default=None, expected order behavior shown in the table).
  2305. distribution (str, optional): Path to the json distribution file to configure
  2306. dataset sharding (default=None). This argument should be specified
  2307. only when no 'sampler' is used.
  2308. Raises:
  2309. RuntimeError: If distribution and sampler are specified at the same time.
  2310. RuntimeError: If distribution is failed to read.
  2311. RuntimeError: If shuffle and sampler are specified at the same time.
  2312. Examples:
  2313. >>> import mindspore.dataset as ds
  2314. >>> dataset_dir = "/path/to/voc_dataset_directory"
  2315. >>> # 1) read all VOC dataset samples in dataset_dir with 8 threads in random order:
  2316. >>> voc_dataset = ds.VOCDataset(dataset_dir, num_parallel_workers=8)
  2317. >>> # 2) read then decode all VOC dataset samples in dataset_dir in sequence:
  2318. >>> voc_dataset = ds.VOCDataset(dataset_dir, decode=True, shuffle=False)
  2319. >>> # in VOC dataset, each dictionary has keys "image" and "target"
  2320. """
  2321. @check_vocdataset
  2322. def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
  2323. shuffle=None, decode=False, sampler=None, distribution=None):
  2324. super().__init__(num_parallel_workers)
  2325. self.dataset_dir = dataset_dir
  2326. self.sampler = sampler
  2327. if distribution is not None:
  2328. if sampler is not None:
  2329. raise RuntimeError("Cannot specify distribution and sampler at the same time.")
  2330. try:
  2331. with open(distribution, 'r') as load_d:
  2332. json.load(load_d)
  2333. except json.decoder.JSONDecodeError:
  2334. raise RuntimeError("Json decode error when load distribution file")
  2335. except Exception:
  2336. raise RuntimeError("Distribution file has failed to load.")
  2337. elif shuffle is not None:
  2338. if sampler is not None:
  2339. raise RuntimeError("Cannot specify shuffle and sampler at the same time.")
  2340. self.num_samples = num_samples
  2341. self.decode = decode
  2342. self.distribution = distribution
  2343. self.shuffle_level = shuffle
  2344. def get_args(self):
  2345. args = super().get_args()
  2346. args["dataset_dir"] = self.dataset_dir
  2347. args["num_samples"] = self.num_samples
  2348. args["sampler"] = self.sampler
  2349. args["decode"] = self.decode
  2350. args["shuffle"] = self.shuffle_level
  2351. args["distribution"] = self.distribution
  2352. return args
  2353. def get_dataset_size(self):
  2354. """
  2355. Get the number of batches in an epoch.
  2356. Return:
  2357. Number, number of batches.
  2358. """
  2359. return self.num_samples
  2360. class CelebADataset(SourceDataset):
  2361. """
  2362. A source dataset for reading and parsing CelebA dataset.Only support list_attr_celeba.txt currently
  2363. Note:
  2364. The generated dataset has two columns ['image', 'attr'].
  2365. The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
  2366. Args:
  2367. dataset_dir (str): Path to the root directory that contains the dataset.
  2368. num_parallel_workers (int, optional): Number of workers to read the data (default=value set in the config).
  2369. shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None).
  2370. dataset_type (string): one of 'all', 'train', 'valid' or 'test'.
  2371. sampler (Sampler, optional): Object used to choose samples from the dataset (default=None).
  2372. decode (bool, optional): decode the images after reading (default=False).
  2373. extensions (list[str], optional): List of file extensions to be
  2374. included in the dataset (default=None).
  2375. num_samples (int, optional): The number of images to be included in the dataset.
  2376. (default=None, all images).
  2377. num_shards (int, optional): Number of shards that the dataset should be divided
  2378. into (default=None).
  2379. shard_id (int, optional): The shard ID within num_shards (default=None). This
  2380. argument should be specified only when num_shards is also specified.
  2381. """
  2382. @check_celebadataset
  2383. def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, dataset_type='all',
  2384. sampler=None, decode=False, extensions=None, num_samples=None, num_shards=None, shard_id=None):
  2385. super().__init__(num_parallel_workers)
  2386. self.dataset_dir = dataset_dir
  2387. self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
  2388. self.num_parallel_workers = num_parallel_workers
  2389. self.decode = decode
  2390. self.extensions = extensions
  2391. self.num_samples = num_samples
  2392. self.dataset_type = dataset_type
  2393. self.num_shards = num_shards
  2394. self.shard_id = shard_id
  2395. self.shuffle_level = shuffle
  2396. def get_args(self):
  2397. args = super().get_args()
  2398. args["dataset_dir"] = self.dataset_dir
  2399. args["sampler"] = self.sampler
  2400. args["shuffle"] = self.shuffle_level
  2401. args["decode"] = self.decode
  2402. args["extensions"] = self.extensions
  2403. args["num_samples"] = self.num_samples
  2404. args["dataset_type"] = self.dataset_type
  2405. args["num_shards"] = self.num_shards
  2406. args["shard_id"] = self.shard_id
  2407. return args