You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.dataset.Dataset.rst 47 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960
  1. .. py:method:: apply(apply_func)
  2. 对数据集对象执行给定操作函数。
  3. **参数:**
  4. `apply_func` (function):数据集处理函数,要求该函数的输入是一个 `Dataset` 对象,返回的是处理后的 `Dataset` 对象。
  5. **返回:**
  6. 执行了给定操作函数的数据集对象。
  7. **样例:**
  8. >>> # dataset是任意数据集对象的实例
  9. >>>
  10. >>> # 声明一个名为apply_func函数,对数据集对象执行batch操作,并返回值处理后的Dataset对象
  11. >>> # 注意apply_func函数的输入参数data需要为 `Dataset` 对象
  12. >>> def apply_func(data):
  13. ... data = data.batch(2)
  14. ... return data
  15. >>>
  16. >>> # 通过apply操作调用apply_func函数,得到处理后的数据集对象
  17. >>> dataset = dataset.apply(apply_func)
  18. **异常:**
  19. - **TypeError:** `apply_func` 的类型不是函数。
  20. - **TypeError:** `apply_func` 未返回Dataset对象。
  21. .. py:method:: batch(batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, input_columns=None, output_columns=None, column_order=None, pad_info=None, python_multiprocessing=False)
  22. 将数据集中连续 `batch_size` 条数据合并为一个批处理数据。
  23. `batch` 操作要求每列中的数据具有相同的shape。如果指定了参数 `per_batch_map` ,该参数将作用于批处理后的数据。
  24. .. note::
  25. 执行 `repeat` 和 `batch` 操作的先后顺序,会影响批处理数据的数量及 `per_batch_map` 的结果。建议在 `batch` 操作完成后执行 `repeat` 操作。
  26. **参数:**
  27. - **batch_size** (Union[int, Callable]) - 指定每个批处理数据包含的数据条目。
  28. 如果 `batch_size` 为整型,则直接表示每个批处理数据大小;
  29. 如果为可调用对象,则可以通过自定义行为动态指定每个批处理数据大小,要求该可调用对象接收一个参数BatchInfo,返回一个整形代表批处理大小,用法请参考样例(3)。
  30. - **drop_remainder** (bool, 可选) - 当最后一个批处理数据包含的数据条目小于 `batch_size` 时,是否将该批处理丢弃,不传递给下一个操作。默认值:False,不丢弃。
  31. - **num_parallel_workers** (int, 可选) - 指定 `batch `操作的并发进程数/线程数(由参数 `python_multiprocessing` 决定当前为多进程模式或多线程模式)。
  32. 默认值:None,使用mindspore.dataset.config中配置的线程数。
  33. - **per_batch_map** (callable, 可选) - 可调用对象,以(list[Tensor], list[Tensor], ..., BatchInfo)作为输入参数,处理后返回(list[Tensor], list[Tensor],...)作为新的数据列。
  34. 输入参数中每个list[Tensor]代表给定数据列中的一批Tensor,list[Tensor]的个数应与 `input_columns` 中传入列名的数量相匹配,
  35. 在返回的(list[Tensor], list[Tensor], ...)中,list[Tensor]的个数应与输入相同,如果输出列数与输入列数不一致,则需要指定 `output_columns`。
  36. 该可调用对象的最后一个输入参数始终是BatchInfo,用于获取数据集的信息,用法参考样例(2)。
  37. - **input_columns** (Union[str, list[str]], 可选):指定 `batch` 操作的输入数据列。
  38. 如果 `per_batch_map` 不为None,列表中列名的个数应与 `per_batch_map` 中包含的列数匹配。默认值:None,不指定。
  39. - **output_columns** (Union[str, list[str]], 可选) - 指定 `batch` 操作的输出数据列。如果输入数据列与输入数据列的长度不相等,则必须指定此参数。
  40. 此列表中列名的数量必须与 `per_batch_map` 方法的返回值数量相匹配。默认值:None,输出列将与输入列具有相同的名称。
  41. - **column_order** (Union[str, list[str]], 可选) - 指定传递到下一个数据集操作的数据列顺序。
  42. 如果 `input_column` 长度不等于 `output_column` 长度,则此参数必须指定。
  43. 注意:列名不限定在 `input_columns` 和 `output_columns` 中指定的列,也可以是上一个操作输出的未被处理的数据列,详细可参阅使用样例(4)。默认值:None,按照原输入顺序排列。
  44. - **pad_info** (dict, 可选) - 对给定数据列进行填充。通过传入dict来指定列信息与填充信息,例如 `pad_info={"col1":([224,224],0)}` ,
  45. 则将列名为"col1"的数据列扩充到shape为[224,224]的Tensor,缺失的值使用0填充。默认值:None,不填充。
  46. - **python_multiprocessing** (bool, 可选) - 启动Python多进程模式并行执行 `per_batch_map` 。如果 `per_batch_map` 的计算量很大,此选项可能会很有用。默认值:False,不启用多进程。
  47. **返回:**
  48. Dataset, `batch` 操作后的数据集对象。
  49. **样例:**
  50. >>> # 1)创建一个数据集对象,每100条数据合并成一个批处理数据
  51. >>> # 如果最后一个批次数据小于给定的批次大小(batch_size),则丢弃这个批次
  52. >>> dataset = dataset.batch(100, True)
  53. >>>
  54. >>> # 2)根据批次编号调整图像大小,如果是第5批,则图像大小调整为(5^2, 5^2) = (25, 25)
  55. >>> def np_resize(col, BatchInfo):
  56. ... output = col.copy()
  57. ... s = (BatchInfo.get_batch_num() + 1) ** 2
  58. ... index = 0
  59. ... for c in col:
  60. ... img = Image.fromarray(c.astype('uint8')).convert('RGB')
  61. ... img = img.resize((s, s), Image.ANTIALIAS)
  62. ... output[index] = np.array(img)
  63. ... index += 1
  64. ... return (output,)
  65. >>> dataset = dataset.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize)
  66. >>>
  67. >>> # 3)创建一个数据集对象,动态指定批处理大小
  68. >>> # 定义一个批处理函数,每次将batch_size加一
  69. >>> def add_one(BatchInfo):
  70. ... return BatchInfo.get_batch_num() + 1
  71. >>> dataset = dataset.batch(batch_size=add_one, drop_remainder=True)
  72. >>>
  73. >>> # 4)创建一个数据集对象,并执行batch数据,指定column_order更换数据列顺序
  74. >>> # 假设数据集对象的数据列原顺序是["image", "label"]
  75. >>> dataset = dataset.batch(32, column_order=["label", "image"])
  76. .. py:method:: bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, element_length_function=None, pad_info=None, pad_to_bucket_boundary=False, drop_remainder=False)
  77. 根据数据的长度进行分桶,每个桶将在数据填满的时候进行填充和批处理操作。
  78. 对数据集中的每一条数据进行长度计算,根据该条数据的长度计算结果和每个分桶的范围将该数据归类到特定的桶里面。
  79. 当某个分桶中数据条数达到指定的大小 `bucket_batch_sizes` 时,将根据 `pad_info` 的信息对分桶进行填充,再进行批处理。
  80. **参数:**
  81. - **column_names** (list[str]) - 传递给参数 `element_length_function` 的数据列,用于计算数据的长度。
  82. - **bucket_boundaries** (list[int]) - 指定各个分桶的上边界值,列表的数值必须严格递增。
  83. 如果有n个边界,则会创建n+1个桶,分配后桶的边界如下:[0, bucket_boundaries[0]),[bucket_boundaries[i], bucket_boundaries[i+1]),[bucket_boundaries[n-1], inf),其中,0<i<n-1。
  84. - **bucket_batch_sizes** (list[int]) - 指定每个分桶的批数据大小,必须包含 `len(bucket_boundaries)+1` 个元素。
  85. - **element_length_function** (Callable, 可选) - 长度计算函数。要求接收 `len(column_names)` 个输入参数,并返回一个整数代表该条数据的长度。
  86. 如果未指定该参数,则参数 `column_names` 的长度必须为1,此时该列数据的shape[0]值将被当做数据长度。默认值:None,不指定。
  87. - **pad_info** (dict, 可选) - 对指定数据列进行填充。通过传入dict来指定列信息与填充信息,要求dict的键是要填充的数据列名,dict的值是包含2个元素的元组。
  88. 元组中第1个元素表示要扩展至的目标shape,第2个元素表示要填充的值。
  89. 如果某一个数据列未指定将要填充后的shape和填充值,则该列中的每条数据都将填充至该批次中最长数据的长度,且填充值为0。
  90. 注意,`pad_info` 中任何填充shape为None的列,其每条数据长度都将被填充为当前批处理中最数据的长度,除非指定 `pad_to_bucket_boundary` 为True。默认值:None,不填充。
  91. - **pad_to_bucket_boundary** (bool, 可选) - 如果为True,则 `pad_info` 中填充shape为None的列,会被填充至由参数 `bucket_batch_sizes` 指定的对应分桶长度-1的长度。
  92. 如果有任何数据落入最后一个分桶中,则将报错。默认值:False。
  93. - **drop_remainder** (bool, 可选) - 当每个分桶中的最后一个批处理数据数据条目小于 `bucket_batch_sizes` 时,是否丢弃该批处理数据。默认值:False,不丢弃。
  94. **返回:**
  95. Dataset,按长度进行分桶和批处理操作后的数据集对象。
  96. **样例:**
  97. >>> # 创建一个数据集对象,其中给定条数的数据会被组成一个批次数据
  98. >>> # 如果最后一个批次数据小于给定的批次大小(batch_size),则丢弃这个批次
  99. >>> import numpy as np
  100. >>> def generate_2_columns(n):
  101. ... for i in range(n):
  102. ... yield (np.array([i]), np.array([j for j in range(i + 1)]))
  103. >>>
  104. >>> column_names = ["col1", "col2"]
  105. >>> dataset = ds.GeneratorDataset(generate_2_columns(8), column_names)
  106. >>> bucket_boundaries = [5, 10]
  107. >>> bucket_batch_sizes = [2, 1, 1]
  108. >>> element_length_function = (lambda col1, col2: max(len(col1), len(col2)))
  109. >>> # 将对列名为"col2"的列进行填充,填充后的shape为[bucket_boundaries[i]],其中i是当前正在批处理的桶的索引
  110. >>> pad_info = {"col2": ([None], -1)}
  111. >>> pad_to_bucket_boundary = True
  112. >>> dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
  113. ... bucket_batch_sizes,
  114. ... element_length_function, pad_info,
  115. ... pad_to_bucket_boundary)
  116. .. py:method:: build_sentencepiece_vocab(columns, vocab_size, character_coverage, model_type, params)
  117. 迭代源数据集对象获取数据并构建SentencePiece词汇表。
  118. **参数:**
  119. - **columns** (list[str]) - 指定 `build_sentencepiece_vocab` 操作的输入列,会从该列获取数据构造词汇表。
  120. - **vocab_size** (int) - 词汇表的容量。
  121. - **character_coverage** (float) - 模型涵盖的字符百分比,必须介于0.98和1.0之间。
  122. 对于具有丰富字符集的语言,如日语或中文字符集,推荐使用0.9995;对于其他字符集较小的语言,比如英语或拉丁文,推荐使用1.0。
  123. - **model_type** (SentencePieceModel) - 训练的SentencePiece模型类型,可取值为'SentencePieceModel.UNIGRAM'、'SentencePieceModel.BPE'、'SentencePieceModel.CHAR'或'SentencePieceModel.WORD'。
  124. 当取值为'SentencePieceModel.WORD'时,输入的数据必须进行预分词(pretokenize)。默认值:SentencePieceModel.UNIGRAM。
  125. - **params** (dict) - 如果希望使用SentencePiece的其他参数,可以构造一个dict进行传入,键为SentencePiece库接口的输入参数名,值为参数值。
  126. **返回:**
  127. 构建好的SentencePiece词汇表。
  128. **样例:**
  129. >>> from mindspore.dataset.text import SentencePieceModel
  130. >>>
  131. >>> # DE_C_INTER_SENTENCEPIECE_MODE 是一个映射字典
  132. >>> from mindspore.dataset.text.utils import DE_C_INTER_SENTENCEPIECE_MODE
  133. >>> dataset = ds.TextFileDataset("/path/to/sentence/piece/vocab/file", shuffle=False)
  134. >>> dataset = dataset.build_sentencepiece_vocab(["text"], 5000, 0.9995,
  135. ... DE_C_INTER_SENTENCEPIECE_MODE[SentencePieceModel.UNIGRAM],
  136. ... {})
  137. .. py:method:: build_vocab(columns, freq_range, top_k, special_tokens, special_first)
  138. 迭代源数据集对象获取数据并构建词汇表。
  139. 收集数据集中所有的不重复单词,并返回 `top_k` 个最常见的单词组成的词汇表(如果指定了 `top_k` )。
  140. **参数:**
  141. - **columns** (Union[str, list[str]]) :指定 `build_vocab` 操作的输入列,会从该列获取数据构造词汇表。
  142. - **freq_range** (tuple[int]) - 由(min_frequency, max_frequency)组成的整数元组,代表词汇出现的频率范围,在这个频率范围的词汇会被保存下来。
  143. 取值范围需满足:0 <= min_frequency <= max_frequency <= 单词总数,其中min_frequency、max_frequency的默认值分别设置为0、单词总数。
  144. - **top_k** (int) - 使用 `top_k` 个最常见的单词构建词汇表。 假如指定了参数 `freq_range` ,则优先统计给定频率范围内的词汇,再根据参数 `top_k` 选取最常见的单词构建词汇表。
  145. 如果 `top_k` 的值大于单词总数,则取所有单词构建词汇表。
  146. - **special_tokens** (list[str]) - 指定词汇表的特殊标记(special token),如'[UNK]'、'[SEP]'。
  147. - **special_first** (bool) - 是否将参数 `special_tokens` 指定的特殊标记添加到词汇表的开头。如果为True则放到开头,否则放到词汇表的结尾。
  148. **返回:**
  149. 构建好的词汇表。
  150. **样例:**
  151. >>> def gen_corpus():
  152. ... # key:单词,value:出现次数,键的取值采用字母表示有利于排序和显示。
  153. ... corpus = {"Z": 4, "Y": 4, "X": 4, "W": 3, "U": 3, "V": 2, "T": 1}
  154. ... for k, v in corpus.items():
  155. ... yield (np.array([k] * v, dtype='S'),)
  156. >>> column_names = ["column1", "column2", "column3"]
  157. >>> dataset = ds.GeneratorDataset(gen_corpus, column_names)
  158. >>> dataset = dataset.build_vocab(columns=["column3", "column1", "column2"],
  159. ... freq_range=(1, 10), top_k=5,
  160. ... special_tokens=["<pad>", "<unk>"],
  161. ... special_first=True,vocab='vocab')
  162. .. py:method:: close_pool()
  163. 关闭数据集对象中的多进程池。如果您熟悉多进程库,可以将此视为进程池对象的析构函数。
  164. .. py:method:: concat(datasets)
  165. 对传入的多个数据集对象进行拼接操作,也可以使用"+"运算符来进行数据集进行拼接。
  166. .. note::用于拼接的多个数据集对象,每个数据集对象的列名、每列数据的维度(rank)和数据类型必须相同。
  167. **参数:**
  168. - **datasets** (Union[list, Dataset]) - 与当前数据集对象拼接的数据集对象列表或单个数据集对象。
  169. **返回:**
  170. Dataset,拼接后的数据集对象。
  171. **样例:**
  172. >>> # 使用"+"运算符拼接dataset_1和dataset_2,获得拼接后的数据集对象
  173. >>> dataset = dataset_1 + dataset_2
  174. >>> # 通过concat操作拼接dataset_1和dataset_2,获得拼接后的数据集对象
  175. >>> dataset = dataset_1.concat(dataset_2)
  176. .. py:method:: create_dict_iterator(num_epochs=-1, output_numpy=False)
  177. 基于数据集对象创建迭代器,输出的数据为字典类型。
  178. **参数:**
  179. - **num_epochs** (int, 可选) - 迭代器可以迭代的最大次数。默认值:-1,迭代器可以迭代无限次。
  180. - **output_numpy** (bool, 可选) - 输出的数据是否转为NumPy类型。如果为False,迭代器输出的每列数据类型为MindSpore.Tensor,否则为NumPy。默认值:False。
  181. **返回:**
  182. DictIterator,基于数据集对象创建的字典迭代器。
  183. **样例:**
  184. >>> # dataset是任意数据集对象的实例
  185. >>> iterator = dataset.create_dict_iterator()
  186. >>> for item in iterator:
  187. ... # item 是一个dict
  188. ... print(type(item))
  189. ... break
  190. <class 'dict'>
  191. .. py:method:: create_tuple_iterator(columns=None, num_epochs=-1, output_numpy=False, do_copy=True)
  192. 基于数据集对象创建迭代器,输出数据为ndarray组成的列表。
  193. 可以通过参数 `columns` 指定输出的所有列名及列的顺序。如果columns未指定,列的顺序将保持不变。
  194. **参数:**
  195. - **columns** (list[str], 可选) - 用于指定输出的数据列和列的顺序。默认值:None,输出所有数据列。
  196. - **num_epochs** (int, 可选) - 迭代器可以迭代的最大次数。默认值:-1,迭代器可以迭代无限次。
  197. - **output_numpy** (bool, 可选) - 输出的数据是否转为NumPy类型。如果为False,迭代器输出的每列数据类型为MindSpore.Tensor,否则为NumPy。默认值:False。
  198. - **do_copy** (bool, 可选) - 当参数 `output_numpy` 为False,即输出数据类型为mindspore.Tensor时,可以将此参数指定为False以减少拷贝,获得更好的性能。默认值:True。
  199. **返回:**
  200. TupleIterator,基于数据集对象创建的元组迭代器。
  201. **样例:**
  202. >>> # dataset是任意数据集对象的实例
  203. >>> iterator = dataset.create_tuple_iterator()
  204. >>> for item in iterator:
  205. ... # item 是一个list
  206. ... print(type(item))
  207. ... break
  208. <class 'list'>
  209. .. py:method:: device_que(send_epoch_end=True, create_data_info_queue=False)
  210. 将数据异步传输到Ascend/GPU设备上。
  211. **参数:**
  212. - **send_epoch_end** (bool, 可选) - 数据发送完成后是否发送结束标识到设备上,默认值:True。
  213. - **create_data_info_queue** (bool, 可选) - 是否创建一个队列,用于存储每条数据的数据类型和shape。默认值:False,不创建。
  214. .. note::
  215. 如果设备类型为Ascend,每次传输的数据大小限制为256MB。
  216. **返回:**
  217. Dataset,用于帮助发送数据到设备上的数据集对象。
  218. .. py:method:: dynamic_min_max_shapes()
  219. 当数据集对象中的数据shape不唯一(动态shape)时,获取数据的最小shape和最大shape。
  220. **返回:**
  221. 两个列表代表最小shape和最大shape,每个列表中的shape按照数据列的顺序排列。
  222. **样例:**
  223. >>> import numpy as np
  224. >>>
  225. >>> def generator1():
  226. >>> for i in range(1, 100):
  227. >>> yield np.ones((16, i, 83)), np.array(i)
  228. >>>
  229. >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
  230. >>> dataset.set_dynamic_columns(columns={"data1": [16, None, 83], "data2": []})
  231. >>> min_shapes, max_shapes = dataset.dynamic_min_max_shapes()
  232. .. py:method:: filter(predicate, input_columns=None, num_parallel_workers=None)
  233. 通过自定义判断条件对数据集对象中的数据进行过滤。
  234. **参数:**
  235. - **predicate** (callable) - Python可调用对象。要求该对象接收n个入参,用于指代每个数据列的数据,最后返回值一个bool值。
  236. 如果返回值为False,则表示过滤掉该条数据。注意n的值与参数 `input_columns` 表示的输入列数量一致。
  237. - **input_columns** (Union[str, list[str]], 可选) - `filter` 操作的输入数据列。默认值:None,`predicate` 将应用于数据集中的所有列。
  238. - **num_parallel_workers** (int, 可选) - 指定 `filter` 操作的并发线程数。默认值:None,使用mindspore.dataset.config中配置的线程数。
  239. **返回:**
  240. Dataset,执行给定筛选过滤操作的数据集对象。
  241. **样例:**
  242. >>> # 生成一个list,其取值范围为(0,63)
  243. >>> def generator_1d():
  244. ... for i in range(64):
  245. ... yield (np.array(i),)
  246. >>> dataset = ds.GeneratorDataset(generator_1d, ["data"])
  247. >>> # 过滤掉数值大于或等于11的数据
  248. >>> dataset = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"])
  249. .. py:method:: flat_map(func)
  250. 对数据集对象中每一条数据执行给定的数据处理,并将结果展平。
  251. **参数:**
  252. - **func** (function) - 数据处理函数,要求输入必须为一个'ndarray',返回值是一个'Dataset'对象。
  253. **返回:**
  254. 执行给定操作后的数据集对象。
  255. **样例:**
  256. >>> # 以NumpySlicesDataset为例
  257. >>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]])
  258. >>>
  259. >>> def flat_map_func(array):
  260. ... # 使用数组创建NumpySlicesDataset
  261. ... dataset = ds.NumpySlicesDataset(array)
  262. ... # 将数据集对象中的数据重复两次
  263. ... dataset = dataset.repeat(2)
  264. ... return dataset
  265. >>>
  266. >>> dataset = dataset.flat_map(flat_map_func)
  267. >>> # [[0, 1], [0, 1], [2, 3], [2, 3]]
  268. **异常:**
  269. - **TypeError** - `func` 不是函数。
  270. - **TypeError** - `func` 的返回值不是数据集对象。
  271. .. py:method:: get_batch_size()
  272. 获得数据集对象定义的批处理大小,即一个批处理数据中包含的数据条数。
  273. **返回:**
  274. int,一个批处理数据中包含的数据条数。
  275. **样例:**
  276. >>> # dataset是任意数据集对象的实例
  277. >>> dataset = dataset.batch(16)
  278. >>> batch_size = dataset.get_batch_size()
  279. .. py:method:: get_class_indexing()
  280. 返回类别索引。
  281. **返回:**
  282. dict,描述类别名称到索引的键值对映射关系,通常为str-to-int格式。针对COCO数据集,类别名称到索引映射关系描述形式为str-to-list<int>格式,列表中的第二个数字表示超级类别。
  283. **样例:**
  284. >>> # dataset是数据集类的实例化对象
  285. >>> class_indexing = dataset.get_class_indexing()
  286. .. py:method:: get_col_names()
  287. 返回数据集对象中包含的列名。
  288. **返回:**
  289. list,数据集中所有列名组成列表。
  290. **样例:**
  291. >>> # dataset是数据集类的实例化对象
  292. >>> col_names = dataset.get_col_names()
  293. .. py:method:: get_dataset_size()
  294. 返回一个epoch中的batch数。
  295. **返回:**
  296. int,batch的数目。
  297. .. py:method:: get_repeat_count()
  298. 获取 `RepeatDataset` 中的repeat次数(默认为1)。
  299. **返回:**
  300. int,repeat次数。
  301. .. py:method:: input_indexs
  302. :property:
  303. 获取input index信息。
  304. **返回:**
  305. input index信息的元组。
  306. **样例:**
  307. >>> # dataset是Dataset对象的实例
  308. >>> # 设置input_indexs
  309. >>> dataset.input_indexs = 10
  310. >>> print(dataset.input_indexs)
  311. 10
  312. .. py:method:: map(operations, input_columns=None, output_columns=None, column_order=None, num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None)
  313. 将operations列表中的每个operation作用于数据集。
  314. 作用的顺序由每个operation在operations参数中的位置决定。
  315. 将首先作用operation[0],然后operation[1],operation[2],以此类推。
  316. 每个operation将数据集中的一列或多列作为输入,并将输出零列或多列。
  317. 第一个operation将 `input_columns` 中指定的列作为输入。
  318. 如果operations列表中存在多个operation,则上一个operation的输出列将用作下一个operation的输入列。
  319. 最后一个operation输出列的列名由 `output_columns` 指定。
  320. 只有在 `column_order` 中指定的列才会传播到子节点,并且列的顺序将与 `column_order` 中指定的顺序相同。
  321. **参数:**
  322. - **operations** (Union[list[TensorOp], list[functions]]) - 要作用于数据集的operations列表。将按operations列表中显示的顺序作用在数据集。
  323. - **input_columns** (Union[str, list[str]], optional) - 第一个operation输入的列名列表。此列表的大小必须与第一个operation预期的输入列数相匹配。(默认为None,从第一列开始,无论多少列,都将传递给第一个operation)。
  324. - **output_columns** (Union[str, list[str]], optional) - 最后一个operation输出的列名列表。如果 `input_columns` 长度不等于 `output_columns` 长度,则此参数必选。此列表的大小必须与最后一个operation的输出列数相匹配(默认为None,输出列将与输入列具有相同的名称,例如,替换一些列)。
  325. - **column_order** (list[str], optional) - 指定整个数据集中所需的所有列的列表。当 `input_columns` 长度不等于 `output_columns` 长度时,则此参数必选。注意:这里的列表不仅仅是参数 `input_columns` 和 `output_columns` 中指定的列。
  326. - **num_parallel_workers** (int, optional) - 用于并行处理数据集的线程数(默认为None,将使用配置文件中的值)。
  327. - **python_multiprocessing** (bool, optional) - 将Python operations委托给多个工作进程进行并行处理。如果Python operations计算量很大,此选项可能会很有用(默认值为False)。
  328. - **cache** (DatasetCache, optional) - 使用Tensor缓存服务加快数据集处理速度(默认为None,即不使用缓存)。
  329. - **callbacks** (DSCallback, list[DSCallback], optional) - 要调用的Dataset回调函数列表(默认为None)。
  330. .. note::
  331. - `operations` 参数主要接收 `mindspore.dataset` 模块中c_transforms、py_transforms算子,以及用户定义的Python函数(PyFuncs)。
  332. - 不要将 `mindspore.nn` 和 `mindspore.ops` 或其他的网络计算算子添加到 `operations` 中。
  333. **返回:**
  334. MapDataset,map操作后的数据集。
  335. **样例:**
  336. >>> # dataset是Dataset的一个实例,它有2列,"image"和"label"。
  337. >>>
  338. >>> # 定义两个operation,每个operation接受1列输入,输出1列。
  339. >>> decode_op = c_vision.Decode(rgb=True)
  340. >>> random_jitter_op = c_vision.RandomColorAdjust(brightness=(0.8, 0.8), contrast=(1, 1),
  341. ... saturation=(1, 1), hue=(0, 0))
  342. >>>
  343. >>> # 1)简单的map示例。
  344. >>>
  345. >>> # 在列“image"上应用decode_op。此列将被
  346. >>> # decode_op的输出列替换。由于未指定column_order,因此两列“image"
  347. >>> # 和“label"将按其原始顺序传播到下一个节点。
  348. >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"])
  349. >>>
  350. >>> # 解码列“image"并将其重命名为“decoded_image"。
  351. >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"], output_columns=["decoded_image"])
  352. >>>
  353. >>> # 指定输出列的顺序。
  354. >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"],
  355. ... output_columns=None, column_order=["label", "image"])
  356. >>>
  357. >>> # 将列“image"重命名为“decoded_image",并指定输出列的顺序。
  358. >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"],
  359. ... output_columns=["decoded_image"], column_order=["label", "decoded_image"])
  360. >>>
  361. >>> # 将列“image"重命名为“decoded_image",并只保留此列。
  362. >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"],
  363. ... output_columns=["decoded_image"], column_order=["decoded_image"])
  364. >>>
  365. >>> # 使用用户自定义Python函数的map简单示例。列重命名和指定列顺序
  366. >>> # 的方式同前面的示例相同。
  367. >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"])
  368. >>> dataset = dataset.map(operations=[(lambda x: x + 1)], input_columns=["data"])
  369. >>>
  370. >>> # 2)多个operation的map示例。
  371. >>>
  372. >>> # 创建一个数据集,图像被解码,并随机颜色抖动。
  373. >>> # decode_op以列“image"作为输入,并输出一列。将
  374. >>> # decode_op输出的列作为输入传递给random_jitter_op。
  375. >>> # random_jitter_op将输出一列。列“image"将替换为
  376. >>> # random_jitter_op(最后一个operation)输出的列。所有其他
  377. >>> # 列保持不变。由于未指定column_order,因此
  378. >>> # 列的顺序将保持不变。
  379. >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"])
  380. >>>
  381. >>> # 将random_jitter_op输出的列重命名为“image_mapped"。
  382. >>> # 指定列顺序的方式与1中的示例相同。
  383. >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"],
  384. ... output_columns=["image_mapped"])
  385. >>>
  386. >>> # 使用用户自定义Python函数的多个operation的map示例。列重命名和指定列顺序
  387. >>> # 的方式与1中的示例相同。
  388. >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"])
  389. >>> dataset = dataset.map(operations=[(lambda x: x * x), (lambda x: x - 1)], input_columns=["data"],
  390. ... output_columns=["data_mapped"])
  391. >>>
  392. >>> # 3)输入列数不等于输出列数的示例。
  393. >>>
  394. >>> # operation[0] 是一个 lambda,它以 2 列作为输入并输出 3 列。
  395. >>> # operations[1] 是一个 lambda,它以 3 列作为输入并输出 1 列。
  396. >>> # operations[2] 是一个 lambda,它以 1 列作为输入并输出 4 列。
  397. >>> #
  398. >>> # 注:operation[i]的输出列数必须等于
  399. >>> # operation[i+1]的输入列。否则,map算子会
  400. >>> # 出错。
  401. >>> operations = [(lambda x, y: (x, x + y, x + y + 1)),
  402. ... (lambda x, y, z: x * y * z),
  403. ... (lambda x: (x % 2, x % 3, x % 5, x % 7))]
  404. >>>
  405. >>> # 注:由于输入列数与
  406. >>> # 输出列数不相同,必须指定output_columns和column_order
  407. >>> # 参数。否则,此map算子也会出错。
  408. >>>
  409. >>> dataset = ds.NumpySlicesDataset(data=([[0, 1, 2]], [[3, 4, 5]]), column_names=["x", "y"])
  410. >>>
  411. >>> # 按以下顺序将所有列传播到子节点:
  412. >>> dataset = dataset.map(operations, input_columns=["x", "y"],
  413. ... output_columns=["mod2", "mod3", "mod5", "mod7"],
  414. ... column_order=["mod2", "mod3", "mod5", "mod7"])
  415. >>>
  416. >>> # 按以下顺序将某些列传播到子节点:
  417. >>> dataset = dataset.map(operations, input_columns=["x", "y"],
  418. ... output_columns=["mod2", "mod3", "mod5", "mod7"],
  419. ... column_order=["mod7", "mod3", "col2"])
  420. .. py:method:: num_classes()
  421. 获取数据集中的样本的class数目。
  422. **返回:**
  423. int,class数目。
  424. .. py:method:: output_shapes()
  425. 获取输出数据的shape。
  426. **返回:**
  427. list,每列shape的列表。
  428. .. py:method:: output_types()
  429. 获取输出数据类型。
  430. **返回:**
  431. list,每列类型的列表。
  432. .. py:method:: project(columns)
  433. 在输入数据集上投影某些列。
  434. 从数据集中选择列,并以指定的顺序传输到流水线中。
  435. 其他列将被丢弃。
  436. **参数:**
  437. - **columns** (Union[str, list[str]]) - 要投影列的列名列表。
  438. **返回:**
  439. ProjectDataset,投影后的数据集对象。
  440. **样例:**
  441. >>> # dataset是Dataset对象的实例
  442. >>> columns_to_project = ["column3", "column1", "column2"]
  443. >>>
  444. >>> # 创建一个数据集,无论列的原始顺序如何,依次包含column3, column1, column2。
  445. >>> dataset = dataset.project(columns=columns_to_project)
  446. .. py:method:: rename(input_columns, output_columns)
  447. 重命名输入数据集中的列。
  448. **参数:**
  449. - **input_columns** (Union[str, list[str]]) - 输入列的列名列表。
  450. - **output_columns** (Union[str, list[str]]) - 输出列的列名列表。
  451. **返回:**
  452. RenameDataset,重命名后数据集对象。
  453. **样例:**
  454. >>> # dataset是Dataset对象的实例
  455. >>> input_columns = ["input_col1", "input_col2", "input_col3"]
  456. >>> output_columns = ["output_col1", "output_col2", "output_col3"]
  457. >>>
  458. >>> # 创建一个数据集,其中input_col1重命名为output_col1,
  459. >>> # input_col2重命名为output_col2,input_col3重命名
  460. >>> # 为output_col3。
  461. >>> dataset = dataset.rename(input_columns=input_columns, output_columns=output_columns)
  462. .. py:method:: repeat(count=None)
  463. 重复此数据集 `count` 次。如果count为None或-1,则无限重复。
  464. .. note::
  465. repeat和batch的顺序反映了batch的数量。建议:repeat操作在batch操作之后使用。
  466. **参数:**
  467. - **count** (int) - 数据集重复的次数(默认为None)。
  468. **返回:**
  469. RepeatDataset,重复操作后的数据集对象。
  470. **样例:**
  471. >>> # dataset是Dataset对象的实例
  472. >>>
  473. >>> # 创建一个数据集,数据集重复50个epoch。
  474. >>> dataset = dataset.repeat(50)
  475. >>>
  476. >>> # 创建一个数据集,其中每个epoch都是单独打乱的。
  477. >>> dataset = dataset.shuffle(10)
  478. >>> dataset = dataset.repeat(50)
  479. >>>
  480. >>> # 创建一个数据集,打乱前先将数据集重复
  481. >>> # 50个epoch。shuffle算子将
  482. >>> # 整个50个epoch视作一个大数据集。
  483. >>> dataset = dataset.repeat(50)
  484. >>> dataset = dataset.shuffle(10)
  485. .. py:method:: reset()
  486. 重置下一个epoch的数据集。
  487. .. py:method:: save(file_name, num_files=1, file_type='mindrecord')
  488. 将流水线正在处理的数据保存为通用的数据集格式。支持的数据集格式:'mindrecord'。
  489. 将数据保存为'mindrecord'格式时存在隐式类型转换。转换表展示如何执行类型转换。
  490. .. list-table:: 保存为'mindrecord'格式时的隐式类型转换
  491. :widths: 25 25 50
  492. :header-rows: 1
  493. * - 'dataset'类型
  494. - 'mindrecord'类型
  495. - 详细
  496. * - bool
  497. - None
  498. - 不支持
  499. * - int8
  500. - int32
  501. -
  502. * - uint8
  503. - bytes(1D uint8)
  504. - Drop dimension
  505. * - int16
  506. - int32
  507. -
  508. * - uint16
  509. - int32
  510. -
  511. * - int32
  512. - int32
  513. -
  514. * - uint32
  515. - int64
  516. -
  517. * - int64
  518. - int64
  519. -
  520. * - uint64
  521. - None
  522. - 不支持
  523. * - float16
  524. - float32
  525. -
  526. * - float32
  527. - float32
  528. -
  529. * - float64
  530. - float64
  531. -
  532. * - string
  533. - string
  534. - 不支持多维字符串
  535. .. note::
  536. 1. 如需按顺序保存示例,请将数据集的shuffle设置为False,将 `num_files` 设置为1。
  537. 2. 在调用函数之前,不要使用batch算子、repeat算子或具有随机属性的数据增强的map算子。
  538. 3. 当数据的维度可变时,只支持1维数组或者在0维变化的多维数组。
  539. 4. 不支持DE_UINT64类型、多维的DE_UINT8类型、多维DE_STRING类型。
  540. **参数:**
  541. - **file_name** (str) - 数据集文件的路径。
  542. - **num_files** (int, optional) - 数据集文件的数量(默认为1)。
  543. - **file_type** (str, optional) - 数据集格式(默认为'mindrecord')。
  544. .. py:method:: set_dynamic_columns(columns=None)
  545. 设置源数据的动态shape信息,需要在定义数据处理流水线后设置。
  546. **参数:**
  547. - **columns** (dict) - 包含数据集中每列shape信息的字典。shape[i]为 `None` 表示shape[i]的数据长度是动态的。
  548. .. py:method:: shuffle(buffer_size)
  549. 使用以下策略随机打乱此数据集的行:
  550. 1. 生成一个shuffle缓冲区包含buffer_size条数据行。
  551. 2. 从shuffle缓冲区中随机选择一个元素,作为下一行传播到子节点。
  552. 3. 从父节点获取下一行(如果有的话),并将其放入shuffle缓冲区中。
  553. 4. 重复步骤2和3,直到打乱缓冲区中没有数据行为止。
  554. 可以提供随机种子,在第一个epoch中使用。在随后的每个epoch,种子都会被设置成一个新产生的随机值。
  555. **参数:**
  556. - **buffer_size** (int) - 用于shuffle的缓冲区大小(必须大于1)。将buffer_size设置为等于数据集大小将导致在全局shuffle。
  557. **返回:**
  558. ShuffleDataset,打乱后的数据集对象。
  559. **异常:**
  560. - **RuntimeError** - 打乱前存在同步操作。
  561. **样例:**
  562. >>> # dataset是Dataset对象的实例
  563. >>> # 可以选择设置第一个epoch的种子
  564. >>> ds.config.set_seed(58)
  565. >>> # 使用大小为4的shuffle缓冲区创建打乱后的数据集。
  566. >>> dataset = dataset.shuffle(4)
  567. .. py:method:: skip(count)
  568. 跳过此数据集的前N个元素。
  569. **参数:**
  570. - **count** (int) - 要跳过的数据集中的元素个数。
  571. **返回:**
  572. SkipDataset,减去跳过的行的数据集对象。
  573. **样例:**
  574. >>> # dataset是Dataset对象的实例
  575. >>> # 创建一个数据集,跳过前3个元素
  576. >>> dataset = dataset.skip(3)
  577. .. py:method:: split(sizes, randomize=True)
  578. 将数据集拆分为多个不重叠的数据集。
  579. 这是一个通用拆分函数,可以被数据处理流水线中的任何算子调用。
  580. 还有如果直接调用ds.split,其中 ds 是一个 MappableDataset,它将被自动调用。
  581. **参数:**
  582. - **sizes** (Union[list[int], list[float]]) - 如果指定了一列整数[s1, s2, …, sn],数据集将被拆分为n个大小为s1、s2、...、sn的数据集。如果所有输入大小的总和不等于原始数据集大小,则报错。如果指定了一列浮点数[f1, f2, …, fn],则所有浮点数必须介于0和1之间,并且总和必须为1,否则报错。数据集将被拆分为n个大小为round(f1*K)、round(f2*K)、...、round(fn*K)的数据集,其中K是原始数据集的大小。
  583. 如果舍入后:
  584. - 任何大小等于0,都将发生错误。
  585. - 如果拆分大小的总和<K,K - sigma(round(fi * k))的差值将添加到第一个子数据集。
  586. - 如果拆分大小的总和>K,sigma(round(fi * K)) - K的差值将从第一个足够大的拆分子集中删除,删除差值后至少有1行。
  587. - **randomize** (bool, optional) - 确定是否随机拆分数据(默认为True)。如果为True,则数据集将被随机拆分。否则,将使用数据集中的连续行创建每个拆分子集。
  588. .. note::
  589. 1. 如果要调用 split,则无法对数据集进行分片。
  590. 2. 强烈建议不要对数据集进行打乱,而是使用随机化(randomize=True)。对数据集进行打乱的结果具有不确定性,每个拆分子集中的数据在每个epoch可能都不同。
  591. **异常:**
  592. - **RuntimeError** - get_dataset_size返回None或此数据集不支持。
  593. - **RuntimeError** - sizes是整数列表,并且size中所有元素的总和不等于数据集大小。
  594. - **RuntimeError** - sizes是float列表,并且计算后存在大小为0的拆分子数据集。
  595. - **RuntimeError** - 数据集在调用拆分之前已进行分片。
  596. - **ValueError** - sizes是float列表,且并非所有float数都在0和1之间,或者float数的总和不等于1。
  597. **返回:**
  598. tuple(Dataset),拆分后子数据集对象的元组。
  599. **样例:**
  600. >>> # TextFileDataset不是可映射dataset,因此将调用通用拆分函数。
  601. >>> # 由于许多数据集默认都打开了shuffle,如需调用拆分函数,请将shuffle设置为False。
  602. >>> dataset = ds.TextFileDataset(text_file_dataset_dir, shuffle=False)
  603. >>> train_dataset, test_dataset = dataset.split([0.9, 0.1])
  604. .. py:method:: sync_update(condition_name, num_batch=None, data=None)
  605. 释放阻塞条件并使用给定数据触发回调函数。
  606. **参数:**
  607. - **condition_name** (str) - 用于切换发送下一行数据的条件名称。
  608. - **num_batch** (Union[int, None]) - 释放的batch(row)数。当 `num_batch` 为None时,将默认为 `sync_wait` 算子指定的值(默认为None)。
  609. - **data** (Any) - 用户自定义传递给回调函数的数据(默认为None)。
  610. .. py:method:: sync_wait(condition_name, num_batch=1, callback=None)
  611. 向输入数据集添加阻塞条件。 将应用同步操作。
  612. **参数:**
  613. - **condition_name** (str) - 用于切换发送下一行的条件名称。
  614. - **num_batch** (int) - 每个epoch开始时无阻塞的batch数。
  615. - **callback** (function) - `sync_update` 中将调用的回调函数。
  616. **返回:**
  617. SyncWaitDataset,添加了阻塞条件的数据集对象。
  618. **异常:**
  619. - **RuntimeError** - 条件名称已存在。
  620. **样例:**
  621. >>> import numpy as np
  622. >>> def gen():
  623. ... for i in range(100):
  624. ... yield (np.array(i),)
  625. >>>
  626. >>> class Augment:
  627. ... def __init__(self, loss):
  628. ... self.loss = loss
  629. ...
  630. ... def preprocess(self, input_):
  631. ... return input_
  632. ...
  633. ... def update(self, data):
  634. ... self.loss = data["loss"]
  635. >>>
  636. >>> batch_size = 4
  637. >>> dataset = ds.GeneratorDataset(gen, column_names=["input"])
  638. >>>
  639. >>> aug = Augment(0)
  640. >>> dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  641. >>> dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
  642. >>> dataset = dataset.batch(batch_size)
  643. >>> count = 0
  644. >>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  645. ... assert data["input"][0] == count
  646. ... count += batch_size
  647. ... data = {"loss": count}
  648. ... dataset.sync_update(condition_name="policy", data=data)
  649. .. py:method:: take(count=-1)
  650. 从数据集中获取最多给定数量的元素。
  651. .. note::
  652. 1. 如果count大于数据集中的元素数或等于-1,则取数据集中的所有元素。
  653. 2. take和batch操作顺序很重要,如果take在batch操作之前,则取给定行数;否则取给定batch数。
  654. **参数:**
  655. - **count** (int, optional) - 要从数据集中获取的元素数(默认为-1)。
  656. **返回:**
  657. TakeDataset,取出指定数目的数据集对象。
  658. **样例:**
  659. >>> # dataset是Dataset对象的实例。
  660. >>> # 创建一个数据集,包含50个元素。
  661. >>> dataset = dataset.take(50)
  662. .. py:method:: to_device(send_epoch_end=True, create_data_info_queue=False)
  663. 将数据从CPU传输到GPU、Ascend或其他设备。
  664. **参数:**
  665. - **send_epoch_end** (bool, optional) - 是否将end of sequence发送到设备(默认为True)。
  666. - **create_data_info_queue** (bool, optional) - 是否创建存储数据类型和shape的队列(默认值为False)。
  667. .. note::
  668. 如果设备为Ascend,则逐个传输数据。每次传输的数据最大限制为256M。
  669. **返回:**
  670. TransferDataset,用于传输的数据集对象。
  671. **异常:**
  672. - **RuntimeError** - 如果提供了分布式训练的文件路径但读取失败。
  673. .. py:method:: to_json(filename='')
  674. 将数据处理流水线序列化为JSON字符串,如果提供了文件名,则转储到文件中。
  675. **参数:**
  676. - **filename** (str) - 另存为JSON格式的文件名。
  677. **返回:**
  678. str,流水线的JSON字符串。