You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.dataset.Dataset.rst 32 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711
  1. ..py:method:: build_vocab(columns, freq_range, top_k, special_tokens, special_first)
  2. 基于数据集对象创建词汇表。
  3. 用于收集数据集中所有的唯一单词,并返回 `top_k` 个最常见的单词组成的词汇表(如果指定了 `top_k` )。
  4. **参数:**
  5. **columns** (Union[str, list[str]]) :指定从数据集对象中哪一列中获取单词。
  6. **freq_range** (tuple[int]):由(min_frequency, max_frequency)组成的整数元组,在这个频率范围的词汇会被保存下来。
  7. 取值范围需满足:0 <= min_frequency <= max_frequency <= total_words,其中min_frequency、max_frequency的默认值分别设置为0、total_words。
  8. **top_k** (int):词汇表中包含的单词数,取`top_k`个最常见的单词。`top_k`优先级低于`freq_range`。如果`top_k`的值大于单词总数,则取所有单词。
  9. **special_tokens** (list[str]):字符串列表,每个字符串都是一个特殊的标记。
  10. **special_first** (bool):是否将 `special_tokens` 添加到词汇表首尾。如果指定了 `special_tokens` 且
  11. `special_first` 设置为默认值,则将`special_tokens`添加到词汇表最前面。
  12. **返回:**
  13. 从数据集对象中构建出的词汇表对象。
  14. **样例:**
  15. >>> def gen_corpus():
  16. ... # 键:单词,值:出现次数,键的取值采用字母表示有利于排序和显示。
  17. ... corpus = {"Z": 4, "Y": 4, "X": 4, "W": 3, "U": 3, "V": 2, "T": 1}
  18. ... for k, v in corpus.items():
  19. ... yield (np.array([k] * v, dtype='S'),)
  20. >>> column_names = ["column1", "column2", "column3"]
  21. >>> dataset = ds.GeneratorDataset(gen_corpus, column_names)
  22. >>> dataset = dataset.build_vocab(columns=["column3", "column1", "column2"],
  23. ... freq_range=(1, 10), top_k=5,
  24. ... special_tokens=["<pad>", "<unk>"],
  25. ... special_first=True,vocab='vocab')
  26. ..py:method:: device_que(send_epoch_end=True, create_data_info_queue=False)
  27. 返回一个能将数据传输到设备上的数据集对象。
  28. **参数:**
  29. **send_epoch_end** (bool, optional):数据发送完成后是否发送结束标识到设备上(默认值为True)。
  30. **create_data_info_queue** (bool, optional):是否创建一个队列,用于存储每条数据的type和shape(默认值为False)。
  31. .. note::
  32. 如果设备类型为Ascend,数据的每一列将被依次单独传输,每次传输的数据大小限制为256M。
  33. **返回:**
  34. TransferDataset,用于帮助发送数据到设备上的数据集对象。
  35. ..py:method:: dynamic_min_max_shapes()
  36. 获取数据集对象中单条数据的最小和最大shape,用于图编译过程。
  37. **返回:**
  38. 列表,原始数据集对象中单条数据的最小和最大shape分别以list形式返回。
  39. **样例:**
  40. >>> import numpy as np
  41. >>>
  42. >>> def generator1():
  43. >>> for i in range(1, 100):
  44. >>> yield np.ones((16, i, 83)), np.array(i)
  45. >>>
  46. >>> dataset = ds.GeneratorDataset(generator1, ["data1", "data2"])
  47. >>> dataset.set_dynamic_columns(columns={"data1": [16, None, 83], "data2": []})
  48. >>> min_shapes, max_shapes = dataset.dynamic_min_max_shapes()
  49. ..py:method:: filter(predicate, input_columns=None, num_parallel_workers=None)
  50. 通过判断条件对数据集对象中的数据进行过滤。
  51. .. note::
  52. 如果`input_columns`未指定或为空,则将使用所有列。
  53. **参数:**
  54. **predicate** (callable):Python可调用对象,返回值为Bool类型。如果为False,则过滤掉该条数据。
  55. **input_columns** (Union[str, list[str]], optional):输入列名组成的列表,当取默认值None时,`predicate` 将应用于数据集中的所有列。
  56. **num_parallel_workers** (int, optional):用于并行处理数据集的线程数(默认为None,将使用配置文件中的值)。
  57. **返回:**
  58. FilterDataset,执行给定筛选过滤操作的数据集对象。
  59. **样例:**
  60. >>> # 生成一个list,其取值范围为(0,63)
  61. >>> # 过滤掉数值大于或等于11的数据
  62. >>> dataset = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"])
  63. ..py:method:: flat_map(func)
  64. 对数据集对象中每一条数据执行给定的`func`操作,并将结果展平。
  65. 指定的`func`是一个函数,输入必须为一个'ndarray',返回值是一个'Dataset'对象。
  66. **参数:**
  67. **func** (function):输入'ndarray'并返回一个'Dataset'对象的函数。
  68. **返回:**
  69. 执行给定操作的数据集对象。
  70. **样例:**
  71. >>> # 以NumpySlicesDataset为例
  72. >>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]])
  73. >>>
  74. >>> def flat_map_func(array):
  75. ... # 使用数组创建NumpySlicesDataset
  76. ... dataset = ds.NumpySlicesDataset(array)
  77. ... # 将数据集对象中的数据重复两次
  78. ... dataset = dataset.repeat(2)
  79. ... return dataset
  80. >>>
  81. >>> dataset = dataset.flat_map(flat_map_func)
  82. >>> # [[0, 1], [0, 1], [2, 3], [2, 3]]
  83. **异常:**
  84. **TypeError** - `func` 不是函数。
  85. **TypeError** - `func` 的返回值不是数据集对象。
  86. ..py:method:: get_batch_size()
  87. 获得批处理的大小,即一个批次中包含的数据条数。
  88. **返回:**
  89. int,一个批次中包含的数据条数。
  90. **样例:**
  91. >> # dataset是数据集类的实例化对象
  92. >> batch_size = dataset.get_batch_size()
  93. ..py:method:: get_class_indexing()
  94. 返回类别索引。
  95. **返回:**
  96. dict,描述类别名称到索引的键值对映射关系,通常为str-to-int格式。针对COCO数据集,类别名称到索引映射关系描述形式为str-to-list<int>格式,列表中的第二个数字表示超级类别。
  97. **样例:**
  98. >> # dataset是数据集类的实例化对象
  99. >> class_indexing = dataset.get_class_indexing()
  100. ..py:method:: get_col_names()
  101. 返回数据集对象中包含的列名。
  102. **返回:**
  103. list,数据集中所有列名组成列表。
  104. **样例:**
  105. >> # dataset是数据集类的实例化对象
  106. >> col_names = dataset.get_col_names()
  107. .. py:method:: get_dataset_size()
  108. 返回一个epoch中的batch数。
  109. **返回:**
  110. int,batch的数目。
  111. .. py:method:: get_repeat_count()
  112. 获取 `RepeatDataset` 中的repeat次数(默认为1)。
  113. **返回:**
  114. int,repeat次数。
  115. .. py:method:: input_indexs
  116. :property:
  117. 获取input index信息。
  118. **返回:**
  119. input index信息的元组。
  120. **样例:**
  121. >>> # dataset是Dataset对象的实例
  122. >>> # 设置input_indexs
  123. >>> dataset.input_indexs = 10
  124. >>> print(dataset.input_indexs)
  125. 10
  126. .. py:method:: map(operations, input_columns=None, output_columns=None, column_order=None, num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None)
  127. 将operations列表中的每个operation作用于数据集。
  128. 作用的顺序由每个operation在operations参数中的位置决定。
  129. 将首先作用operation[0],然后operation[1],operation[2],以此类推。
  130. 每个operation将数据集中的一列或多列作为输入,并将输出零列或多列。
  131. 第一个operation将 `input_columns` 中指定的列作为输入。
  132. 如果operations列表中存在多个operation,则上一个operation的输出列将用作下一个operation的输入列。
  133. 最后一个operation输出列的列名由 `output_columns` 指定。
  134. 只有在 `column_order` 中指定的列才会传播到子节点,并且列的顺序将与 `column_order` 中指定的顺序相同。
  135. **参数:**
  136. - **operations** (Union[list[TensorOp], list[functions]]) - 要作用于数据集的operations列表。将按operations列表中显示的顺序作用在数据集。
  137. - **input_columns** (Union[str, list[str]], optional) - 第一个operation输入的列名列表。此列表的大小必须与第一个operation预期的输入列数相匹配。(默认为None,从第一列开始,无论多少列,都将传递给第一个operation)。
  138. - **output_columns** (Union[str, list[str]], optional) - 最后一个operation输出的列名列表。如果 `input_columns` 长度不等于 `output_columns` 长度,则此参数必选。此列表的大小必须与最后一个operation的输出列数相匹配(默认为None,输出列将与输入列具有相同的名称,例如,替换一些列)。
  139. - **column_order** (list[str], optional) - 指定整个数据集中所需的所有列的列表。当 `input_columns` 长度不等于 `output_columns` 长度时,则此参数必选。注意:这里的列表不仅仅是参数 `input_columns` 和 `output_columns` 中指定的列。
  140. - **num_parallel_workers** (int, optional) - 用于并行处理数据集的线程数(默认为None,将使用配置文件中的值)。
  141. - **python_multiprocessing** (bool, optional) - 将Python operations委托给多个工作进程进行并行处理。如果Python operations计算量很大,此选项可能会很有用(默认值为False)。
  142. - **cache** (DatasetCache, optional) - 使用Tensor缓存服务加快数据集处理速度(默认为None,即不使用缓存)。
  143. - **callbacks** (DSCallback, list[DSCallback], optional) - 要调用的Dataset回调函数列表(默认为None)。
  144. **返回:**
  145. MapDataset,map操作后的数据集。
  146. **样例:**
  147. >>> # dataset是Dataset的一个实例,它有2列,"image"和"label"。
  148. >>>
  149. >>> # 定义两个operation,每个operation接受1列输入,输出1列。
  150. >>> decode_op = c_vision.Decode(rgb=True)
  151. >>> random_jitter_op = c_vision.RandomColorAdjust(brightness=(0.8, 0.8), contrast=(1, 1),
  152. ... saturation=(1, 1), hue=(0, 0))
  153. >>>
  154. >>> # 1)简单的map示例。
  155. >>>
  156. >>> # 在列“image"上应用decode_op。此列将被
  157. >>> # decode_op的输出列替换。由于未指定column_order,因此两列“image"
  158. >>> # 和“label"将按其原始顺序传播到下一个节点。
  159. >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"])
  160. >>>
  161. >>> # 解码列“image"并将其重命名为“decoded_image"。
  162. >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"], output_columns=["decoded_image"])
  163. >>>
  164. >>> # 指定输出列的顺序。
  165. >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"],
  166. ... output_columns=None, column_order=["label", "image"])
  167. >>>
  168. >>> # 将列“image"重命名为“decoded_image",并指定输出列的顺序。
  169. >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"],
  170. ... output_columns=["decoded_image"], column_order=["label", "decoded_image"])
  171. >>>
  172. >>> # 将列“image"重命名为“decoded_image",并只保留此列。
  173. >>> dataset = dataset.map(operations=[decode_op], input_columns=["image"],
  174. ... output_columns=["decoded_image"], column_order=["decoded_image"])
  175. >>>
  176. >>> # 使用用户自定义Python函数的map简单示例。列重命名和指定列顺序
  177. >>> # 的方式同前面的示例相同。
  178. >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"])
  179. >>> dataset = dataset.map(operations=[(lambda x: x + 1)], input_columns=["data"])
  180. >>>
  181. >>> # 2)多个operation的map示例。
  182. >>>
  183. >>> # 创建一个数据集,图像被解码,并随机颜色抖动。
  184. >>> # decode_op以列“image"作为输入,并输出一列。将
  185. >>> # decode_op输出的列作为输入传递给random_jitter_op。
  186. >>> # random_jitter_op将输出一列。列“image"将替换为
  187. >>> # random_jitter_op(最后一个operation)输出的列。所有其他
  188. >>> # 列保持不变。由于未指定column_order,因此
  189. >>> # 列的顺序将保持不变。
  190. >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"])
  191. >>>
  192. >>> # 将random_jitter_op输出的列重命名为“image_mapped"。
  193. >>> # 指定列顺序的方式与1中的示例相同。
  194. >>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"],
  195. ... output_columns=["image_mapped"])
  196. >>>
  197. >>> # 使用用户自定义Python函数的多个operation的map示例。列重命名和指定列顺序
  198. >>> # 的方式与1中的示例相同。
  199. >>> dataset = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"])
  200. >>> dataset = dataset.map(operations=[(lambda x: x * x), (lambda x: x - 1)], input_columns=["data"],
  201. ... output_columns=["data_mapped"])
  202. >>>
  203. >>> # 3)输入列数不等于输出列数的示例。
  204. >>>
  205. >>> # operation[0] 是一个 lambda,它以 2 列作为输入并输出 3 列。
  206. >>> # operations[1] 是一个 lambda,它以 3 列作为输入并输出 1 列。
  207. >>> # operations[2] 是一个 lambda,它以 1 列作为输入并输出 4 列。
  208. >>> #
  209. >>> # 注:operation[i]的输出列数必须等于
  210. >>> # operation[i+1]的输入列。否则,map算子会
  211. >>> # 出错。
  212. >>> operations = [(lambda x, y: (x, x + y, x + y + 1)),
  213. ... (lambda x, y, z: x * y * z),
  214. ... (lambda x: (x % 2, x % 3, x % 5, x % 7))]
  215. >>>
  216. >>> # 注:由于输入列数与
  217. >>> # 输出列数不相同,必须指定output_columns和column_order
  218. >>> # 参数。否则,此map算子也会出错。
  219. >>>
  220. >>> dataset = ds.NumpySlicesDataset(data=([[0, 1, 2]], [[3, 4, 5]]), column_names=["x", "y"])
  221. >>>
  222. >>> # 按以下顺序将所有列传播到子节点:
  223. >>> dataset = dataset.map(operations, input_columns=["x", "y"],
  224. ... output_columns=["mod2", "mod3", "mod5", "mod7"],
  225. ... column_order=["mod2", "mod3", "mod5", "mod7"])
  226. >>>
  227. >>> # 按以下顺序将某些列传播到子节点:
  228. >>> dataset = dataset.map(operations, input_columns=["x", "y"],
  229. ... output_columns=["mod2", "mod3", "mod5", "mod7"],
  230. ... column_order=["mod7", "mod3", "col2"])
  231. .. py:method:: num_classes()
  232. 获取数据集中的样本的class数目。
  233. **返回:**
  234. int,class数目。
  235. .. py:method:: output_shapes()
  236. 获取输出数据的shape。
  237. **返回:**
  238. list,每列shape的列表。
  239. .. py:method:: output_types()
  240. 获取输出数据类型。
  241. **返回:**
  242. list,每列类型的列表。
  243. .. py:method:: project(columns)
  244. 在输入数据集上投影某些列。
  245. 从数据集中选择列,并以指定的顺序传输到流水线中。
  246. 其他列将被丢弃。
  247. **参数:**
  248. **columns** (Union[str, list[str]]) - 要投影列的列名列表。
  249. **返回:**
  250. ProjectDataset,投影后的数据集对象。
  251. **样例:**
  252. >>> # dataset是Dataset对象的实例
  253. >>> columns_to_project = ["column3", "column1", "column2"]
  254. >>>
  255. >>> # 创建一个数据集,无论列的原始顺序如何,依次包含column3, column1, column2。
  256. >>> dataset = dataset.project(columns=columns_to_project)
  257. .. py:method:: rename(input_columns, output_columns)
  258. 重命名输入数据集中的列。
  259. **参数:**
  260. - **input_columns** (Union[str, list[str]]) - 输入列的列名列表。
  261. - **output_columns** (Union[str, list[str]]) - 输出列的列名列表。
  262. **返回:**
  263. RenameDataset,重命名后数据集对象。
  264. **样例:**
  265. >>> # dataset是Dataset对象的实例
  266. >>> input_columns = ["input_col1", "input_col2", "input_col3"]
  267. >>> output_columns = ["output_col1", "output_col2", "output_col3"]
  268. >>>
  269. >>> # 创建一个数据集,其中input_col1重命名为output_col1,
  270. >>> # input_col2重命名为output_col2,input_col3重命名
  271. >>> # 为output_col3。
  272. >>> dataset = dataset.rename(input_columns=input_columns, output_columns=output_columns)
  273. .. py:method:: repeat(count=None)
  274. 重复此数据集 `count` 次。如果count为None或-1,则无限重复。
  275. .. note::
  276. repeat和batch的顺序反映了batch的数量。建议:repeat操作在batch操作之后使用。
  277. **参数:**
  278. **count** (int) - 数据集重复的次数(默认为None)。
  279. **返回:**
  280. RepeatDataset,重复操作后的数据集对象。
  281. **样例:**
  282. >>> # dataset是Dataset对象的实例
  283. >>>
  284. >>> # 创建一个数据集,数据集重复50个epoch。
  285. >>> dataset = dataset.repeat(50)
  286. >>>
  287. >>> # 创建一个数据集,其中每个epoch都是单独打乱的。
  288. >>> dataset = dataset.shuffle(10)
  289. >>> dataset = dataset.repeat(50)
  290. >>>
  291. >>> # 创建一个数据集,打乱前先将数据集重复
  292. >>> # 50个epoch。shuffle算子将
  293. >>> # 整个50个epoch视作一个大数据集。
  294. >>> dataset = dataset.repeat(50)
  295. >>> dataset = dataset.shuffle(10)
  296. ..py:method:: reset()
  297. 重置下一个epoch的数据集。
  298. ..py:method:: save(file_name, num_files=1, file_type='mindrecord')
  299. 将流水线正在处理的数据保存为通用的数据集格式。支持的数据集格式:'mindrecord'。
  300. 将数据保存为'mindrecord'格式时存在隐式类型转换。转换表展示如何执行类型转换。
  301. .. list-table:: 保存为'mindrecord'格式时的隐式类型转换
  302. :widths: 25 25 50
  303. :header-rows: 1
  304. * - 'dataset'类型
  305. - 'mindrecord'类型
  306. - 详细
  307. * - bool
  308. - None
  309. - 不支持
  310. * - int8
  311. - int32
  312. -
  313. * - uint8
  314. - bytes(1D uint8)
  315. - Drop dimension
  316. * - int16
  317. - int32
  318. -
  319. * - uint16
  320. - int32
  321. -
  322. * - int32
  323. - int32
  324. -
  325. * - uint32
  326. - int64
  327. -
  328. * - int64
  329. - int64
  330. -
  331. * - uint64
  332. - None
  333. - 不支持
  334. * - float16
  335. - float32
  336. -
  337. * - float32
  338. - float32
  339. -
  340. * - float64
  341. - float64
  342. -
  343. * - string
  344. - string
  345. - 不支持多维字符串
  346. .. note::
  347. 1. 如需按顺序保存示例,请将数据集的shuffle设置为False,将 `num_files` 设置为1。
  348. 2. 在调用函数之前,不要使用batch算子、repeat算子或具有随机属性的数据增强的map算子。
  349. 3. 当数据的维度可变时,只支持1维数组或者在0维变化的多维数组。
  350. 4. 不支持DE_UINT64类型、多维的DE_UINT8类型、多维DE_STRING类型。
  351. **参数:**
  352. - **file_name** (str) - 数据集文件的路径。
  353. - **num_files** (int, optional) - 数据集文件的数量(默认为1)。
  354. - **file_type** (str, optional) - 数据集格式(默认为'mindrecord')。
  355. ..py:method:: set_dynamic_columns(columns=None)
  356. 设置源数据的动态shape信息,需要在定义数据处理流水线后设置。
  357. **参数:**
  358. **columns** (dict) - 包含数据集中每列shape信息的字典。shape[i]为 `None` 表示shape[i]的数据长度是动态的。
  359. ..py:method:: shuffle(buffer_size)
  360. 使用以下策略随机打乱此数据集的行:
  361. 1. 生成一个shuffle缓冲区包含buffer_size条数据行。
  362. 2. 从shuffle缓冲区中随机选择一个元素,作为下一行传播到子节点。
  363. 3. 从父节点获取下一行(如果有的话),并将其放入shuffle缓冲区中。
  364. 4. 重复步骤2和3,直到打乱缓冲区中没有数据行为止。
  365. 可以提供随机种子,在第一个epoch中使用。在随后的每个epoch,种子都会被设置成一个新产生的随机值。
  366. **参数:**
  367. **buffer_size** (int) - 用于shuffle的缓冲区大小(必须大于1)。将buffer_size设置为等于数据集大小将导致在全局shuffle。
  368. **返回:**
  369. ShuffleDataset,打乱后的数据集对象。
  370. **异常:**
  371. **RuntimeError** - 打乱前存在同步操作。
  372. **样例:**
  373. >>> # dataset是Dataset对象的实例
  374. >>> # 可以选择设置第一个epoch的种子
  375. >>> ds.config.set_seed(58)
  376. >>> # 使用大小为4的shuffle缓冲区创建打乱后的数据集。
  377. >>> dataset = dataset.shuffle(4)
  378. ..py:method:: skip(count)
  379. 跳过此数据集的前N个元素。
  380. **参数:**
  381. **count** (int) - 要跳过的数据集中的元素个数。
  382. **返回:**
  383. SkipDataset,减去跳过的行的数据集对象。
  384. **样例:**
  385. >>> # dataset是Dataset对象的实例
  386. >>> # 创建一个数据集,跳过前3个元素
  387. >>> dataset = dataset.skip(3)
  388. ..py:method:: split(sizes, randomize=True)
  389. 将数据集拆分为多个不重叠的数据集。
  390. 这是一个通用拆分函数,可以被数据处理流水线中的任何算子调用。
  391. 还有如果直接调用ds.split,其中 ds 是一个 MappableDataset,它将被自动调用。
  392. **参数:**
  393. - **sizes** (Union[list[int], list[float]]) - 如果指定了一列整数[s1, s2, …, sn],数据集将被拆分为n个大小为s1、s2、...、sn的数据集。如果所有输入大小的总和不等于原始数据集大小,则报错。如果指定了一列浮点数[f1, f2, …, fn],则所有浮点数必须介于0和1之间,并且总和必须为1,否则报错。数据集将被拆分为n个大小为round(f1*K)、round(f2*K)、...、round(fn*K)的数据集,其中K是原始数据集的大小。
  394. 如果舍入后:
  395. - 任何大小等于0,都将发生错误。
  396. - 如果拆分大小的总和<K,K - sigma(round(fi * k))的差值将添加到第一个子数据集。
  397. - 如果拆分大小的总和>K,sigma(round(fi * K)) - K的差值将从第一个足够大的拆分子集中删除,删除差值后至少有1行。
  398. - **randomize** (bool, optional):确定是否随机拆分数据(默认为True)。如果为True,则数据集将被随机拆分。否则,将使用数据集中的连续行创建每个拆分子集。
  399. .. note::
  400. 1. 如果要调用 split,则无法对数据集进行分片。
  401. 2. 强烈建议不要对数据集进行打乱,而是使用随机化(randomize=True)。对数据集进行打乱的结果具有不确定性,每个拆分子集中的数据在每个epoch可能都不同。
  402. **异常:**
  403. - **RuntimeError** - get_dataset_size返回None或此数据集不支持。
  404. - **RuntimeError** - sizes是整数列表,并且size中所有元素的总和不等于数据集大小。
  405. - **RuntimeError** - sizes是float列表,并且计算后存在大小为0的拆分子数据集。
  406. - **RuntimeError** - 数据集在调用拆分之前已进行分片。
  407. - **ValueError** - sizes是float列表,且并非所有float数都在0和1之间,或者float数的总和不等于1。
  408. **返回:**
  409. tuple(Dataset),拆分后子数据集对象的元组。
  410. **样例:**
  411. >>> # TextFileDataset不是可映射dataset,因此将调用通用拆分函数。
  412. >>> # 由于许多数据集默认都打开了shuffle,如需调用拆分函数,请将shuffle设置为False。
  413. >>> dataset = ds.TextFileDataset(text_file_dataset_dir, shuffle=False)
  414. >>> train_dataset, test_dataset = dataset.split([0.9, 0.1])
  415. ..py:method:: sync_update(condition_name, num_batch=None, data=None)
  416. 释放阻塞条件并使用给定数据触发回调函数。
  417. **参数:**
  418. - **condition_name** (str) - 用于切换发送下一行数据的条件名称。
  419. - **num_batch** (Union[int, None]) - 释放的batch(row)数。当 `num_batch` 为None时,将默认为 `sync_wait` 算子指定的值(默认为None)。
  420. - **data** (Any) - 用户自定义传递给回调函数的数据(默认为None)。
  421. ..py:method:: sync_wait(condition_name, num_batch=1, callback=None)
  422. 向输入数据集添加阻塞条件。 将应用同步操作。
  423. **参数:**
  424. - **condition_name** (str) - 用于切换发送下一行的条件名称。
  425. - **num_batch** (int) - 每个epoch开始时无阻塞的batch数。
  426. - **callback** (function) - `sync_update` 中将调用的回调函数。
  427. **返回:**
  428. SyncWaitDataset,添加了阻塞条件的数据集对象。
  429. **异常:**
  430. **RuntimeError** - 条件名称已存在。
  431. **样例:**
  432. >>> import numpy as np
  433. >>> def gen():
  434. ... for i in range(100):
  435. ... yield (np.array(i),)
  436. >>>
  437. >>> class Augment:
  438. ... def __init__(self, loss):
  439. ... self.loss = loss
  440. ...
  441. ... def preprocess(self, input_):
  442. ... return input_
  443. ...
  444. ... def update(self, data):
  445. ... self.loss = data["loss"]
  446. >>>
  447. >>> batch_size = 4
  448. >>> dataset = ds.GeneratorDataset(gen, column_names=["input"])
  449. >>>
  450. >>> aug = Augment(0)
  451. >>> dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  452. >>> dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
  453. >>> dataset = dataset.batch(batch_size)
  454. >>> count = 0
  455. >>> for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  456. ... assert data["input"][0] == count
  457. ... count += batch_size
  458. ... data = {"loss": count}
  459. ... dataset.sync_update(condition_name="policy", data=data)
  460. ..py:method:: take(count=-1)
  461. 从数据集中获取最多给定数量的元素。
  462. .. note::
  463. 1. 如果count大于数据集中的元素数或等于-1,则取数据集中的所有元素。
  464. 2. take和batch操作顺序很重要,如果take在batch操作之前,则取给定行数;否则取给定batch数。
  465. **参数:**
  466. **count** (int, optional) - 要从数据集中获取的元素数(默认为-1)。
  467. **返回:**
  468. TakeDataset,取出指定数目的数据集对象。
  469. **样例:**
  470. >>> # dataset是Dataset对象的实例。
  471. >>> # 创建一个数据集,包含50个元素。
  472. >>> dataset = dataset.take(50)
  473. ..py:method:: to_device(send_epoch_end=True, create_data_info_queue=False)
  474. 将数据从CPU传输到GPU、Ascend或其他设备。
  475. **参数:**
  476. - **send_epoch_end** (bool, optional) - 是否将end of sequence发送到设备(默认为True)。
  477. - **create_data_info_queue** (bool, optional) - 是否创建存储数据类型和shape的队列(默认值为False)。
  478. .. note::
  479. 如果设备为Ascend,则逐个传输数据。每次传输的数据最大限制为256M。
  480. **返回:**
  481. TransferDataset,用于传输的数据集对象。
  482. **异常:**
  483. **RuntimeError** - 如果提供了分布式训练的文件路径但读取失败。
  484. ..py:method:: to_json(filename='')
  485. 将数据处理流水线序列化为JSON字符串,如果提供了文件名,则转储到文件中。
  486. **参数:**
  487. **filename** (str) - 另存为JSON格式的文件名。
  488. **返回:**
  489. str,流水线的JSON字符串。
  490. ..py:method:: zip(datasets)
  491. 将数据集和输入的数据集或者数据集元组按列进行合并压缩。输入数据集中的列名必须不同。
  492. **参数:**
  493. **datasets** (Union[tuple, class Dataset]) - 数据集对象的元组或单个数据集对象与当前数据集一起合并压缩。
  494. **返回:**
  495. ZipDataset,合并压缩后的数据集对象。
  496. **样例:**
  497. >>> # 创建一个数据集,它将dataset和dataset_1进行合并
  498. >>> dataset = dataset.zip(dataset_1)