You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

api.py 16 kB

7 years ago
6 years ago
6 years ago
7 years ago
7 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. import warnings
  2. import torch
  3. warnings.filterwarnings('ignore')
  4. import os
  5. from fastNLP.core.dataset import DataSet
  6. from .utils import load_url
  7. from .processor import ModelProcessor
  8. from fastNLP.io.dataset_loader import _cut_long_sentence, ConllLoader
  9. from fastNLP.core.instance import Instance
  10. from ..api.pipeline import Pipeline
  11. from fastNLP.core.metrics import SpanFPreRecMetric
  12. from .processor import IndexerProcessor
  13. # TODO add pretrain urls
  14. model_urls = {
  15. "cws": "http://123.206.98.91:8888/download/cws_lstm_ctb9_1_20-09908656.pkl",
  16. "pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl",
  17. "parser": "http://123.206.98.91:8888/download/parser_20190204-c72ca5c0.pkl"
  18. }
  19. class ConllCWSReader(object):
  20. """Deprecated. Use ConllLoader for all types of conll-format files."""
  21. def __init__(self):
  22. pass
  23. def load(self, path, cut_long_sent=False):
  24. """
  25. 返回的DataSet只包含raw_sentence这个field,内容为str。
  26. 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即
  27. ::
  28. 1 编者按 编者按 NN O 11 nmod:topic
  29. 2 : : PU O 11 punct
  30. 3 7月 7月 NT DATE 4 compound:nn
  31. 4 12日 12日 NT DATE 11 nmod:tmod
  32. 5 , , PU O 11 punct
  33. 1 这 这 DT O 3 det
  34. 2 款 款 M O 1 mark:clf
  35. 3 飞行 飞行 NN O 8 nsubj
  36. 4 从 从 P O 5 case
  37. 5 外型 外型 NN O 8 nmod:prep
  38. """
  39. datalist = []
  40. with open(path, 'r', encoding='utf-8') as f:
  41. sample = []
  42. for line in f:
  43. if line.startswith('\n'):
  44. datalist.append(sample)
  45. sample = []
  46. elif line.startswith('#'):
  47. continue
  48. else:
  49. sample.append(line.strip().split())
  50. if len(sample) > 0:
  51. datalist.append(sample)
  52. ds = DataSet()
  53. for sample in datalist:
  54. # print(sample)
  55. res = self.get_char_lst(sample)
  56. if res is None:
  57. continue
  58. line = ' '.join(res)
  59. if cut_long_sent:
  60. sents = _cut_long_sentence(line)
  61. else:
  62. sents = [line]
  63. for raw_sentence in sents:
  64. ds.append(Instance(raw_sentence=raw_sentence))
  65. return ds
  66. def get_char_lst(self, sample):
  67. if len(sample) == 0:
  68. return None
  69. text = []
  70. for w in sample:
  71. t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
  72. if t3 == '_':
  73. return None
  74. text.append(t1)
  75. return text
  76. class ConllxDataLoader(ConllLoader):
  77. """返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。
  78. Deprecated. Use ConllLoader for all types of conll-format files.
  79. """
  80. def __init__(self):
  81. headers = [
  82. 'words', 'pos_tags', 'heads', 'labels',
  83. ]
  84. indexs = [
  85. 1, 3, 6, 7,
  86. ]
  87. super(ConllxDataLoader, self).__init__(headers=headers, indexs=indexs)
  88. class API:
  89. def __init__(self):
  90. self.pipeline = None
  91. self._dict = None
  92. def predict(self, *args, **kwargs):
  93. """Do prediction for the given input.
  94. """
  95. raise NotImplementedError
  96. def test(self, file_path):
  97. """Test performance over the given data set.
  98. :param str file_path:
  99. :return: a dictionary of metric values
  100. """
  101. raise NotImplementedError
  102. def load(self, path, device):
  103. if os.path.exists(os.path.expanduser(path)):
  104. _dict = torch.load(path, map_location='cpu')
  105. else:
  106. _dict = load_url(path, map_location='cpu')
  107. self._dict = _dict
  108. self.pipeline = _dict['pipeline']
  109. for processor in self.pipeline.pipeline:
  110. if isinstance(processor, ModelProcessor):
  111. processor.set_model_device(device)
  112. class POS(API):
  113. """FastNLP API for Part-Of-Speech tagging.
  114. :param str model_path: the path to the model.
  115. :param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch.
  116. """
  117. def __init__(self, model_path=None, device='cpu'):
  118. super(POS, self).__init__()
  119. if model_path is None:
  120. model_path = model_urls['pos']
  121. self.load(model_path, device)
  122. def predict(self, content):
  123. """predict函数的介绍,
  124. 函数介绍的第二句,这句话不会换行
  125. :param content: list of list of str. Each string is a token(word).
  126. :return answer: list of list of str. Each string is a tag.
  127. """
  128. if not hasattr(self, "pipeline"):
  129. raise ValueError("You have to load model first.")
  130. sentence_list = content
  131. # 1. 检查sentence的类型
  132. for sentence in sentence_list:
  133. if not all((type(obj) == str for obj in sentence)):
  134. raise ValueError("Input must be list of list of string.")
  135. # 2. 组建dataset
  136. dataset = DataSet()
  137. dataset.add_field("words", sentence_list)
  138. # 3. 使用pipeline
  139. self.pipeline(dataset)
  140. def merge_tag(words_list, tags_list):
  141. rtn = []
  142. for words, tags in zip(words_list, tags_list):
  143. rtn.append([w + "/" + t for w, t in zip(words, tags)])
  144. return rtn
  145. output = dataset.field_arrays["tag"].content
  146. if isinstance(content, str):
  147. return output[0]
  148. elif isinstance(content, list):
  149. return merge_tag(content, output)
  150. def test(self, file_path):
  151. test_data = ConllxDataLoader().load(file_path)
  152. save_dict = self._dict
  153. tag_vocab = save_dict["tag_vocab"]
  154. pipeline = save_dict["pipeline"]
  155. index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False)
  156. pipeline.pipeline = [index_tag] + pipeline.pipeline
  157. test_data.rename_field("pos_tags", "tag")
  158. pipeline(test_data)
  159. test_data.set_target("truth")
  160. prediction = test_data.field_arrays["predict"].content
  161. truth = test_data.field_arrays["truth"].content
  162. seq_len = test_data.field_arrays["word_seq_origin_len"].content
  163. # padding by hand
  164. max_length = max([len(seq) for seq in prediction])
  165. for idx in range(len(prediction)):
  166. prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx])))
  167. truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx])))
  168. evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth",
  169. seq_len="word_seq_origin_len")
  170. evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)},
  171. {"truth": torch.Tensor(truth)})
  172. test_result = evaluator.get_metric()
  173. f1 = round(test_result['f'] * 100, 2)
  174. pre = round(test_result['pre'] * 100, 2)
  175. rec = round(test_result['rec'] * 100, 2)
  176. return {"F1": f1, "precision": pre, "recall": rec}
  177. class CWS(API):
  178. """
  179. 中文分词高级接口。
  180. :param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型
  181. :param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。
  182. """
  183. def __init__(self, model_path=None, device='cpu'):
  184. super(CWS, self).__init__()
  185. if model_path is None:
  186. model_path = model_urls['cws']
  187. self.load(model_path, device)
  188. def predict(self, content):
  189. """
  190. 分词接口。
  191. :param content: str或List[str], 例如: "中文分词很重要!", 返回的结果是"中文 分词 很 重要 !"。 如果传入的为List[str],比如
  192. [ "中文分词很重要!", ...], 返回的结果["中文 分词 很 重要 !", ...]。
  193. :return: str或List[str], 根据输入的的类型决定。
  194. """
  195. if not hasattr(self, 'pipeline'):
  196. raise ValueError("You have to load model first.")
  197. sentence_list = []
  198. # 1. 检查sentence的类型
  199. if isinstance(content, str):
  200. sentence_list.append(content)
  201. elif isinstance(content, list):
  202. sentence_list = content
  203. # 2. 组建dataset
  204. dataset = DataSet()
  205. dataset.add_field('raw_sentence', sentence_list)
  206. # 3. 使用pipeline
  207. self.pipeline(dataset)
  208. output = dataset.get_field('output').content
  209. if isinstance(content, str):
  210. return output[0]
  211. elif isinstance(content, list):
  212. return output
  213. def test(self, filepath):
  214. """
  215. 传入一个分词文件路径,返回该数据集上分词f1, precision, recall。
  216. 分词文件应该为::
  217. 1 编者按 编者按 NN O 11 nmod:topic
  218. 2 : : PU O 11 punct
  219. 3 7月 7月 NT DATE 4 compound:nn
  220. 4 12日 12日 NT DATE 11 nmod:tmod
  221. 5 , , PU O 11 punct
  222. 1 这 这 DT O 3 det
  223. 2 款 款 M O 1 mark:clf
  224. 3 飞行 飞行 NN O 8 nsubj
  225. 4 从 从 P O 5 case
  226. 5 外型 外型 NN O 8 nmod:prep
  227. 以空行分割两个句子,有内容的每行有7列。
  228. :param filepath: str, 文件路径路径。
  229. :return: float, float, float. 分别f1, precision, recall.
  230. """
  231. tag_proc = self._dict['tag_proc']
  232. cws_model = self.pipeline.pipeline[-2].model
  233. pipeline = self.pipeline.pipeline[:-2]
  234. pipeline.insert(1, tag_proc)
  235. pp = Pipeline(pipeline)
  236. reader = ConllCWSReader()
  237. # te_filename = '/home/hyan/ctb3/test.conllx'
  238. te_dataset = reader.load(filepath)
  239. pp(te_dataset)
  240. from ..core.tester import Tester
  241. from ..core.metrics import SpanFPreRecMetric
  242. tester = Tester(data=te_dataset, model=cws_model, metrics=SpanFPreRecMetric(tag_proc.get_vocab()), batch_size=64,
  243. verbose=0)
  244. eval_res = tester.test()
  245. f1 = eval_res['SpanFPreRecMetric']['f']
  246. pre = eval_res['SpanFPreRecMetric']['pre']
  247. rec = eval_res['SpanFPreRecMetric']['rec']
  248. # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
  249. return {"F1": f1, "precision": pre, "recall": rec}
  250. class Parser(API):
  251. def __init__(self, model_path=None, device='cpu'):
  252. super(Parser, self).__init__()
  253. if model_path is None:
  254. model_path = model_urls['parser']
  255. self.pos_tagger = POS(device=device)
  256. self.load(model_path, device)
  257. def predict(self, content):
  258. if not hasattr(self, 'pipeline'):
  259. raise ValueError("You have to load model first.")
  260. # 1. 利用POS得到分词和pos tagging结果
  261. pos_out = self.pos_tagger.predict(content)
  262. # pos_out = ['这里/NN 是/VB 分词/NN 结果/NN'.split()]
  263. # 2. 组建dataset
  264. dataset = DataSet()
  265. dataset.add_field('wp', pos_out)
  266. dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words')
  267. dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos')
  268. dataset.rename_field("words", "raw_words")
  269. # 3. 使用pipeline
  270. self.pipeline(dataset)
  271. dataset.apply(lambda x: [str(arc) for arc in x['arc_pred']], new_field_name='arc_pred')
  272. dataset.apply(lambda x: [arc + '/' + label for arc, label in
  273. zip(x['arc_pred'], x['label_pred_seq'])][1:], new_field_name='output')
  274. # output like: [['2/top', '0/root', '4/nn', '2/dep']]
  275. return dataset.field_arrays['output'].content
  276. def load_test_file(self, path):
  277. def get_one(sample):
  278. sample = list(map(list, zip(*sample)))
  279. if len(sample) == 0:
  280. return None
  281. for w in sample[7]:
  282. if w == '_':
  283. print('Error Sample {}'.format(sample))
  284. return None
  285. # return word_seq, pos_seq, head_seq, head_tag_seq
  286. return sample[1], sample[3], list(map(int, sample[6])), sample[7]
  287. datalist = []
  288. with open(path, 'r', encoding='utf-8') as f:
  289. sample = []
  290. for line in f:
  291. if line.startswith('\n'):
  292. datalist.append(sample)
  293. sample = []
  294. elif line.startswith('#'):
  295. continue
  296. else:
  297. sample.append(line.split('\t'))
  298. if len(sample) > 0:
  299. datalist.append(sample)
  300. data = [get_one(sample) for sample in datalist]
  301. data_list = list(filter(lambda x: x is not None, data))
  302. return data_list
  303. def test(self, filepath):
  304. data = self.load_test_file(filepath)
  305. def convert(data):
  306. BOS = '<BOS>'
  307. dataset = DataSet()
  308. for sample in data:
  309. word_seq = [BOS] + sample[0]
  310. pos_seq = [BOS] + sample[1]
  311. heads = [0] + sample[2]
  312. head_tags = [BOS] + sample[3]
  313. dataset.append(Instance(raw_words=word_seq,
  314. pos=pos_seq,
  315. gold_heads=heads,
  316. arc_true=heads,
  317. tags=head_tags))
  318. return dataset
  319. ds = convert(data)
  320. pp = self.pipeline
  321. for p in pp:
  322. if p.field_name == 'word_list':
  323. p.field_name = 'gold_words'
  324. elif p.field_name == 'pos_list':
  325. p.field_name = 'gold_pos'
  326. # ds.rename_field("words", "raw_words")
  327. # ds.rename_field("tag", "pos")
  328. pp(ds)
  329. head_cor, label_cor, total = 0, 0, 0
  330. for ins in ds:
  331. head_gold = ins['gold_heads']
  332. head_pred = ins['arc_pred']
  333. length = len(head_gold)
  334. total += length
  335. for i in range(length):
  336. head_cor += 1 if head_pred[i] == head_gold[i] else 0
  337. uas = head_cor / total
  338. # print('uas:{:.2f}'.format(uas))
  339. for p in pp:
  340. if p.field_name == 'gold_words':
  341. p.field_name = 'word_list'
  342. elif p.field_name == 'gold_pos':
  343. p.field_name = 'pos_list'
  344. return {"USA": round(uas, 5)}
  345. class Analyzer:
  346. def __init__(self, device='cpu'):
  347. self.cws = CWS(device=device)
  348. self.pos = POS(device=device)
  349. self.parser = Parser(device=device)
  350. def predict(self, content, seg=False, pos=False, parser=False):
  351. if seg is False and pos is False and parser is False:
  352. seg = True
  353. output_dict = {}
  354. if seg:
  355. seg_output = self.cws.predict(content)
  356. output_dict['seg'] = seg_output
  357. if pos:
  358. pos_output = self.pos.predict(content)
  359. output_dict['pos'] = pos_output
  360. if parser:
  361. parser_output = self.parser.predict(content)
  362. output_dict['parser'] = parser_output
  363. return output_dict
  364. def test(self, filepath):
  365. output_dict = {}
  366. if self.cws:
  367. seg_output = self.cws.test(filepath)
  368. output_dict['seg'] = seg_output
  369. if self.pos:
  370. pos_output = self.pos.test(filepath)
  371. output_dict['pos'] = pos_output
  372. if self.parser:
  373. parser_output = self.parser.test(filepath)
  374. output_dict['parser'] = parser_output
  375. return output_dict