You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

processor.py 17 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. import re
  2. from collections import defaultdict
  3. import torch
  4. from fastNLP.core.batch import Batch
  5. from fastNLP.core.dataset import DataSet
  6. from fastNLP.core.sampler import SequentialSampler
  7. from fastNLP.core.vocabulary import Vocabulary
  8. class Processor(object):
  9. def __init__(self, field_name, new_added_field_name):
  10. """
  11. :param field_name: 处理哪个field
  12. :param new_added_field_name: 如果为None,则认为是field_name,即覆盖原有的field
  13. """
  14. self.field_name = field_name
  15. if new_added_field_name is None:
  16. self.new_added_field_name = field_name
  17. else:
  18. self.new_added_field_name = new_added_field_name
  19. def process(self, *args, **kwargs):
  20. raise NotImplementedError
  21. def __call__(self, *args, **kwargs):
  22. return self.process(*args, **kwargs)
  23. class FullSpaceToHalfSpaceProcessor(Processor):
  24. """全角转半角,以字符为处理单元
  25. """
  26. def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True,
  27. change_space=True):
  28. super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None)
  29. self.change_alpha = change_alpha
  30. self.change_digit = change_digit
  31. self.change_punctuation = change_punctuation
  32. self.change_space = change_space
  33. FH_SPACE = [(u" ", u" ")]
  34. FH_NUM = [
  35. (u"0", u"0"), (u"1", u"1"), (u"2", u"2"), (u"3", u"3"), (u"4", u"4"),
  36. (u"5", u"5"), (u"6", u"6"), (u"7", u"7"), (u"8", u"8"), (u"9", u"9")]
  37. FH_ALPHA = [
  38. (u"a", u"a"), (u"b", u"b"), (u"c", u"c"), (u"d", u"d"), (u"e", u"e"),
  39. (u"f", u"f"), (u"g", u"g"), (u"h", u"h"), (u"i", u"i"), (u"j", u"j"),
  40. (u"k", u"k"), (u"l", u"l"), (u"m", u"m"), (u"n", u"n"), (u"o", u"o"),
  41. (u"p", u"p"), (u"q", u"q"), (u"r", u"r"), (u"s", u"s"), (u"t", u"t"),
  42. (u"u", u"u"), (u"v", u"v"), (u"w", u"w"), (u"x", u"x"), (u"y", u"y"),
  43. (u"z", u"z"),
  44. (u"A", u"A"), (u"B", u"B"), (u"C", u"C"), (u"D", u"D"), (u"E", u"E"),
  45. (u"F", u"F"), (u"G", u"G"), (u"H", u"H"), (u"I", u"I"), (u"J", u"J"),
  46. (u"K", u"K"), (u"L", u"L"), (u"M", u"M"), (u"N", u"N"), (u"O", u"O"),
  47. (u"P", u"P"), (u"Q", u"Q"), (u"R", u"R"), (u"S", u"S"), (u"T", u"T"),
  48. (u"U", u"U"), (u"V", u"V"), (u"W", u"W"), (u"X", u"X"), (u"Y", u"Y"),
  49. (u"Z", u"Z")]
  50. # 谨慎使用标点符号转换, 因为"5.12特大地震"转换后可能就成了"5.12特大地震"
  51. FH_PUNCTUATION = [
  52. (u'%', u'%'), (u'!', u'!'), (u'"', u'\"'), (u''', u'\''), (u'#', u'#'),
  53. (u'¥', u'$'), (u'&', u'&'), (u'(', u'('), (u')', u')'), (u'*', u'*'),
  54. (u'+', u'+'), (u',', u','), (u'-', u'-'), (u'.', u'.'), (u'/', u'/'),
  55. (u':', u':'), (u';', u';'), (u'<', u'<'), (u'=', u'='), (u'>', u'>'),
  56. (u'?', u'?'), (u'@', u'@'), (u'[', u'['), (u']', u']'), (u'\', u'\\'),
  57. (u'^', u'^'), (u'_', u'_'), (u'`', u'`'), (u'~', u'~'), (u'{', u'{'),
  58. (u'}', u'}'), (u'|', u'|')]
  59. FHs = []
  60. if self.change_alpha:
  61. FHs = FH_ALPHA
  62. if self.change_digit:
  63. FHs += FH_NUM
  64. if self.change_punctuation:
  65. FHs += FH_PUNCTUATION
  66. if self.change_space:
  67. FHs += FH_SPACE
  68. self.convert_map = {k: v for k, v in FHs}
  69. def process(self, dataset):
  70. assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
  71. def inner_proc(ins):
  72. sentence = ins[self.field_name]
  73. new_sentence = [""] * len(sentence)
  74. for idx, char in enumerate(sentence):
  75. if char in self.convert_map:
  76. char = self.convert_map[char]
  77. new_sentence[idx] = char
  78. return "".join(new_sentence)
  79. dataset.apply(inner_proc, new_field_name=self.field_name)
  80. return dataset
  81. class PreAppendProcessor(Processor):
  82. """
  83. 向某个field的起始增加data(应该为str类型)。该field需要为list类型。即新增的field为
  84. [data] + instance[field_name]
  85. """
  86. def __init__(self, data, field_name, new_added_field_name=None):
  87. super(PreAppendProcessor, self).__init__(field_name, new_added_field_name)
  88. self.data = data
  89. def process(self, dataset):
  90. dataset.apply(lambda ins: [self.data] + ins[self.field_name], new_field_name=self.new_added_field_name)
  91. return dataset
  92. class SliceProcessor(Processor):
  93. """
  94. 从某个field中只取部分内容。等价于instance[field_name][start:end:step]
  95. """
  96. def __init__(self, start, end, step, field_name, new_added_field_name=None):
  97. super(SliceProcessor, self).__init__(field_name, new_added_field_name)
  98. for o in (start, end, step):
  99. assert isinstance(o, int) or o is None
  100. self.slice = slice(start, end, step)
  101. def process(self, dataset):
  102. dataset.apply(lambda ins: ins[self.field_name][self.slice], new_field_name=self.new_added_field_name)
  103. return dataset
  104. class Num2TagProcessor(Processor):
  105. """
  106. 将一句话中的数字转换为某个tag。
  107. """
  108. def __init__(self, tag, field_name, new_added_field_name=None):
  109. """
  110. :param tag: str, 将数字转换为该tag
  111. :param field_name:
  112. :param new_added_field_name:
  113. """
  114. super(Num2TagProcessor, self).__init__(field_name, new_added_field_name)
  115. self.tag = tag
  116. self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)'
  117. def process(self, dataset):
  118. def inner_proc(ins):
  119. s = ins[self.field_name]
  120. new_s = [None] * len(s)
  121. for i, w in enumerate(s):
  122. if re.search(self.pattern, w) is not None:
  123. w = self.tag
  124. new_s[i] = w
  125. return new_s
  126. dataset.apply(inner_proc, new_field_name=self.new_added_field_name)
  127. return dataset
  128. class IndexerProcessor(Processor):
  129. """
  130. 给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如
  131. ['我', '是', xxx]
  132. """
  133. def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True):
  134. assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
  135. super(IndexerProcessor, self).__init__(field_name, new_added_field_name)
  136. self.vocab = vocab
  137. self.delete_old_field = delete_old_field
  138. self.is_input = is_input
  139. def set_vocab(self, vocab):
  140. assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
  141. self.vocab = vocab
  142. def process(self, dataset):
  143. assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
  144. dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]],
  145. new_field_name=self.new_added_field_name)
  146. if self.is_input:
  147. dataset.set_input(self.new_added_field_name)
  148. if self.delete_old_field:
  149. dataset.delete_field(self.field_name)
  150. return dataset
  151. class VocabProcessor(Processor):
  152. """
  153. 传入若干个DataSet以建立vocabulary。
  154. """
  155. def __init__(self, field_name, min_freq=1, max_size=None):
  156. super(VocabProcessor, self).__init__(field_name, None)
  157. self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size)
  158. def process(self, *datasets):
  159. for dataset in datasets:
  160. assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
  161. dataset.apply(lambda ins: self.vocab.update(ins[self.field_name]))
  162. def get_vocab(self):
  163. self.vocab.build_vocab()
  164. return self.vocab
  165. class SeqLenProcessor(Processor):
  166. """
  167. 根据某个field新增一个sequence length的field。取该field的第一维
  168. """
  169. def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True):
  170. super(SeqLenProcessor, self).__init__(field_name, new_added_field_name)
  171. self.is_input = is_input
  172. def process(self, dataset):
  173. assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
  174. dataset.apply(lambda ins: len(ins[self.field_name]), new_field_name=self.new_added_field_name)
  175. if self.is_input:
  176. dataset.set_input(self.new_added_field_name)
  177. return dataset
  178. from fastNLP.core.utils import _build_args
  179. class ModelProcessor(Processor):
  180. def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32):
  181. """
  182. 传入一个model,在process()时传入一个dataset,该processor会通过Batch将DataSet的内容输出给model.predict或者model.forward.
  183. model输出的内容会被增加到dataset中,field_name由model输出决定。如果生成的内容维度不是(Batch_size, )与
  184. (Batch_size, 1),则使用seqence length这个field进行unpad
  185. TODO 这个类需要删除对seq_lens的依赖。
  186. :param seq_len_field_name:
  187. :param batch_size:
  188. """
  189. super(ModelProcessor, self).__init__(None, None)
  190. self.batch_size = batch_size
  191. self.seq_len_field_name = seq_len_field_name
  192. self.model = model
  193. def process(self, dataset):
  194. self.model.eval()
  195. assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
  196. data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler())
  197. batch_output = defaultdict(list)
  198. predict_func = self.model.forward
  199. with torch.no_grad():
  200. for batch_x, _ in data_iterator:
  201. refined_batch_x = _build_args(predict_func, **batch_x)
  202. prediction = predict_func(**refined_batch_x)
  203. seq_lens = batch_x[self.seq_len_field_name].tolist()
  204. for key, value in prediction.items():
  205. tmp_batch = []
  206. value = value.cpu().numpy()
  207. if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1):
  208. batch_output[key].extend(value.tolist())
  209. else:
  210. for idx, seq_len in enumerate(seq_lens):
  211. tmp_batch.append(value[idx, :seq_len])
  212. batch_output[key].extend(tmp_batch)
  213. if not self.seq_len_field_name in prediction:
  214. batch_output[self.seq_len_field_name].extend(seq_lens)
  215. # TODO 当前的实现会导致之后的processor需要知道model输出的output的key是什么
  216. for field_name, fields in batch_output.items():
  217. dataset.add_field(field_name, fields, is_input=True, is_target=False)
  218. return dataset
  219. def set_model(self, model):
  220. self.model = model
  221. def set_model_device(self, device):
  222. device = torch.device(device)
  223. self.model.to(device)
  224. class Index2WordProcessor(Processor):
  225. """
  226. 将DataSet中某个为index的field根据vocab转换为str
  227. """
  228. def __init__(self, vocab, field_name, new_added_field_name):
  229. super(Index2WordProcessor, self).__init__(field_name, new_added_field_name)
  230. self.vocab = vocab
  231. def process(self, dataset):
  232. dataset.apply(lambda ins: [self.vocab.to_word(w) for w in ins[self.field_name]],
  233. new_field_name=self.new_added_field_name)
  234. return dataset
  235. class SetTargetProcessor(Processor):
  236. def __init__(self, *fields, flag=True):
  237. super(SetTargetProcessor, self).__init__(None, None)
  238. self.fields = fields
  239. self.flag = flag
  240. def process(self, dataset):
  241. dataset.set_target(*self.fields, flag=self.flag)
  242. return dataset
  243. class SetInputProcessor(Processor):
  244. def __init__(self, *fields, flag=True):
  245. super(SetInputProcessor, self).__init__(None, None)
  246. self.fields = fields
  247. self.flag = flag
  248. def process(self, dataset):
  249. dataset.set_input(*self.fields, flag=self.flag)
  250. return dataset
  251. class VocabIndexerProcessor(Processor):
  252. """
  253. 根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供
  254. new_added_field_name, 则覆盖原有的field_name.
  255. """
  256. def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None,
  257. verbose=0, is_input=True):
  258. """
  259. :param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作
  260. :param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name.
  261. :param min_freq: 创建的Vocabulary允许的单词最少出现次数.
  262. :param max_size: 创建的Vocabulary允许的最大的单词数量
  263. :param verbose: 0, 不输出任何信息;1,输出信息
  264. :param bool is_input:
  265. """
  266. super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name)
  267. self.min_freq = min_freq
  268. self.max_size = max_size
  269. self.verbose = verbose
  270. self.is_input = is_input
  271. def construct_vocab(self, *datasets):
  272. """
  273. 使用传入的DataSet创建vocabulary
  274. :param datasets: DataSet类型的数据,用于构建vocabulary
  275. :return:
  276. """
  277. self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size)
  278. for dataset in datasets:
  279. assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
  280. dataset.apply(lambda ins: self.vocab.update(ins[self.field_name]))
  281. self.vocab.build_vocab()
  282. if self.verbose:
  283. print("Vocabulary Constructed, has {} items.".format(len(self.vocab)))
  284. def process(self, *datasets, only_index_dataset=None):
  285. """
  286. 若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary
  287. 后,则会index datasets与only_index_dataset。
  288. :param datasets: DataSet类型的数据
  289. :param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。
  290. :return:
  291. """
  292. if len(datasets) == 0 and not hasattr(self, 'vocab'):
  293. raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.")
  294. if not hasattr(self, 'vocab'):
  295. self.construct_vocab(*datasets)
  296. else:
  297. if self.verbose:
  298. print("Using constructed vocabulary with {} items.".format(len(self.vocab)))
  299. to_index_datasets = []
  300. if len(datasets) != 0:
  301. for dataset in datasets:
  302. assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
  303. to_index_datasets.append(dataset)
  304. if not (only_index_dataset is None):
  305. if isinstance(only_index_dataset, list):
  306. for dataset in only_index_dataset:
  307. assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
  308. to_index_datasets.append(dataset)
  309. elif isinstance(only_index_dataset, DataSet):
  310. to_index_datasets.append(only_index_dataset)
  311. else:
  312. raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset)))
  313. for dataset in to_index_datasets:
  314. assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
  315. dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]],
  316. new_field_name=self.new_added_field_name, is_input=self.is_input)
  317. # 只返回一个,infer时为了跟其他processor保持一致
  318. if len(to_index_datasets) == 1:
  319. return to_index_datasets[0]
  320. def set_vocab(self, vocab):
  321. assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab))
  322. self.vocab = vocab
  323. def delete_vocab(self):
  324. del self.vocab
  325. def get_vocab_size(self):
  326. return len(self.vocab)
  327. def set_verbose(self, verbose):
  328. """
  329. 设置processor verbose状态。
  330. :param verbose: int, 0,不输出任何信息;1,输出vocab 信息。
  331. :return:
  332. """
  333. self.verbose = verbose