You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.dataset.Dataset.rst 45 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930
  1. .. py:method:: apply(apply_func)
  2. 对数据集对象执行给定操作函数。
  3. **参数:**
  4. `apply_func` (function):传入 `Dataset` 对象作为参数,并将返回处理后的 `Dataset` 对象。
  5. **返回:**
  6. 执行了给定操作函数的数据集对象。
  7. **样例:**
  8. >>> # dataset是数据集类的实例化对象
  9. >>>
  10. >>> # 声明一个名为apply_func函数,其返回值是一个Dataset对象
  11. >>> def apply_func(data):
  12. ... data = data.batch(2)
  13. ... return data
  14. >>>
  15. >>> # 通过apply操作调用apply_func函数
  16. >>> dataset = dataset.apply(apply_func)
  17. **异常:**
  18. - **TypeError:** `apply_func` 不是一个函数。
  19. - **TypeError:** `apply_func` 未返回Dataset对象。
  20. .. py:method:: batch(batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, input_columns=None, output_columns=None, column_order=None, pad_info=None, python_multiprocessing=False)
  21. 将dataset中连续 `batch_size` 行数据合并为一个批处理数据。
  22. 对一个批处理数据执行给定操作与对条数据进行给定操作用法一致。对于任意列,batch操作要求该列中的各条数据shape必须相同。如果给定可执行函数 `per_batch_map` ,它将作用于批处理后的数据。
  23. .. note::
  24. 执行 `repeat` 和 `batch` 操作的顺序,会影响数据批次的数量及 `per_batch_map` 操作。建议在batch操作完成后执行repeat操作。
  25. **参数:**
  26. - **batch_size** (int or function) - 每个批处理数据包含的条数。参数需要是int或可调用对象,该对象接收1个参数,即BatchInfo。
  27. - **drop_remainder** (bool, optional) - 是否删除最后一个数据条数小于批处理大小的batch(默认值为False)。如果为True,并且最后一个批次中数据行数少于 `batch_size`,则这些数据将被丢弃,不会传递给后续的操作。
  28. - **num_parallel_workers** (int, optional) - 用于进行batch操作的的线程数(threads),默认值为None。
  29. - **per_batch_map** (callable, optional) - 是一个以(list[Tensor], list[Tensor], ..., BatchInfo)作为输入参数的可调用对象。每个list[Tensor]代表给定列上的一批Tensor。入参中list[Tensor]的个数应与 `input_columns` 中传入列名的数量相匹配。该可调用对象的最后一个参数始终是BatchInfo对象。`per_batch_map` 应返回(list[Tensor], list[Tensor], ...)。其出中list[Tensor]的个数应与输入相同。如果输出列数与输入列数不一致,则需要指定 `output_columns`。 - **input_columns** (Union[str, list[str]], optional):由输入列名组成的列表。如果 `per_batch_map` 不为None,列表中列名的个数应与 `per_batch_map` 中包含的列数匹配(默认为None)。
  30. - **output_columns** (Union[str, list[str]], optional) - 当前操作所有输出列的列名列表。如果len(input_columns) != len(output_columns),则此参数必须指定。此列表中列名的数量必须与给定操作的输出列数相匹配(默认为None,输出列将与输入列具有相同的名称)。
  31. - **column_order** (Union[str, list[str]], optional) - 指定整个数据集对象中包含的所有列名的顺序。如果len(input_column) != len(output_column),则此参数必须指定。 注意:这里的列名不仅仅是在 `input_columns` 和 `output_columns` 中指定的列。
  32. - **pad_info** (dict, optional) - 用于对给定列进行填充。例如 `pad_info={"col1":([224,224],0)}` ,则将列名为"col1"的列填充到大小为[224,224]的张量,并用0填充缺失的值(默认为None)。
  33. - **python_multiprocessing** (bool, optional) - 针对 `per_batch_map` 函数,使用Python多进执行的方式进行调用。如果函数计算量大,开启这个选项可能会很有帮助(默认值为False)。
  34. **返回:**
  35. 批处理后的数据集对象。
  36. **样例:**
  37. >>> # 创建一个数据集对象,每100条数据合并成一个批次
  38. >>> # 如果最后一个批次数据小于给定的批次大小(batch_size),则丢弃这个批次
  39. >>> dataset = dataset.batch(100, True)
  40. >>> # 根据批次编号调整图像大小,如果是第5批,则图像大小调整为(5^2, 5^2) = (25, 25)
  41. >>> def np_resize(col, batchInfo):
  42. ... output = col.copy()
  43. ... s = (batchInfo.get_batch_num() + 1) ** 2
  44. ... index = 0
  45. ... for c in col:
  46. ... img = Image.fromarray(c.astype('uint8')).convert('RGB')
  47. ... img = img.resize((s, s), Image.ANTIALIAS)
  48. ... output[index] = np.array(img)
  49. ... index += 1
  50. ... return (output,)
  51. >>> dataset = dataset.batch(batch_size=8, input_columns=["image"], per_batch_map=np_resize)
  52. .. py:method:: bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, element_length_function=None, pad_info=None, pad_to_bucket_boundary=False, drop_remainder=False)
  53. 依据数据中元素长度进行分桶。每个桶将在满了的时候进行元素填充和批处理操作。
  54. 对数据集中的每一条数据执行长度计算函数。然后,根据该条数据的长度和桶的边界将该数据归到特定的桶里面。当桶中数据条数达到指定的大小 `bucket_batch_sizes` 时,将根据 `pad_info` 对桶中元素进行填充,再进行批处理。这样每个批次都是满的,但也有特殊情况,每个桶的最后一个批次(batch)可能不满。
  55. **参数:**
  56. - **column_names** (list[str]) - 传递给长度计算函数的所有列名。
  57. - **bucket_boundaries** (list[int]) - 由各个桶的上边界值组成的列表,必须严格递增。如果有n个边界,则创建n+1个桶,分配后桶的边界如下:[0, bucket_boundaries[0]),[bucket_boundaries[i], bucket_boundaries[i+1])(其中,0<i<n-1),[bucket_boundaries[n-1], inf)。
  58. - **bucket_batch_sizes** (list[int]) - 由每个桶的批次大小组成的列表,必须包含 `len(bucket_boundaries)+1` 个元素。
  59. - **element_length_function** (Callable, optional) - 输入包含M个参数的函数,其中M等于 `len(column_names)` ,并返回一个整数。如果未指定该参数,则 `len(column_names)` 必须为1,并且该列数据第一维的shape值将用作长度(默认为None)。
  60. - **pad_info** (dict, optional) - 有关如何对指定列进行填充的字典对象。字典中键对应要填充的列名,值必须是包含2个元素的元组。元组中第一个元素对应要填充成的shape,第二个元素对应要填充的值。如果某一列未指定将要填充后的shape和填充值,则当前批次中该列上的每条数据都将填充至该批次中最长数据的长度,填充值为0。除非 `pad_to_bucket_boundary` 为True,否则 `pad_info` 中任何填充shape为None的列,其每条数据长度都将被填充为当前批处理中最数据的长度。如果不需要填充,请将 `pad_info` 设置为None(默认为None)。
  61. - **pad_to_bucket_boundary** (bool, optional) - 如果为True,则 `pad_info` 中填充shape为None的列,其长度都会被填充至 `bucket_boundary-1` 长度。如果有任何元素落入最后一个桶中,则将报错(默认为False)。
  62. - **drop_remainder** (bool, optional) - 如果为True,则丢弃每个桶中最后不足一个批次数据(默认为False)。
  63. **返回:**
  64. BucketBatchByLengthDataset,按长度进行分桶和批处理操作后的数据集对象。
  65. **样例:**
  66. >>> # 创建一个数据集对象,其中给定条数的数据会被组成一个批次数据
  67. >>> # 如果最后一个批次数据小于给定的批次大小(batch_size),则丢弃这个批次
  68. >>> import numpy as np
  69. >>> def generate_2_columns(n):
  70. ... for i in range(n):
  71. ... yield (np.array([i]), np.array([j for j in range(i + 1)]))
  72. >>>
  73. >>> column_names = ["col1", "col2"]
  74. >>> dataset = ds.GeneratorDataset(generate_2_columns(8), column_names)
  75. >>> bucket_boundaries = [5, 10]
  76. >>> bucket_batch_sizes = [2, 1, 1]
  77. >>> element_length_function = (lambda col1, col2: max(len(col1), len(col2)))
  78. >>> # 将对列名为"col2"的列进行填充,填充后的shape为[bucket_boundaries[i]],其中i是当前正在批处理的桶的索引
  79. >>> pad_info = {"col2": ([None], -1)}
  80. >>> pad_to_bucket_boundary = True
  81. >>> dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
  82. ... bucket_batch_sizes,
  83. ... element_length_function, pad_info,
  84. ... pad_to_bucket_boundary)
  85. .. py:method:: build_sentencepiece_vocab(columns, vocab_size, character_coverage, model_type, params)
  86. 用于从源数据集对象创建句子词表的函数。
  87. **参数:**
  88. - **columns** (list[str]) - 指定从哪一列中获取单词。
  89. - **vocab_size** (int) - 词汇表大小。
  90. - **character_coverage** (int) - 模型涵盖的字符百分比,必须介于0.98和1.0之间。默认值如0.9995,适用于具有丰富字符集的语言,如日语或中文字符集;1.0适用于其他字符集较小的语言,比如英语或拉丁文。
  91. - **model_type** (SentencePieceModel) - 模型类型,枚举值包括unigram(默认值)、bpe、char及word。当类型为word时,输入句子必须预先标记。
  92. - **params** (dict) - 依据原始数据内容构建祠表的附加参数,无附加参数时取值可以是空字典。
  93. **返回:**
  94. SentencePieceVocab,从数据集构建的词汇表。
  95. **样例:**
  96. >>> from mindspore.dataset.text import SentencePieceModel
  97. >>>
  98. >>> # DE_C_INTER_SENTENCEPIECE_MODE 是一个映射字典
  99. >>> from mindspore.dataset.text.utils import DE_C_INTER_SENTENCEPIECE_MODE
  100. >>> dataset = ds.TextFileDataset("/path/to/sentence/piece/vocab/file", shuffle=False)
  101. >>> dataset = dataset.build_sentencepiece_vocab(["text"], 5000, 0.9995,
  102. ... DE_C_INTER_SENTENCEPIECE_MODE[SentencePieceModel.UNIGRAM],
  103. ... {})
  104. .. py:method:: build_vocab(columns, freq_range, top_k, special_tokens, special_first)
  105. 基于数据集对象创建词汇表。
  106. 用于收集数据集中所有的唯一单词,并返回 `top_k` 个最常见的单词组成的词汇表(如果指定了 `top_k` )。
  107. **参数:**
  108. - **columns** (Union[str, list[str]]) :指定从数据集对象中哪一列中获取单词。
  109. - **freq_range** (tuple[int]) - 由(min_frequency, max_frequency)组成的整数元组,在这个频率范围的词汇会被保存下来。
  110. 取值范围需满足:0 <= min_frequency <= max_frequency <= total_words,其中min_frequency、max_frequency的默认值分别设置为0、total_words。
  111. - **top_k** (int) - 词汇表中包含的单词数,取 `top_k` 个最常见的单词。`top_k` 优先级低于 `freq_range`。如果 `top_k` 的值大于单词总数,则取所有单词。
  112. - **special_tokens** (list[str]) - 字符串列表,每个字符串都是一个特殊的标记。
  113. - **special_first** (bool) - 是否将 `special_tokens` 添加到词汇表首尾。如果指定了 `special_tokens` 且
  114. `special_first` 设置为默认值,则将 `special_tokens` 添加到词汇表最前面。
  115. **返回:**
  116. 从数据集对象中构建出的词汇表对象。
  117. **样例:**
  118. >>> def gen_corpus():
  119. ... # 键:单词,值:出现次数,键的取值采用字母表示有利于排序和显示。
  120. ... corpus = {"Z": 4, "Y": 4, "X": 4, "W": 3, "U": 3, "V": 2, "T": 1}
  121. ... for k, v in corpus.items():
  122. ... yield (np.array([k] * v, dtype='S'),)
  123. >>> column_names = ["column1", "column2", "column3"]
  124. >>> dataset = ds.GeneratorDataset(gen_corpus, column_names)
  125. >>> dataset = dataset.build_vocab(columns=["column3", "column1", "column2"],
  126. ... freq_range=(1, 10), top_k=5,
  127. ... special_tokens=["<pad>", "<unk>"],
  128. ... special_first=True,vocab='vocab')
  129. .. py:method:: close_pool()
  130. 关闭数据集对象中的多进程池。如果您熟悉多进程库,可以将此视为进程池对象的析构函数。
  131. .. py:method:: concat(datasets)
  132. 对传入的多个数据集对象进行拼接操作。重载“+”运算符来进行数据集对象拼接操作。
  133. .. note::用于拼接的多个数据集对象,其列名、每列数据的维度(rank)和类型必须相同。
  134. **参数:**
  135. - **datasets** (Union[list, class Dataset]) - 与当前数据集对象拼接的数据集对象列表或单个数据集对象。
  136. **返回:**
  137. ConcatDataset,拼接后的数据集对象。
  138. **样例:**
  139. >>> # 通过使用“+”运算符拼接dataset_1和dataset_2,获得拼接后的数据集对象
  140. >>> dataset = dataset_1 + dataset_2
  141. >>> # 通过concat操作拼接dataset_1和dataset_2,获得拼接后的数据集对象
  142. >>> dataset = dataset_1.concat(dataset_2)
  143. .. py:method:: create_dict_iterator(num_epochs=-1, output_numpy=False)
  144. 基于数据集对象创建迭代器,输出数据为字典类型。
  145. 字典中列的顺序可能与数据集对象中原始顺序不同。
  146. **参数:**
  147. - **num_epochs** (int, optional) - 迭代器可以迭代的最多轮次数(默认为-1,迭代器可以迭代无限次)。
  148. - **output_numpy** (bool, optional) - 是否输出NumPy数据类型,如果 `output_numpy` 为False,迭代器输出的每列数据类型为MindSpore.Tensor(默认为False)。
  149. **返回:**
  150. DictIterator,基于数据集对象创建的字典迭代器。
  151. **样例:**
  152. >>> # dataset是数据集类的实例化对象
  153. >>> iterator = dataset.create_dict_iterator()
  154. >>> for item in iterator:
  155. ... # item 是一个dict
  156. ... print(type(item))
  157. ... break
  158. <class 'dict'>
  159. .. py:method:: create_tuple_iterator(columns=None, num_epochs=-1, output_numpy=False, do_copy=True)
  160. 基于数据集对象创建迭代器,输出数据为ndarray组成的列表。
  161. 可以使用columns指定输出的所有列名及列的顺序。如果columns未指定,列的顺序将保持不变。
  162. **参数:**
  163. - **columns** (list[str], optional) - 用于指定列顺序的列名列表(默认为None,表示所有列)。
  164. - **num_epochs** (int, optional) - 迭代器可以迭代的最多轮次数(默认为-1,迭代器可以迭代无限次)。
  165. - **output_numpy** (bool, optional) - 是否输出NumPy数据类型,如果output_numpy为False,迭代器输出的每列数据类型为MindSpore.Tensor(默认为False)。
  166. - **do_copy** (bool, optional) - 当输出数据类型为mindspore.Tensor时,通过此参数指定转换方法,采用False主要考虑以获得更好的性能(默认为True)。
  167. **返回:**
  168. TupleIterator,基于数据集对象创建的元组迭代器。
  169. **样例:**
  170. >>> # dataset是数据集类的实例化对象
  171. >>> iterator = dataset.create_tuple_iterator()
  172. >>> for item in iterator:
  173. ... # item 是一个列表
  174. ... print(type(item))
  175. ... break
  176. <class 'list'>
  177. .. py:method:: device_que(send_epoch_end=True, create_data_info_queue=False)
  178. 返回一个能将数据传输到设备上的数据集对象。
  179. **参数:**
  180. - **send_epoch_end** (bool, optional) - 数据发送完成后是否发送结束标识到设备上(默认值为True)。
  181. - **create_data_info_queue** (bool, optional) - 是否创建一个队列,用于存储每条数据的type和shape(默认值为False)。
  182. .. note::
  183. 如果设备类型为Ascend,数据的每一列将被依次单独传输,每次传输的数据大小限制为256M。
  184. **返回:**
  185. TransferDataset,用于帮助发送数据到设备上的数据集对象。
  186. .. py:method:: dynamic_min_max_shapes()
  187. 获取数据集对象中单条数据的最小和最大shape,用于图编译过程。
  188. **返回:**
  189. 列表,原始数据集对象中单条数据的最小和最大shape分别以list形式返回。
  190. **样例:**
  191. >>> import numpy as np
  192. >>>
  193. >>> def generator1():
  194. >>> for i in range(1, 100):
  195. >>> yield np.ones((16, i, 83)), np.array(i)
  196. >>>
  197. >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
  198. >>> dataset.set_dynamic_columns(columns={"data1": [16, None, 83], "data2": []})
  199. >>> min_shapes, max_shapes = dataset.dynamic_min_max_shapes()
  200. .. py:method:: filter(predicate, input_columns=None, num_parallel_workers=None)
  201. 通过判断条件对数据集对象中的数据进行过滤。
  202. .. note::
  203. 如果 `input_columns` 未指定或为空,则将使用所有列。
  204. **参数:**
  205. - **predicate** (callable) - Python可调用对象,返回值为Bool类型。如果为False,则过滤掉该条数据。
  206. - **input_columns** (Union[str, list[str]], optional) - 输入列名组成的列表,当取默认值None时,`predicate` 将应用于数据集中的所有列。
  207. - **num_parallel_workers** (int, optional) - 用于并行处理数据集的线程数(默认为None,将使用配置文件中的值)。
  208. **返回:**
  209. FilterDataset,执行给定筛选过滤操作的数据集对象。
  210. **样例:**
  211. >>> # 生成一个list,其取值范围为(0,63)
  212. >>> # 过滤掉数值大于或等于11的数据
  213. >>> dataset = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"])
  214. .. py:method:: flat_map(func)
  215. 对数据集对象中每一条数据执行给定的 `func` 操作,并将结果展平。
  216. 指定的 `func` 是一个函数,输入必须为一个'ndarray',返回值是一个'Dataset'对象。
  217. **参数:**
  218. - **func** (function) - 输入'ndarray'并返回一个'Dataset'对象的函数。
  219. **返回:**
  220. 执行给定操作的数据集对象。
  221. **样例:**
  222. >>> # 以NumpySlicesDataset为例
  223. >>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]])
  224. >>>
  225. >>> def flat_map_func(array):
  226. ... # 使用数组创建NumpySlicesDataset
  227. ... dataset = ds.NumpySlicesDataset(array)
  228. ... # 将数据集对象中的数据重复两次
  229. ... dataset = dataset.repeat(2)
  230. ... return dataset
  231. >>>
  232. >>> dataset = dataset.flat_map(flat_map_func)
  233. >>> # [[0, 1], [0, 1], [2, 3], [2, 3]]
  234. **异常:**
  235. - **TypeError** - `func` 不是函数。
  236. - **TypeError** - `func` 的返回值不是数据集对象。
  237. .. py:method:: get_batch_size()
  238. 获得批处理的大小,即一个批次中包含的数据条数。
  239. **返回:**
  240. int,一个批次中包含的数据条数。
  241. **样例:**
  242. >> # dataset是数据集类的实例化对象
  243. >> batch_size = dataset.get_batch_size()
  244. .. py:method:: get_class_indexing()
  245. 返回类别索引。
  246. **返回:**
  247. dict,描述类别名称到索引的键值对映射关系,通常为str-to-int格式。针对COCO数据集,类别名称到索引映射关系描述形式为str-to-list<int>格式,列表中的第二个数字表示超级类别。
  248. **样例:**
  249. >> # dataset是数据集类的实例化对象
  250. >> class_indexing = dataset.get_class_indexing()
  251. .. py:method:: get_col_names()
  252. 返回数据集对象中包含的列名。
  253. **返回:**
  254. list,数据集中所有列名组成列表。
  255. **样例:**
  256. >> # dataset是数据集类的实例化对象
  257. >> col_names = dataset.get_col_names()
  258. .. py:method:: get_dataset_size()
  259. 返回一个epoch中的batch数。
  260. **返回:**
  261. int,batch的数目。
  262. .. py:method:: get_repeat_count()
  263. 获取 `RepeatDataset` 中的repeat次数(默认为1)。
  264. **返回:**
  265. int,repeat次数。
  266. .. py:method:: input_indexs
  267. :property:
  268. 获取input index信息。
  269. **返回:**
  270. input index信息的元组。
  271. **样例:**
  272. >>> # dataset是Dataset对象的实例
  273. >>> # 设置input_indexs
  274. >>> dataset.input_indexs = 10
  275. >>> print(dataset.input_indexs)
  276. 10
  277. .. py:method:: map(operations, input_columns=None, output_columns=None, column_order=None, num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None)
  278. 将operations列表中的每个operation作用于数据集。
  279. 作用的顺序由每个operation在operations参数中的位置决定。
  280. 将首先作用operation[0],然后operation[1],operation[2],以此类推。
  281. 每个operation将数据集中的一列或多列作为输入,并将输出零列或多列。
  282. 第一个operation将 `input_columns` 中指定的列作为输入。
  283. 如果operations列表中存在多个operation,则上一个operation的输出列将用作下一个operation的输入列。
  284. 最后一个operation输出列的列名由 `output_columns` 指定。
  285. 只有在 `column_order` 中指定的列才会传播到子节点,并且列的顺序将与 `column_order` 中指定的顺序相同。
  286. **参数:**
  287. - **operations** (Union[list[TensorOp], list[functions]]) - 要作用于数据集的operations列表。将按operations列表中显示的顺序作用在数据集。
  288. - **input_columns** (Union[str, list[str]], optional) - 第一个operation输入的列名列表。此列表的大小必须与第一个operation预期的输入列数相匹配。(默认为None,从第一列开始,无论多少列,都将传递给第一个operation)。
  289. - **output_columns** (Union[str, list[str]], optional) - 最后一个operation输出的列名列表。如果 `input_columns` 长度不等于 `output_columns` 长度,则此参数必选。此列表的大小必须与最后一个operation的输出列数相匹配(默认为None,输出列将与输入列具有相同的名称,例如,替换一些列)。
  290. - **column_order** (list[str], optional) - 指定整个数据集中所需的所有列的列表。当 `input_columns` 长度不等于 `output_columns` 长度时,则此参数必选。注意:这里的列表不仅仅是参数 `input_columns` 和 `output_columns` 中指定的列。
  291. - **num_parallel_workers** (int, optional) - 用于并行处理数据集的线程数(默认为None,将使用配置文件中的值)。
  292. - **python_multiprocessing** (bool, optional) - 将Python operations委托给多个工作进程进行并行处理。如果Python operations计算量很大,此选项可能会很有用(默认值为False)。
  293. - **cache** (DatasetCache, optional) - 使用Tensor缓存服务加快数据集处理速度(默认为None,即不使用缓存)。
  294. - **callbacks** (DSCallback, list[DSCallback], optional) - 要调用的Dataset回调函数列表(默认为None)。
  295. .. note::
  296. - `operations` 参数主要接收 `mindspore.dataset` 模块中c_transforms、py_transforms算子,以及用户定义的Python函数(PyFuncs)。
  297. - 不要将 `mindspore.nn` 和 `mindspore.ops` 或其他的网络计算算子添加到 `operations` 中。
  298. **返回:**
  299. MapDataset,map操作后的数据集。
  300. **样例:**
  301. >>> # dataset是Dataset的一个实例,它有2列,"image"和"label"。
  302. >>>
  303. >>> # 定义两个operation,每个operation接受1列输入,输出1列。
  304. >>> decode_op = c_vision.Decode(rgb=True)
  305. >>> random_jitter_op = c_vision.RandomColorAdjust(brightness=(0.8, 0.8), contrast=(1, 1),
  306. ... saturation=(1, 1), hue=(0, 0))
  307. >>>
  308. >>> # 1)简单的map示例。
  309. >>>
  310. >>> # 在列“image"上应用decode_op。此列将被
  311. >>> # decode_op的输出列替换。由于未指定column_order,因此两列“image"
  312. >>> # 和“label"将按其原始顺序传播到下一个节点。
  313. >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"])
  314. >>>
  315. >>> # 解码列“image"并将其重命名为“decoded_image"。
  316. >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"], output_columns=["decoded_image"])
  317. >>>
  318. >>> # 指定输出列的顺序。
  319. >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"],
  320. ... output_columns=None, column_order=["label", "image"])
  321. >>>
  322. >>> # 将列“image"重命名为“decoded_image",并指定输出列的顺序。
  323. >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"],
  324. ... output_columns=["decoded_image"], column_order=["label", "decoded_image"])
  325. >>>
  326. >>> # 将列“image"重命名为“decoded_image",并只保留此列。
  327. >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"],
  328. ... output_columns=["decoded_image"], column_order=["decoded_image"])
  329. >>>
  330. >>> # 使用用户自定义Python函数的map简单示例。列重命名和指定列顺序
  331. >>> # 的方式同前面的示例相同。
  332. >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"])
  333. >>> dataset = dataset.map(operations=[(lambda x: x + 1)], input_columns=["data"])
  334. >>>
  335. >>> # 2)多个operation的map示例。
  336. >>>
  337. >>> # 创建一个数据集,图像被解码,并随机颜色抖动。
  338. >>> # decode_op以列“image"作为输入,并输出一列。将
  339. >>> # decode_op输出的列作为输入传递给random_jitter_op。
  340. >>> # random_jitter_op将输出一列。列“image"将替换为
  341. >>> # random_jitter_op(最后一个operation)输出的列。所有其他
  342. >>> # 列保持不变。由于未指定column_order,因此
  343. >>> # 列的顺序将保持不变。
  344. >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"])
  345. >>>
  346. >>> # 将random_jitter_op输出的列重命名为“image_mapped"。
  347. >>> # 指定列顺序的方式与1中的示例相同。
  348. >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"],
  349. ... output_columns=["image_mapped"])
  350. >>>
  351. >>> # 使用用户自定义Python函数的多个operation的map示例。列重命名和指定列顺序
  352. >>> # 的方式与1中的示例相同。
  353. >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"])
  354. >>> dataset = dataset.map(operations=[(lambda x: x * x), (lambda x: x - 1)], input_columns=["data"],
  355. ... output_columns=["data_mapped"])
  356. >>>
  357. >>> # 3)输入列数不等于输出列数的示例。
  358. >>>
  359. >>> # operation[0] 是一个 lambda,它以 2 列作为输入并输出 3 列。
  360. >>> # operations[1] 是一个 lambda,它以 3 列作为输入并输出 1 列。
  361. >>> # operations[2] 是一个 lambda,它以 1 列作为输入并输出 4 列。
  362. >>> #
  363. >>> # 注:operation[i]的输出列数必须等于
  364. >>> # operation[i+1]的输入列。否则,map算子会
  365. >>> # 出错。
  366. >>> operations = [(lambda x, y: (x, x + y, x + y + 1)),
  367. ... (lambda x, y, z: x * y * z),
  368. ... (lambda x: (x % 2, x % 3, x % 5, x % 7))]
  369. >>>
  370. >>> # 注:由于输入列数与
  371. >>> # 输出列数不相同,必须指定output_columns和column_order
  372. >>> # 参数。否则,此map算子也会出错。
  373. >>>
  374. >>> dataset = ds.NumpySlicesDataset(data=([[0, 1, 2]], [[3, 4, 5]]), column_names=["x", "y"])
  375. >>>
  376. >>> # 按以下顺序将所有列传播到子节点:
  377. >>> dataset = dataset.map(operations, input_columns=["x", "y"],
  378. ... output_columns=["mod2", "mod3", "mod5", "mod7"],
  379. ... column_order=["mod2", "mod3", "mod5", "mod7"])
  380. >>>
  381. >>> # 按以下顺序将某些列传播到子节点:
  382. >>> dataset = dataset.map(operations, input_columns=["x", "y"],
  383. ... output_columns=["mod2", "mod3", "mod5", "mod7"],
  384. ... column_order=["mod7", "mod3", "col2"])
  385. .. py:method:: num_classes()
  386. 获取数据集中的样本的class数目。
  387. **返回:**
  388. int,class数目。
  389. .. py:method:: output_shapes()
  390. 获取输出数据的shape。
  391. **返回:**
  392. list,每列shape的列表。
  393. .. py:method:: output_types()
  394. 获取输出数据类型。
  395. **返回:**
  396. list,每列类型的列表。
  397. .. py:method:: project(columns)
  398. 在输入数据集上投影某些列。
  399. 从数据集中选择列,并以指定的顺序传输到流水线中。
  400. 其他列将被丢弃。
  401. **参数:**
  402. - **columns** (Union[str, list[str]]) - 要投影列的列名列表。
  403. **返回:**
  404. ProjectDataset,投影后的数据集对象。
  405. **样例:**
  406. >>> # dataset是Dataset对象的实例
  407. >>> columns_to_project = ["column3", "column1", "column2"]
  408. >>>
  409. >>> # 创建一个数据集,无论列的原始顺序如何,依次包含column3, column1, column2。
  410. >>> dataset = dataset.project(columns=columns_to_project)
  411. .. py:method:: rename(input_columns, output_columns)
  412. 重命名输入数据集中的列。
  413. **参数:**
  414. - **input_columns** (Union[str, list[str]]) - 输入列的列名列表。
  415. - **output_columns** (Union[str, list[str]]) - 输出列的列名列表。
  416. **返回:**
  417. RenameDataset,重命名后数据集对象。
  418. **样例:**
  419. >>> # dataset是Dataset对象的实例
  420. >>> input_columns = ["input_col1", "input_col2", "input_col3"]
  421. >>> output_columns = ["output_col1", "output_col2", "output_col3"]
  422. >>>
  423. >>> # 创建一个数据集,其中input_col1重命名为output_col1,
  424. >>> # input_col2重命名为output_col2,input_col3重命名
  425. >>> # 为output_col3。
  426. >>> dataset = dataset.rename(input_columns=input_columns, output_columns=output_columns)
  427. .. py:method:: repeat(count=None)
  428. 重复此数据集 `count` 次。如果count为None或-1,则无限重复。
  429. .. note::
  430. repeat和batch的顺序反映了batch的数量。建议:repeat操作在batch操作之后使用。
  431. **参数:**
  432. - **count** (int) - 数据集重复的次数(默认为None)。
  433. **返回:**
  434. RepeatDataset,重复操作后的数据集对象。
  435. **样例:**
  436. >>> # dataset是Dataset对象的实例
  437. >>>
  438. >>> # 创建一个数据集,数据集重复50个epoch。
  439. >>> dataset = dataset.repeat(50)
  440. >>>
  441. >>> # 创建一个数据集,其中每个epoch都是单独打乱的。
  442. >>> dataset = dataset.shuffle(10)
  443. >>> dataset = dataset.repeat(50)
  444. >>>
  445. >>> # 创建一个数据集,打乱前先将数据集重复
  446. >>> # 50个epoch。shuffle算子将
  447. >>> # 整个50个epoch视作一个大数据集。
  448. >>> dataset = dataset.repeat(50)
  449. >>> dataset = dataset.shuffle(10)
  450. .. py:method:: reset()
  451. 重置下一个epoch的数据集。
  452. .. py:method:: save(file_name, num_files=1, file_type='mindrecord')
  453. 将流水线正在处理的数据保存为通用的数据集格式。支持的数据集格式:'mindrecord'。
  454. 将数据保存为'mindrecord'格式时存在隐式类型转换。转换表展示如何执行类型转换。
  455. .. list-table:: 保存为'mindrecord'格式时的隐式类型转换
  456. :widths: 25 25 50
  457. :header-rows: 1
  458. * - 'dataset'类型
  459. - 'mindrecord'类型
  460. - 详细
  461. * - bool
  462. - None
  463. - 不支持
  464. * - int8
  465. - int32
  466. -
  467. * - uint8
  468. - bytes(1D uint8)
  469. - Drop dimension
  470. * - int16
  471. - int32
  472. -
  473. * - uint16
  474. - int32
  475. -
  476. * - int32
  477. - int32
  478. -
  479. * - uint32
  480. - int64
  481. -
  482. * - int64
  483. - int64
  484. -
  485. * - uint64
  486. - None
  487. - 不支持
  488. * - float16
  489. - float32
  490. -
  491. * - float32
  492. - float32
  493. -
  494. * - float64
  495. - float64
  496. -
  497. * - string
  498. - string
  499. - 不支持多维字符串
  500. .. note::
  501. 1. 如需按顺序保存示例,请将数据集的shuffle设置为False,将 `num_files` 设置为1。
  502. 2. 在调用函数之前,不要使用batch算子、repeat算子或具有随机属性的数据增强的map算子。
  503. 3. 当数据的维度可变时,只支持1维数组或者在0维变化的多维数组。
  504. 4. 不支持DE_UINT64类型、多维的DE_UINT8类型、多维DE_STRING类型。
  505. **参数:**
  506. - **file_name** (str) - 数据集文件的路径。
  507. - **num_files** (int, optional) - 数据集文件的数量(默认为1)。
  508. - **file_type** (str, optional) - 数据集格式(默认为'mindrecord')。
  509. .. py:method:: set_dynamic_columns(columns=None)
  510. 设置源数据的动态shape信息,需要在定义数据处理流水线后设置。
  511. **参数:**
  512. - **columns** (dict) - 包含数据集中每列shape信息的字典。shape[i]为 `None` 表示shape[i]的数据长度是动态的。
  513. .. py:method:: shuffle(buffer_size)
  514. 使用以下策略随机打乱此数据集的行:
  515. 1. 生成一个shuffle缓冲区包含buffer_size条数据行。
  516. 2. 从shuffle缓冲区中随机选择一个元素,作为下一行传播到子节点。
  517. 3. 从父节点获取下一行(如果有的话),并将其放入shuffle缓冲区中。
  518. 4. 重复步骤2和3,直到打乱缓冲区中没有数据行为止。
  519. 可以提供随机种子,在第一个epoch中使用。在随后的每个epoch,种子都会被设置成一个新产生的随机值。
  520. **参数:**
  521. - **buffer_size** (int) - 用于shuffle的缓冲区大小(必须大于1)。将buffer_size设置为等于数据集大小将导致在全局shuffle。
  522. **返回:**
  523. ShuffleDataset,打乱后的数据集对象。
  524. **异常:**
  525. - **RuntimeError** - 打乱前存在同步操作。
  526. **样例:**
  527. >>> # dataset是Dataset对象的实例
  528. >>> # 可以选择设置第一个epoch的种子
  529. >>> ds.config.set_seed(58)
  530. >>> # 使用大小为4的shuffle缓冲区创建打乱后的数据集。
  531. >>> dataset = dataset.shuffle(4)
  532. .. py:method:: skip(count)
  533. 跳过此数据集的前N个元素。
  534. **参数:**
  535. - **count** (int) - 要跳过的数据集中的元素个数。
  536. **返回:**
  537. SkipDataset,减去跳过的行的数据集对象。
  538. **样例:**
  539. >>> # dataset是Dataset对象的实例
  540. >>> # 创建一个数据集,跳过前3个元素
  541. >>> dataset = dataset.skip(3)
  542. .. py:method:: split(sizes, randomize=True)
  543. 将数据集拆分为多个不重叠的数据集。
  544. 这是一个通用拆分函数,可以被数据处理流水线中的任何算子调用。
  545. 还有如果直接调用ds.split,其中 ds 是一个 MappableDataset,它将被自动调用。
  546. **参数:**
  547. - **sizes** (Union[list[int], list[float]]) - 如果指定了一列整数[s1, s2, …, sn],数据集将被拆分为n个大小为s1、s2、...、sn的数据集。如果所有输入大小的总和不等于原始数据集大小,则报错。如果指定了一列浮点数[f1, f2, …, fn],则所有浮点数必须介于0和1之间,并且总和必须为1,否则报错。数据集将被拆分为n个大小为round(f1*K)、round(f2*K)、...、round(fn*K)的数据集,其中K是原始数据集的大小。
  548. 如果舍入后:
  549. - 任何大小等于0,都将发生错误。
  550. - 如果拆分大小的总和<K,K - sigma(round(fi * k))的差值将添加到第一个子数据集。
  551. - 如果拆分大小的总和>K,sigma(round(fi * K)) - K的差值将从第一个足够大的拆分子集中删除,删除差值后至少有1行。
  552. - **randomize** (bool, optional) - 确定是否随机拆分数据(默认为True)。如果为True,则数据集将被随机拆分。否则,将使用数据集中的连续行创建每个拆分子集。
  553. .. note::
  554. 1. 如果要调用 split,则无法对数据集进行分片。
  555. 2. 强烈建议不要对数据集进行打乱,而是使用随机化(randomize=True)。对数据集进行打乱的结果具有不确定性,每个拆分子集中的数据在每个epoch可能都不同。
  556. **异常:**
  557. - **RuntimeError** - get_dataset_size返回None或此数据集不支持。
  558. - **RuntimeError** - sizes是整数列表,并且size中所有元素的总和不等于数据集大小。
  559. - **RuntimeError** - sizes是float列表,并且计算后存在大小为0的拆分子数据集。
  560. - **RuntimeError** - 数据集在调用拆分之前已进行分片。
  561. - **ValueError** - sizes是float列表,且并非所有float数都在0和1之间,或者float数的总和不等于1。
  562. **返回:**
  563. tuple(Dataset),拆分后子数据集对象的元组。
  564. **样例:**
  565. >>> # TextFileDataset不是可映射dataset,因此将调用通用拆分函数。
  566. >>> # 由于许多数据集默认都打开了shuffle,如需调用拆分函数,请将shuffle设置为False。
  567. >>> dataset = ds.TextFileDataset(text_file_dataset_dir, shuffle=False)
  568. >>> train_dataset, test_dataset = dataset.split([0.9, 0.1])
  569. .. py:method:: sync_update(condition_name, num_batch=None, data=None)
  570. 释放阻塞条件并使用给定数据触发回调函数。
  571. **参数:**
  572. - **condition_name** (str) - 用于切换发送下一行数据的条件名称。
  573. - **num_batch** (Union[int, None]) - 释放的batch(row)数。当 `num_batch` 为None时,将默认为 `sync_wait` 算子指定的值(默认为None)。
  574. - **data** (Any) - 用户自定义传递给回调函数的数据(默认为None)。
  575. .. py:method:: sync_wait(condition_name, num_batch=1, callback=None)
  576. 向输入数据集添加阻塞条件。 将应用同步操作。
  577. **参数:**
  578. - **condition_name** (str) - 用于切换发送下一行的条件名称。
  579. - **num_batch** (int) - 每个epoch开始时无阻塞的batch数。
  580. - **callback** (function) - `sync_update` 中将调用的回调函数。
  581. **返回:**
  582. SyncWaitDataset,添加了阻塞条件的数据集对象。
  583. **异常:**
  584. - **RuntimeError** - 条件名称已存在。
  585. **样例:**
  586. >>> import numpy as np
  587. >>> def gen():
  588. ... for i in range(100):
  589. ... yield (np.array(i),)
  590. >>>
  591. >>> class Augment:
  592. ... def __init__(self, loss):
  593. ... self.loss = loss
  594. ...
  595. ... def preprocess(self, input_):
  596. ... return input_
  597. ...
  598. ... def update(self, data):
  599. ... self.loss = data["loss"]
  600. >>>
  601. >>> batch_size = 4
  602. >>> dataset = ds.GeneratorDataset(gen, column_names=["input"])
  603. >>>
  604. >>> aug = Augment(0)
  605. >>> dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  606. >>> dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
  607. >>> dataset = dataset.batch(batch_size)
  608. >>> count = 0
  609. >>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  610. ... assert data["input"][0] == count
  611. ... count += batch_size
  612. ... data = {"loss": count}
  613. ... dataset.sync_update(condition_name="policy", data=data)
  614. .. py:method:: take(count=-1)
  615. 从数据集中获取最多给定数量的元素。
  616. .. note::
  617. 1. 如果count大于数据集中的元素数或等于-1,则取数据集中的所有元素。
  618. 2. take和batch操作顺序很重要,如果take在batch操作之前,则取给定行数;否则取给定batch数。
  619. **参数:**
  620. - **count** (int, optional) - 要从数据集中获取的元素数(默认为-1)。
  621. **返回:**
  622. TakeDataset,取出指定数目的数据集对象。
  623. **样例:**
  624. >>> # dataset是Dataset对象的实例。
  625. >>> # 创建一个数据集,包含50个元素。
  626. >>> dataset = dataset.take(50)
  627. .. py:method:: to_device(send_epoch_end=True, create_data_info_queue=False)
  628. 将数据从CPU传输到GPU、Ascend或其他设备。
  629. **参数:**
  630. - **send_epoch_end** (bool, optional) - 是否将end of sequence发送到设备(默认为True)。
  631. - **create_data_info_queue** (bool, optional) - 是否创建存储数据类型和shape的队列(默认值为False)。
  632. .. note::
  633. 如果设备为Ascend,则逐个传输数据。每次传输的数据最大限制为256M。
  634. **返回:**
  635. TransferDataset,用于传输的数据集对象。
  636. **异常:**
  637. - **RuntimeError** - 如果提供了分布式训练的文件路径但读取失败。
  638. .. py:method:: to_json(filename='')
  639. 将数据处理流水线序列化为JSON字符串,如果提供了文件名,则转储到文件中。
  640. **参数:**
  641. - **filename** (str) - 另存为JSON格式的文件名。
  642. **返回:**
  643. str,流水线的JSON字符串。