You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_bundle.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. r"""
  2. .. todo::
  3. doc
  4. """
  5. __all__ = [
  6. 'DataBundle',
  7. ]
  8. from typing import Union, List, Callable
  9. from ..core.dataset import DataSet
  10. from fastNLP.core.vocabulary import Vocabulary
  11. from fastNLP.core import logger
  12. class DataBundle:
  13. r"""
  14. 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种
  15. Loader的load函数生成,可以通过以下的方法获取里面的内容
  16. Example::
  17. data_bundle = YelpLoader().load({'train':'/path/to/train', 'dev': '/path/to/dev'})
  18. train_vocabs = data_bundle.vocabs['train']
  19. train_data = data_bundle.datasets['train']
  20. dev_data = data_bundle.datasets['train']
  21. """
  22. def __init__(self, vocabs=None, datasets=None):
  23. r"""
  24. :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict
  25. :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict。建议不要将相同的DataSet对象重复传入,可能会在
  26. 使用Pipe处理数据的时候遇到问题,若多个数据集确需一致,请手动deepcopy后传入。
  27. """
  28. self.vocabs = vocabs or {}
  29. self.datasets = datasets or {}
  30. def set_vocab(self, vocab: Vocabulary, field_name: str):
  31. r"""
  32. 向DataBunlde中增加vocab
  33. :param ~fastNLP.Vocabulary vocab: 词表
  34. :param str field_name: 这个vocab对应的field名称
  35. :return: self
  36. """
  37. assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary supports."
  38. self.vocabs[field_name] = vocab
  39. return self
  40. def set_dataset(self, dataset: DataSet, name: str):
  41. r"""
  42. :param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet
  43. :param str name: dataset的名称
  44. :return: self
  45. """
  46. assert isinstance(dataset, DataSet), "Only fastNLP.DataSet supports."
  47. self.datasets[name] = dataset
  48. return self
  49. def get_dataset(self, name: str) -> DataSet:
  50. r"""
  51. 获取名为name的dataset
  52. :param str name: dataset的名称,一般为'train', 'dev', 'test'
  53. :return: DataSet
  54. """
  55. if name in self.datasets.keys():
  56. return self.datasets[name]
  57. else:
  58. error_msg = f'DataBundle do NOT have DataSet named {name}. ' \
  59. f'It should be one of {self.datasets.keys()}.'
  60. logger.error(error_msg)
  61. raise KeyError(error_msg)
  62. def delete_dataset(self, name: str):
  63. r"""
  64. 删除名为name的DataSet
  65. :param str name:
  66. :return: self
  67. """
  68. self.datasets.pop(name, None)
  69. return self
  70. def get_vocab(self, field_name: str) -> Vocabulary:
  71. r"""
  72. 获取field名为field_name对应的vocab
  73. :param str field_name: 名称
  74. :return: Vocabulary
  75. """
  76. if field_name in self.vocabs.keys():
  77. return self.vocabs[field_name]
  78. else:
  79. error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \
  80. f'It should be one of {self.vocabs.keys()}.'
  81. logger.error(error_msg)
  82. raise KeyError(error_msg)
  83. def delete_vocab(self, field_name: str):
  84. r"""
  85. 删除vocab
  86. :param str field_name:
  87. :return: self
  88. """
  89. self.vocabs.pop(field_name, None)
  90. return self
  91. @property
  92. def num_dataset(self):
  93. return len(self.datasets)
  94. @property
  95. def num_vocab(self):
  96. return len(self.vocabs)
  97. def copy_field(self, field_name: str, new_field_name: str, ignore_miss_dataset=True):
  98. r"""
  99. 将DataBundle中所有的DataSet中名为field_name的Field复制一份并命名为叫new_field_name.
  100. :param str field_name:
  101. :param str new_field_name:
  102. :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
  103. 如果为False,则报错
  104. :return: self
  105. """
  106. for name, dataset in self.datasets.items():
  107. if dataset.has_field(field_name=field_name):
  108. dataset.copy_field(field_name=field_name, new_field_name=new_field_name)
  109. elif not ignore_miss_dataset:
  110. raise KeyError(f"{field_name} not found DataSet:{name}.")
  111. return self
  112. def rename_field(self, field_name: str, new_field_name: str, ignore_miss_dataset=True, rename_vocab=True):
  113. r"""
  114. 将DataBundle中所有DataSet中名为field_name的field重命名为new_field_name.
  115. :param str field_name:
  116. :param str new_field_name:
  117. :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
  118. 如果为False,则报错
  119. :param bool rename_vocab: 如果该field同时也存在于vocabs中,会将该field的名称对应修改
  120. :return: self
  121. """
  122. for name, dataset in self.datasets.items():
  123. if dataset.has_field(field_name=field_name):
  124. dataset.rename_field(field_name=field_name, new_field_name=new_field_name)
  125. elif not ignore_miss_dataset:
  126. raise KeyError(f"{field_name} not found DataSet:{name}.")
  127. if rename_vocab:
  128. if field_name in self.vocabs:
  129. self.vocabs[new_field_name] = self.vocabs.pop(field_name)
  130. return self
  131. def delete_field(self, field_name: str, ignore_miss_dataset=True, delete_vocab=True):
  132. r"""
  133. 将DataBundle中所有DataSet中名为field_name的field删除掉.
  134. :param str field_name:
  135. :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
  136. 如果为False,则报错
  137. :param bool delete_vocab: 如果该field也在vocabs中存在,将该值也一并删除
  138. :return: self
  139. """
  140. for name, dataset in self.datasets.items():
  141. if dataset.has_field(field_name=field_name):
  142. dataset.delete_field(field_name=field_name)
  143. elif not ignore_miss_dataset:
  144. raise KeyError(f"{field_name} not found DataSet:{name}.")
  145. if delete_vocab:
  146. if field_name in self.vocabs:
  147. self.vocabs.pop(field_name)
  148. return self
  149. def iter_datasets(self) -> Union[str, DataSet]:
  150. r"""
  151. 迭代data_bundle中的DataSet
  152. Example::
  153. for name, dataset in data_bundle.iter_datasets():
  154. pass
  155. :return:
  156. """
  157. for name, dataset in self.datasets.items():
  158. yield name, dataset
  159. def get_dataset_names(self) -> List[str]:
  160. r"""
  161. 返回DataBundle中DataSet的名称
  162. :return:
  163. """
  164. return list(self.datasets.keys())
  165. def get_vocab_names(self) -> List[str]:
  166. r"""
  167. 返回DataBundle中Vocabulary的名称
  168. :return:
  169. """
  170. return list(self.vocabs.keys())
  171. def iter_vocabs(self):
  172. r"""
  173. 迭代data_bundle中的DataSet
  174. Example:
  175. for field_name, vocab in data_bundle.iter_vocabs():
  176. pass
  177. :return:
  178. """
  179. for field_name, vocab in self.vocabs.items():
  180. yield field_name, vocab
  181. def apply_field(self, func: Callable, field_name: str, new_field_name: str, num_proc: int = 0,
  182. ignore_miss_dataset: bool = True, progress_desc: str = '', progress_bar: str = 'rich'):
  183. r"""
  184. 对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法
  185. :param callable func: input是instance中名为 `field_name` 的field的内容。
  186. :param str field_name: 传入func的是哪个field。
  187. :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
  188. 盖之前的field。如果为None则不创建新的field。
  189. :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
  190. 如果为False,则报错
  191. :param num_proc: 使用进程的数量。
  192. .. note::
  193. 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
  194. ``func`` 函数中的打印将不会输出。
  195. :param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。
  196. :param progress_desc: 当显示 progress 时,可以显示当前正在处理的名称
  197. :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
  198. """
  199. _progress_desc = progress_desc
  200. for name, dataset in self.datasets.items():
  201. if _progress_desc:
  202. progress_desc = _progress_desc + f' for `{name}`'
  203. if dataset.has_field(field_name=field_name):
  204. dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc,
  205. progress_desc=progress_desc, progress_bar=progress_bar)
  206. elif not ignore_miss_dataset:
  207. raise KeyError(f"{field_name} not found DataSet:{name}.")
  208. return self
  209. def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True,
  210. ignore_miss_dataset=True, progress_bar: str = 'rich', progress_desc: str = ''):
  211. r"""
  212. 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法
  213. .. note::
  214. ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
  215. ``apply`` 区别的介绍。
  216. :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
  217. :param str field_name: 传入func的是哪个field。
  218. :param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True
  219. :param num_proc: 使用进程的数量。
  220. .. note::
  221. 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
  222. ``func`` 函数中的打印将不会输出。
  223. :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
  224. 如果为False,则报错
  225. :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
  226. :param progress_desc: 当显示 progress_bar 时,可以显示 ``progress`` 的名称。
  227. :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
  228. """
  229. res = {}
  230. _progress_desc = progress_desc
  231. for name, dataset in self.datasets.items():
  232. if _progress_desc:
  233. progress_desc = _progress_desc + f' for `{name}`'
  234. if dataset.has_field(field_name=field_name):
  235. res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc,
  236. modify_fields=modify_fields,
  237. progress_bar=progress_bar, progress_desc=progress_desc)
  238. elif not ignore_miss_dataset:
  239. raise KeyError(f"{field_name} not found DataSet:{name} .")
  240. return res
  241. def apply(self, func: Callable, new_field_name: str, num_proc: int = 0,
  242. progress_desc: str = '', progress_bar: bool = True):
  243. r"""
  244. 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法
  245. 对DataBundle中所有的dataset使用apply方法
  246. :param callable func: input是instance中名为 `field_name` 的field的内容。
  247. :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
  248. 盖之前的field。如果为None则不创建新的field。
  249. :param num_proc: 使用进程的数量。
  250. .. note::
  251. 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
  252. ``func`` 函数中的打印将不会输出。
  253. :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
  254. :param progress_desc: 当显示 progress bar 时,可以显示当前正在处理的名称
  255. """
  256. _progress_desc = progress_desc
  257. for name, dataset in self.datasets.items():
  258. if _progress_desc:
  259. progress_desc = _progress_desc + f' for `{name}`'
  260. dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, progress_bar=progress_bar,
  261. progress_desc=progress_desc)
  262. return self
  263. def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0,
  264. progress_desc: str = '', progress_bar: str = 'rich'):
  265. r"""
  266. 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法
  267. .. note::
  268. ``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
  269. ``apply`` 区别的介绍。
  270. :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
  271. :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
  272. :param num_proc: 使用进程的数量。
  273. .. note::
  274. 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
  275. ``func`` 函数中的打印将不会输出。
  276. :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
  277. :param progress_desc: 当显示 progress_bar 时,可以显示当前正在处理的名称
  278. :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
  279. """
  280. res = {}
  281. _progress_desc = progress_desc
  282. for name, dataset in self.datasets.items():
  283. if _progress_desc:
  284. progress_desc = _progress_desc + f' for `{name}`'
  285. res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc,
  286. progress_bar=progress_bar, progress_desc=progress_desc)
  287. return res
  288. def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle":
  289. """
  290. 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
  291. :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
  292. field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
  293. 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
  294. 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
  295. :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
  296. field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
  297. 无意义。
  298. :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
  299. :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
  300. torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
  301. :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
  302. batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
  303. 形式,输出将被直接作为结果输出。
  304. :return: self
  305. """
  306. for _, ds in self.iter_datasets():
  307. ds.collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, backend=backend,
  308. pad_fn=pad_fn)
  309. return self
  310. def set_ignore(self, *field_names) -> "DataBundle":
  311. """
  312. 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
  313. Example::
  314. collator.set_ignore('field1', 'field2')
  315. :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
  316. field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果
  317. __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
  318. :return: self
  319. """
  320. for _, ds in self.iter_datasets():
  321. ds.collator.set_ignore(*field_names)
  322. return self
  323. def __repr__(self) -> str:
  324. _str = ''
  325. if len(self.datasets):
  326. _str += 'In total {} datasets:\n'.format(self.num_dataset)
  327. for name, dataset in self.datasets.items():
  328. _str += '\t{} has {} instances.\n'.format(name, len(dataset))
  329. if len(self.vocabs):
  330. _str += 'In total {} vocabs:\n'.format(self.num_vocab)
  331. for name, vocab in self.vocabs.items():
  332. _str += '\t{} has {} entries.\n'.format(name, len(vocab))
  333. return _str