You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MatchingDataLoader.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. """
  2. 这个文件的内容已合并到fastNLP.io.data_loader里,这个文件的内容不再更新
  3. """
  4. import os
  5. from typing import Union, Dict
  6. from fastNLP.core.const import Const
  7. from fastNLP.core.vocabulary import Vocabulary
  8. from fastNLP.io.base_loader import DataBundle, DataSetLoader
  9. from fastNLP.io.dataset_loader import JsonLoader, CSVLoader
  10. from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
  11. from fastNLP.modules.encoder._bert import BertTokenizer
  12. class MatchingLoader(DataSetLoader):
  13. """
  14. 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`
  15. 读取Matching任务的数据集
  16. :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名
  17. """
  18. def __init__(self, paths: dict=None):
  19. self.paths = paths
  20. def _load(self, path):
  21. """
  22. :param str path: 待读取数据集的路径名
  23. :return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子
  24. 的原始字符串文本,第三个为标签
  25. """
  26. raise NotImplementedError
  27. def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None,
  28. to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
  29. cut_text: int = None, get_index=True, auto_pad_length: int=None,
  30. auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True,
  31. set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataBundle:
  32. """
  33. :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
  34. 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和
  35. 对应的全路径文件名。
  36. :param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义
  37. 这个数据集的名字,如果不定义则默认为train。
  38. :param bool to_lower: 是否将文本自动转为小写。默认值为False。
  39. :param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` :
  40. 提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和
  41. attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len
  42. :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径
  43. :param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。
  44. :param bool get_index: 是否需要根据词表将文本转为index
  45. :param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad
  46. :param str auto_pad_token: 自动pad的内容
  47. :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False
  48. 则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input,
  49. 于此同时其他field不会被设置为input。默认值为True。
  50. :param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。
  51. :param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。
  52. 如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果
  53. 传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]'].
  54. :return:
  55. """
  56. if isinstance(set_input, str):
  57. set_input = [set_input]
  58. if isinstance(set_target, str):
  59. set_target = [set_target]
  60. if isinstance(set_input, bool):
  61. auto_set_input = set_input
  62. else:
  63. auto_set_input = False
  64. if isinstance(set_target, bool):
  65. auto_set_target = set_target
  66. else:
  67. auto_set_target = False
  68. if isinstance(paths, str):
  69. if os.path.isdir(paths):
  70. path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()}
  71. else:
  72. path = {dataset_name if dataset_name is not None else 'train': paths}
  73. else:
  74. path = paths
  75. data_info = DataBundle()
  76. for data_name in path.keys():
  77. data_info.datasets[data_name] = self._load(path[data_name])
  78. for data_name, data_set in data_info.datasets.items():
  79. if auto_set_input:
  80. data_set.set_input(Const.INPUTS(0), Const.INPUTS(1))
  81. if auto_set_target:
  82. if Const.TARGET in data_set.get_field_names():
  83. data_set.set_target(Const.TARGET)
  84. if to_lower:
  85. for data_name, data_set in data_info.datasets.items():
  86. data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0),
  87. is_input=auto_set_input)
  88. data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1),
  89. is_input=auto_set_input)
  90. if bert_tokenizer is not None:
  91. if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR:
  92. PRETRAIN_URL = _get_base_url('bert')
  93. model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer]
  94. model_url = PRETRAIN_URL + model_name
  95. model_dir = cached_path(model_url)
  96. # 检查是否存在
  97. elif os.path.isdir(bert_tokenizer):
  98. model_dir = bert_tokenizer
  99. else:
  100. raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.")
  101. words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]')
  102. with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f:
  103. lines = f.readlines()
  104. lines = [line.strip() for line in lines]
  105. words_vocab.add_word_lst(lines)
  106. words_vocab.build_vocab()
  107. tokenizer = BertTokenizer.from_pretrained(model_dir)
  108. for data_name, data_set in data_info.datasets.items():
  109. for fields in data_set.get_field_names():
  110. if Const.INPUT in fields:
  111. data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields,
  112. is_input=auto_set_input)
  113. if isinstance(concat, bool):
  114. concat = 'default' if concat else None
  115. if concat is not None:
  116. if isinstance(concat, str):
  117. CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'],
  118. 'default': ['', '<sep>', '', '']}
  119. if concat.lower() in CONCAT_MAP:
  120. concat = CONCAT_MAP[concat]
  121. else:
  122. concat = 4 * [concat]
  123. assert len(concat) == 4, \
  124. f'Please choose a list with 4 symbols which at the beginning of first sentence ' \
  125. f'the end of first sentence, the begin of second sentence, and the end of second' \
  126. f'sentence. Your input is {concat}'
  127. for data_name, data_set in data_info.datasets.items():
  128. data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] +
  129. x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT)
  130. data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT,
  131. is_input=auto_set_input)
  132. if seq_len_type is not None:
  133. if seq_len_type == 'seq_len': #
  134. for data_name, data_set in data_info.datasets.items():
  135. for fields in data_set.get_field_names():
  136. if Const.INPUT in fields:
  137. data_set.apply(lambda x: len(x[fields]),
  138. new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN),
  139. is_input=auto_set_input)
  140. elif seq_len_type == 'mask':
  141. for data_name, data_set in data_info.datasets.items():
  142. for fields in data_set.get_field_names():
  143. if Const.INPUT in fields:
  144. data_set.apply(lambda x: [1] * len(x[fields]),
  145. new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN),
  146. is_input=auto_set_input)
  147. elif seq_len_type == 'bert':
  148. for data_name, data_set in data_info.datasets.items():
  149. if Const.INPUT not in data_set.get_field_names():
  150. raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: '
  151. f'got {data_set.get_field_names()}')
  152. data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1),
  153. new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input)
  154. data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]),
  155. new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input)
  156. if auto_pad_length is not None:
  157. cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length)
  158. if cut_text is not None:
  159. for data_name, data_set in data_info.datasets.items():
  160. for fields in data_set.get_field_names():
  161. if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')):
  162. data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields,
  163. is_input=auto_set_input)
  164. data_set_list = [d for n, d in data_info.datasets.items()]
  165. assert len(data_set_list) > 0, f'There are NO data sets in data info!'
  166. if bert_tokenizer is None:
  167. words_vocab = Vocabulary(padding=auto_pad_token)
  168. words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n],
  169. field_name=[n for n in data_set_list[0].get_field_names()
  170. if (Const.INPUT in n)],
  171. no_create_entry_dataset=[d for n, d in data_info.datasets.items()
  172. if 'train' not in n])
  173. target_vocab = Vocabulary(padding=None, unknown=None)
  174. target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n],
  175. field_name=Const.TARGET)
  176. data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab}
  177. if get_index:
  178. for data_name, data_set in data_info.datasets.items():
  179. for fields in data_set.get_field_names():
  180. if Const.INPUT in fields:
  181. data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields,
  182. is_input=auto_set_input)
  183. if Const.TARGET in data_set.get_field_names():
  184. data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET,
  185. is_input=auto_set_input, is_target=auto_set_target)
  186. if auto_pad_length is not None:
  187. if seq_len_type == 'seq_len':
  188. raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, '
  189. f'so the seq_len_type cannot be `{seq_len_type}`!')
  190. for data_name, data_set in data_info.datasets.items():
  191. for fields in data_set.get_field_names():
  192. if Const.INPUT in fields:
  193. data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] *
  194. (auto_pad_length - len(x[fields])), new_field_name=fields,
  195. is_input=auto_set_input)
  196. elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'):
  197. data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])),
  198. new_field_name=fields, is_input=auto_set_input)
  199. for data_name, data_set in data_info.datasets.items():
  200. if isinstance(set_input, list):
  201. data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()])
  202. if isinstance(set_target, list):
  203. data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()])
  204. return data_info
  205. class SNLILoader(MatchingLoader, JsonLoader):
  206. """
  207. 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`
  208. 读取SNLI数据集,读取的DataSet包含fields::
  209. words1: list(str),第一句文本, premise
  210. words2: list(str), 第二句文本, hypothesis
  211. target: str, 真实标签
  212. 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
  213. """
  214. def __init__(self, paths: dict=None):
  215. fields = {
  216. 'sentence1_binary_parse': Const.INPUTS(0),
  217. 'sentence2_binary_parse': Const.INPUTS(1),
  218. 'gold_label': Const.TARGET,
  219. }
  220. paths = paths if paths is not None else {
  221. 'train': 'snli_1.0_train.jsonl',
  222. 'dev': 'snli_1.0_dev.jsonl',
  223. 'test': 'snli_1.0_test.jsonl'}
  224. MatchingLoader.__init__(self, paths=paths)
  225. JsonLoader.__init__(self, fields=fields)
  226. def _load(self, path):
  227. ds = JsonLoader._load(self, path)
  228. parentheses_table = str.maketrans({'(': None, ')': None})
  229. ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
  230. new_field_name=Const.INPUTS(0))
  231. ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
  232. new_field_name=Const.INPUTS(1))
  233. ds.drop(lambda x: x[Const.TARGET] == '-')
  234. return ds
  235. class RTELoader(MatchingLoader, CSVLoader):
  236. """
  237. 别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader`
  238. 读取RTE数据集,读取的DataSet包含fields::
  239. words1: list(str),第一句文本, premise
  240. words2: list(str), 第二句文本, hypothesis
  241. target: str, 真实标签
  242. 数据来源:
  243. """
  244. def __init__(self, paths: dict=None):
  245. paths = paths if paths is not None else {
  246. 'train': 'train.tsv',
  247. 'dev': 'dev.tsv',
  248. 'test': 'test.tsv' # test set has not label
  249. }
  250. MatchingLoader.__init__(self, paths=paths)
  251. self.fields = {
  252. 'sentence1': Const.INPUTS(0),
  253. 'sentence2': Const.INPUTS(1),
  254. 'label': Const.TARGET,
  255. }
  256. CSVLoader.__init__(self, sep='\t')
  257. def _load(self, path):
  258. ds = CSVLoader._load(self, path)
  259. for k, v in self.fields.items():
  260. if v in ds.get_field_names():
  261. ds.rename_field(k, v)
  262. for fields in ds.get_all_fields():
  263. if Const.INPUT in fields:
  264. ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
  265. return ds
  266. class QNLILoader(MatchingLoader, CSVLoader):
  267. """
  268. 别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader`
  269. 读取QNLI数据集,读取的DataSet包含fields::
  270. words1: list(str),第一句文本, premise
  271. words2: list(str), 第二句文本, hypothesis
  272. target: str, 真实标签
  273. 数据来源:
  274. """
  275. def __init__(self, paths: dict=None):
  276. paths = paths if paths is not None else {
  277. 'train': 'train.tsv',
  278. 'dev': 'dev.tsv',
  279. 'test': 'test.tsv' # test set has not label
  280. }
  281. MatchingLoader.__init__(self, paths=paths)
  282. self.fields = {
  283. 'question': Const.INPUTS(0),
  284. 'sentence': Const.INPUTS(1),
  285. 'label': Const.TARGET,
  286. }
  287. CSVLoader.__init__(self, sep='\t')
  288. def _load(self, path):
  289. ds = CSVLoader._load(self, path)
  290. for k, v in self.fields.items():
  291. if v in ds.get_field_names():
  292. ds.rename_field(k, v)
  293. for fields in ds.get_all_fields():
  294. if Const.INPUT in fields:
  295. ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
  296. return ds
  297. class MNLILoader(MatchingLoader, CSVLoader):
  298. """
  299. 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader`
  300. 读取MNLI数据集,读取的DataSet包含fields::
  301. words1: list(str),第一句文本, premise
  302. words2: list(str), 第二句文本, hypothesis
  303. target: str, 真实标签
  304. 数据来源:
  305. """
  306. def __init__(self, paths: dict=None):
  307. paths = paths if paths is not None else {
  308. 'train': 'train.tsv',
  309. 'dev_matched': 'dev_matched.tsv',
  310. 'dev_mismatched': 'dev_mismatched.tsv',
  311. 'test_matched': 'test_matched.tsv',
  312. 'test_mismatched': 'test_mismatched.tsv',
  313. # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt',
  314. # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt',
  315. # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle)
  316. }
  317. MatchingLoader.__init__(self, paths=paths)
  318. CSVLoader.__init__(self, sep='\t')
  319. self.fields = {
  320. 'sentence1_binary_parse': Const.INPUTS(0),
  321. 'sentence2_binary_parse': Const.INPUTS(1),
  322. 'gold_label': Const.TARGET,
  323. }
  324. def _load(self, path):
  325. ds = CSVLoader._load(self, path)
  326. for k, v in self.fields.items():
  327. if k in ds.get_field_names():
  328. ds.rename_field(k, v)
  329. if Const.TARGET in ds.get_field_names():
  330. if ds[0][Const.TARGET] == 'hidden':
  331. ds.delete_field(Const.TARGET)
  332. parentheses_table = str.maketrans({'(': None, ')': None})
  333. ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
  334. new_field_name=Const.INPUTS(0))
  335. ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
  336. new_field_name=Const.INPUTS(1))
  337. if Const.TARGET in ds.get_field_names():
  338. ds.drop(lambda x: x[Const.TARGET] == '-')
  339. return ds
  340. class QuoraLoader(MatchingLoader, CSVLoader):
  341. """
  342. 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader`
  343. 读取MNLI数据集,读取的DataSet包含fields::
  344. words1: list(str),第一句文本, premise
  345. words2: list(str), 第二句文本, hypothesis
  346. target: str, 真实标签
  347. 数据来源:
  348. """
  349. def __init__(self, paths: dict=None):
  350. paths = paths if paths is not None else {
  351. 'train': 'train.tsv',
  352. 'dev': 'dev.tsv',
  353. 'test': 'test.tsv',
  354. }
  355. MatchingLoader.__init__(self, paths=paths)
  356. CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID'))
  357. def _load(self, path):
  358. ds = CSVLoader._load(self, path)
  359. return ds