You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

processor.py 17 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
Dev0.4.0 (#149) * 1. CRF增加支持bmeso类型的tag 2. vocabulary中增加注释 * BucketSampler增加一条错误检测 * 1.修改ClipGradientCallback的bug;删除LRSchedulerCallback中的print,之后应该传入pbar进行打印;2.增加MLP注释 * update MLP module * 增加metric注释;修改trainer save过程中的bug * Update README.md fix tutorial link * Add ENAS (Efficient Neural Architecture Search) * add ignore_type in DataSet.add_field * * AutoPadder will not pad when dtype is None * add ignore_type in DataSet.apply * 修复fieldarray中padder潜在bug * 修复crf中typo; 以及可能导致数值不稳定的地方 * 修复CRF中可能存在的bug * change two default init arguments of Trainer into None * Changes to Callbacks: * 给callback添加给定几个只读属性 * 通过manager设置这些属性 * 代码优化,减轻@transfer的负担 * * 将enas相关代码放到automl目录下 * 修复fast_param_mapping的一个bug * Trainer添加自动创建save目录 * Vocabulary的打印,显示内容 * * 给vocabulary添加遍历方法 * 修复CRF为负数的bug * add SQuAD metric * add sigmoid activate function in MLP * - add star transformer model - add ConllLoader, for all kinds of conll-format files - add JsonLoader, for json-format files - add SSTLoader, for SST-2 & SST-5 - change Callback interface - fix batch multi-process when killed - add README to list models and their performance * - fix test * - fix callback & tests * - update README * 修改部分bug;调整callback * 准备发布0.4.0版本“ * update readme * support parallel loss * 防止多卡的情况导致无法正确计算loss“ * update advance_tutorial jupyter notebook * 1. 在embedding_loader中增加新的读取函数load_with_vocab(), load_without_vocab, 比之前的函数改变主要在(1)不再需要传入embed_dim(2)自动判断当前是word2vec还是glove. 2. vocabulary增加from_dataset(), index_dataset()函数。避免需要多行写index dataset的问题。 3. 在utils中新增一个cache_result()修饰器,用于cache函数的返回值。 4. callback中新增update_every属性 * 1.DataSet.apply()报错时提供错误的index 2.Vocabulary.from_dataset(), index_dataset()提供报错时的vocab顺序 3.embedloader在embed读取时遇到不规则的数据跳过这一行. * update attention * doc tools * fix some doc errors * 修改为中文注释,增加viterbi解码方法 * 样例版本 * - add pad sequence for lstm - add csv, conll, json filereader - update dataloader - remove useless dataloader - fix trainer loss print - fix tests * - fix test_tutorial * 注释增加 * 测试文档 * 本地暂存 * 本地暂存 * 修改文档的顺序 * - add document * 本地暂存 * update pooling * update bert * update documents in MLP * update documents in snli * combine self attention module to attention.py * update documents on losses.py * 对DataSet的文档进行更新 * update documents on metrics * 1. 删除了LSTM中print的内容; 2. 将Trainer和Tester的use_cuda修改为了device; 3.补充Trainer的文档 * 增加对Trainer的注释 * 完善了trainer,callback等的文档; 修改了部分代码的命名以使得代码从文档中隐藏 * update char level encoder * update documents on embedding.py * - update doc * 补充注释,并修改部分代码 * - update doc - add get_embeddings * 修改了文档配置项 * 修改embedding为init_embed初始化 * 1.增加对Trainer和Tester的多卡支持; * - add test - fix jsonloader * 删除了注释教程 * 给 dataset 增加了get_field_names * 修复bug * - add Const - fix bugs * 修改部分注释 * - add model runner for easier test models - add model tests * 修改了 docs 的配置和架构 * 修改了核心部分的一大部分文档,TODO: 1. 完善 trainer 和 tester 部分的文档 2. 研究注释样例与测试 * core部分的注释基本检查完成 * 修改了 io 部分的注释 * 全部改为相对路径引用 * 全部改为相对路径引用 * small change * 1. 从安装文件中删除api/automl的安装 2. metric中存在seq_len的bug 3. sampler中存在命名错误,已修改 * 修复 bug :兼容 cpu 版本的 PyTorch TODO:其它地方可能也存在类似的 bug * 修改文档中的引用部分 * 把 tqdm.autonotebook 换成tqdm.auto * - fix batch & vocab * 上传了文档文件 *.rst * 上传了文档文件和若干 TODO * 讨论并整合了若干模块 * core部分的测试和一些小修改 * 删除了一些冗余文档 * update init files * update const files * update const files * 增加cnn的测试 * fix a little bug * - update attention - fix tests * 完善测试 * 完成快速入门教程 * 修改了sequence_modeling 命名为 sequence_labeling 的文档 * 重新 apidoc 解决改名的遗留问题 * 修改文档格式 * 统一不同位置的seq_len_to_mask, 现统一到core.utils.seq_len_to_mask * 增加了一行提示 * 在文档中展示 dataset_loader * 提示 Dataset.read_csv 会被 CSVLoader 替换 * 完成 Callback 和 Trainer 之间的文档 * index更新了部分 * 删除冗余的print * 删除用于分词的metric,因为有可能引起错误 * 修改文档中的中文名称 * 完成了详细介绍文档 * tutorial 的 ipynb 文件 * 修改了一些介绍文档 * 修改了 models 和 modules 的主页介绍 * 加上了 titlesonly 这个设置 * 修改了模块文档展示的标题 * 修改了 core 和 io 的开篇介绍 * 修改了 modules 和 models 开篇介绍 * 使用 .. todo:: 隐藏了可能被抽到文档中的 TODO 注释 * 修改了一些注释 * delete an old metric in test * 修改 tutorials 的测试文件 * 把暂不发布的功能移到 legacy 文件夹 * 删除了不能运行的测试 * 修改 callback 的测试文件 * 删除了过时的教程和测试文件 * cache_results 参数的修改 * 修改 io 的测试文件; 删除了一些过时的测试 * 修复bug * 修复无法通过test_utils.py的测试 * 修复与pytorch1.1中的padsequence的兼容问题; 修改Trainer的pbar * 1. 修复metric中的bug; 2.增加metric测试 * add model summary * 增加别名 * 删除encoder中的嵌套层 * 修改了 core 部分 import 的顺序,__all__ 暴露的内容 * 修改了 models 部分 import 的顺序,__all__ 暴露的内容 * 修改了文件名 * 修改了 modules 模块的__all__ 和 import * fix var runn * 增加vocab的clear方法 * 一些符合 PEP8 的微调 * 更新了cache_results的例子 * 1. 对callback中indices潜在None作出提示;2.DataSet支持通过List进行index * 修改了一个typo * 修改了 README.md * update documents on bert * update documents on encoder/bert * 增加一个fitlog callback,实现与fitlog实验记录 * typo * - update dataset_loader * 增加了到 fitlog 文档的链接。 * 增加了 DataSet Loader 的文档 * - add star-transformer reproduction
6 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. import re
  2. from collections import defaultdict
  3. import torch
  4. from fastNLP.core.batch import Batch
  5. from fastNLP.core.dataset import DataSet
  6. from fastNLP.core.sampler import SequentialSampler
  7. from fastNLP.core.vocabulary import Vocabulary
  8. class Processor(object):
  9. def __init__(self, field_name, new_added_field_name):
  10. """
  11. :param field_name: 处理哪个field
  12. :param new_added_field_name: 如果为None,则认为是field_name,即覆盖原有的field
  13. """
  14. self.field_name = field_name
  15. if new_added_field_name is None:
  16. self.new_added_field_name = field_name
  17. else:
  18. self.new_added_field_name = new_added_field_name
  19. def process(self, *args, **kwargs):
  20. raise NotImplementedError
  21. def __call__(self, *args, **kwargs):
  22. return self.process(*args, **kwargs)
  23. class FullSpaceToHalfSpaceProcessor(Processor):
  24. """全角转半角,以字符为处理单元
  25. """
  26. def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True,
  27. change_space=True):
  28. super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None)
  29. self.change_alpha = change_alpha
  30. self.change_digit = change_digit
  31. self.change_punctuation = change_punctuation
  32. self.change_space = change_space
  33. FH_SPACE = [(u" ", u" ")]
  34. FH_NUM = [
  35. (u"0", u"0"), (u"1", u"1"), (u"2", u"2"), (u"3", u"3"), (u"4", u"4"),
  36. (u"5", u"5"), (u"6", u"6"), (u"7", u"7"), (u"8", u"8"), (u"9", u"9")]
  37. FH_ALPHA = [
  38. (u"a", u"a"), (u"b", u"b"), (u"c", u"c"), (u"d", u"d"), (u"e", u"e"),
  39. (u"f", u"f"), (u"g", u"g"), (u"h", u"h"), (u"i", u"i"), (u"j", u"j"),
  40. (u"k", u"k"), (u"l", u"l"), (u"m", u"m"), (u"n", u"n"), (u"o", u"o"),
  41. (u"p", u"p"), (u"q", u"q"), (u"r", u"r"), (u"s", u"s"), (u"t", u"t"),
  42. (u"u", u"u"), (u"v", u"v"), (u"w", u"w"), (u"x", u"x"), (u"y", u"y"),
  43. (u"z", u"z"),
  44. (u"A", u"A"), (u"B", u"B"), (u"C", u"C"), (u"D", u"D"), (u"E", u"E"),
  45. (u"F", u"F"), (u"G", u"G"), (u"H", u"H"), (u"I", u"I"), (u"J", u"J"),
  46. (u"K", u"K"), (u"L", u"L"), (u"M", u"M"), (u"N", u"N"), (u"O", u"O"),
  47. (u"P", u"P"), (u"Q", u"Q"), (u"R", u"R"), (u"S", u"S"), (u"T", u"T"),
  48. (u"U", u"U"), (u"V", u"V"), (u"W", u"W"), (u"X", u"X"), (u"Y", u"Y"),
  49. (u"Z", u"Z")]
  50. # 谨慎使用标点符号转换, 因为"5.12特大地震"转换后可能就成了"5.12特大地震"
  51. FH_PUNCTUATION = [
  52. (u'%', u'%'), (u'!', u'!'), (u'"', u'\"'), (u''', u'\''), (u'#', u'#'),
  53. (u'¥', u'$'), (u'&', u'&'), (u'(', u'('), (u')', u')'), (u'*', u'*'),
  54. (u'+', u'+'), (u',', u','), (u'-', u'-'), (u'.', u'.'), (u'/', u'/'),
  55. (u':', u':'), (u';', u';'), (u'<', u'<'), (u'=', u'='), (u'>', u'>'),
  56. (u'?', u'?'), (u'@', u'@'), (u'[', u'['), (u']', u']'), (u'\', u'\\'),
  57. (u'^', u'^'), (u'_', u'_'), (u'`', u'`'), (u'~', u'~'), (u'{', u'{'),
  58. (u'}', u'}'), (u'|', u'|')]
  59. FHs = []
  60. if self.change_alpha:
  61. FHs = FH_ALPHA
  62. if self.change_digit:
  63. FHs += FH_NUM
  64. if self.change_punctuation:
  65. FHs += FH_PUNCTUATION
  66. if self.change_space:
  67. FHs += FH_SPACE
  68. self.convert_map = {k: v for k, v in FHs}
  69. def process(self, dataset):
  70. assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
  71. def inner_proc(ins):
  72. sentence = ins[self.field_name]
  73. new_sentence = [""] * len(sentence)
  74. for idx, char in enumerate(sentence):
  75. if char in self.convert_map:
  76. char = self.convert_map[char]
  77. new_sentence[idx] = char
  78. return "".join(new_sentence)
  79. dataset.apply(inner_proc, new_field_name=self.field_name)
  80. return dataset
  81. class PreAppendProcessor(Processor):
  82. """
  83. 向某个field的起始增加data(应该为str类型)。该field需要为list类型。即新增的field为
  84. [data] + instance[field_name]
  85. """
  86. def __init__(self, data, field_name, new_added_field_name=None):
  87. super(PreAppendProcessor, self).__init__(field_name, new_added_field_name)
  88. self.data = data
  89. def process(self, dataset):
  90. dataset.apply(lambda ins: [self.data] + ins[self.field_name], new_field_name=self.new_added_field_name)
  91. return dataset
  92. class SliceProcessor(Processor):
  93. """
  94. 从某个field中只取部分内容。等价于instance[field_name][start:end:step]
  95. """
  96. def __init__(self, start, end, step, field_name, new_added_field_name=None):
  97. super(SliceProcessor, self).__init__(field_name, new_added_field_name)
  98. for o in (start, end, step):
  99. assert isinstance(o, int) or o is None
  100. self.slice = slice(start, end, step)
  101. def process(self, dataset):
  102. dataset.apply(lambda ins: ins[self.field_name][self.slice], new_field_name=self.new_added_field_name)
  103. return dataset
  104. class Num2TagProcessor(Processor):
  105. """
  106. 将一句话中的数字转换为某个tag。
  107. """
  108. def __init__(self, tag, field_name, new_added_field_name=None):
  109. """
  110. :param tag: str, 将数字转换为该tag
  111. :param field_name:
  112. :param new_added_field_name:
  113. """
  114. super(Num2TagProcessor, self).__init__(field_name, new_added_field_name)
  115. self.tag = tag
  116. self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)'
  117. def process(self, dataset):
  118. def inner_proc(ins):
  119. s = ins[self.field_name]
  120. new_s = [None] * len(s)
  121. for i, w in enumerate(s):
  122. if re.search(self.pattern, w) is not None:
  123. w = self.tag
  124. new_s[i] = w
  125. return new_s
  126. dataset.apply(inner_proc, new_field_name=self.new_added_field_name)
  127. return dataset
  128. class IndexerProcessor(Processor):
  129. """
  130. 给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如
  131. ['我', '是', xxx]
  132. """
  133. def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True):
  134. assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
  135. super(IndexerProcessor, self).__init__(field_name, new_added_field_name)
  136. self.vocab = vocab
  137. self.delete_old_field = delete_old_field
  138. self.is_input = is_input
  139. def set_vocab(self, vocab):
  140. assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
  141. self.vocab = vocab
  142. def process(self, dataset):
  143. assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
  144. dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]],
  145. new_field_name=self.new_added_field_name)
  146. if self.is_input:
  147. dataset.set_input(self.new_added_field_name)
  148. if self.delete_old_field:
  149. dataset.delete_field(self.field_name)
  150. return dataset
  151. class VocabProcessor(Processor):
  152. """
  153. 传入若干个DataSet以建立vocabulary。
  154. """
  155. def __init__(self, field_name, min_freq=1, max_size=None):
  156. super(VocabProcessor, self).__init__(field_name, None)
  157. self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size)
  158. def process(self, *datasets):
  159. for dataset in datasets:
  160. assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
  161. dataset.apply(lambda ins: self.vocab.update(ins[self.field_name]))
  162. def get_vocab(self):
  163. self.vocab.build_vocab()
  164. return self.vocab
  165. class SeqLenProcessor(Processor):
  166. """
  167. 根据某个field新增一个sequence length的field。取该field的第一维
  168. """
  169. def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True):
  170. super(SeqLenProcessor, self).__init__(field_name, new_added_field_name)
  171. self.is_input = is_input
  172. def process(self, dataset):
  173. assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
  174. dataset.apply(lambda ins: len(ins[self.field_name]), new_field_name=self.new_added_field_name)
  175. if self.is_input:
  176. dataset.set_input(self.new_added_field_name)
  177. return dataset
  178. from fastNLP.core.utils import _build_args
  179. class ModelProcessor(Processor):
  180. def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32):
  181. """
  182. 传入一个model,在process()时传入一个dataset,该processor会通过Batch将DataSet的内容输出给model.predict或者model.forward.
  183. model输出的内容会被增加到dataset中,field_name由model输出决定。如果生成的内容维度不是(Batch_size, )与
  184. (Batch_size, 1),则使用seqence length这个field进行unpad
  185. TODO 这个类需要删除对seq_lens的依赖。
  186. :param seq_len_field_name:
  187. :param batch_size:
  188. """
  189. super(ModelProcessor, self).__init__(None, None)
  190. self.batch_size = batch_size
  191. self.seq_len_field_name = seq_len_field_name
  192. self.model = model
  193. def process(self, dataset):
  194. self.model.eval()
  195. assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
  196. data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler())
  197. batch_output = defaultdict(list)
  198. predict_func = self.model.forward
  199. with torch.no_grad():
  200. for batch_x, _ in data_iterator:
  201. refined_batch_x = _build_args(predict_func, **batch_x)
  202. prediction = predict_func(**refined_batch_x)
  203. seq_lens = batch_x[self.seq_len_field_name].tolist()
  204. for key, value in prediction.items():
  205. tmp_batch = []
  206. value = value.cpu().numpy()
  207. if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1):
  208. batch_output[key].extend(value.tolist())
  209. else:
  210. for idx, seq_len in enumerate(seq_lens):
  211. tmp_batch.append(value[idx, :seq_len])
  212. batch_output[key].extend(tmp_batch)
  213. if not self.seq_len_field_name in prediction:
  214. batch_output[self.seq_len_field_name].extend(seq_lens)
  215. # TODO 当前的实现会导致之后的processor需要知道model输出的output的key是什么
  216. for field_name, fields in batch_output.items():
  217. dataset.add_field(field_name, fields, is_input=True, is_target=False)
  218. return dataset
  219. def set_model(self, model):
  220. self.model = model
  221. def set_model_device(self, device):
  222. device = torch.device(device)
  223. self.model.to(device)
  224. class Index2WordProcessor(Processor):
  225. """
  226. 将DataSet中某个为index的field根据vocab转换为str
  227. """
  228. def __init__(self, vocab, field_name, new_added_field_name):
  229. super(Index2WordProcessor, self).__init__(field_name, new_added_field_name)
  230. self.vocab = vocab
  231. def process(self, dataset):
  232. dataset.apply(lambda ins: [self.vocab.to_word(w) for w in ins[self.field_name]],
  233. new_field_name=self.new_added_field_name)
  234. return dataset
  235. class SetTargetProcessor(Processor):
  236. def __init__(self, *fields, flag=True):
  237. super(SetTargetProcessor, self).__init__(None, None)
  238. self.fields = fields
  239. self.flag = flag
  240. def process(self, dataset):
  241. dataset.set_target(*self.fields, flag=self.flag)
  242. return dataset
  243. class SetInputProcessor(Processor):
  244. def __init__(self, *fields, flag=True):
  245. super(SetInputProcessor, self).__init__(None, None)
  246. self.fields = fields
  247. self.flag = flag
  248. def process(self, dataset):
  249. dataset.set_input(*self.fields, flag=self.flag)
  250. return dataset
  251. class VocabIndexerProcessor(Processor):
  252. """
  253. 根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供
  254. new_added_field_name, 则覆盖原有的field_name.
  255. """
  256. def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None,
  257. verbose=0, is_input=True):
  258. """
  259. :param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作
  260. :param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name.
  261. :param min_freq: 创建的Vocabulary允许的单词最少出现次数.
  262. :param max_size: 创建的Vocabulary允许的最大的单词数量
  263. :param verbose: 0, 不输出任何信息;1,输出信息
  264. :param bool is_input:
  265. """
  266. super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name)
  267. self.min_freq = min_freq
  268. self.max_size = max_size
  269. self.verbose = verbose
  270. self.is_input = is_input
  271. def construct_vocab(self, *datasets):
  272. """
  273. 使用传入的DataSet创建vocabulary
  274. :param datasets: DataSet类型的数据,用于构建vocabulary
  275. :return:
  276. """
  277. self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size)
  278. for dataset in datasets:
  279. assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
  280. dataset.apply(lambda ins: self.vocab.update(ins[self.field_name]))
  281. self.vocab.build_vocab()
  282. if self.verbose:
  283. print("Vocabulary Constructed, has {} items.".format(len(self.vocab)))
  284. def process(self, *datasets, only_index_dataset=None):
  285. """
  286. 若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary
  287. 后,则会index datasets与only_index_dataset。
  288. :param datasets: DataSet类型的数据
  289. :param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。
  290. :return:
  291. """
  292. if len(datasets) == 0 and not hasattr(self, 'vocab'):
  293. raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.")
  294. if not hasattr(self, 'vocab'):
  295. self.construct_vocab(*datasets)
  296. else:
  297. if self.verbose:
  298. print("Using constructed vocabulary with {} items.".format(len(self.vocab)))
  299. to_index_datasets = []
  300. if len(datasets) != 0:
  301. for dataset in datasets:
  302. assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
  303. to_index_datasets.append(dataset)
  304. if not (only_index_dataset is None):
  305. if isinstance(only_index_dataset, list):
  306. for dataset in only_index_dataset:
  307. assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
  308. to_index_datasets.append(dataset)
  309. elif isinstance(only_index_dataset, DataSet):
  310. to_index_datasets.append(only_index_dataset)
  311. else:
  312. raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset)))
  313. for dataset in to_index_datasets:
  314. assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
  315. dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]],
  316. new_field_name=self.new_added_field_name, is_input=self.is_input)
  317. # 只返回一个,infer时为了跟其他processor保持一致
  318. if len(to_index_datasets) == 1:
  319. return to_index_datasets[0]
  320. def set_vocab(self, vocab):
  321. assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab))
  322. self.vocab = vocab
  323. def delete_vocab(self):
  324. del self.vocab
  325. def get_vocab_size(self):
  326. return len(self.vocab)
  327. def set_verbose(self, verbose):
  328. """
  329. 设置processor verbose状态。
  330. :param verbose: int, 0,不输出任何信息;1,输出vocab 信息。
  331. :return:
  332. """
  333. self.verbose = verbose