| @@ -6,6 +6,7 @@ warnings.filterwarnings('ignore') | |||
| import os | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.api.model_zoo import load_url | |||
| from fastNLP.api.processor import ModelProcessor | |||
| from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader | |||
| @@ -120,7 +121,7 @@ class POS(API): | |||
| f1 = round(test_result['F'] * 100, 2) | |||
| pre = round(test_result['P'] * 100, 2) | |||
| rec = round(test_result['R'] * 100, 2) | |||
| print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | |||
| # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | |||
| return f1, pre, rec | |||
| @@ -179,7 +180,7 @@ class CWS(API): | |||
| f1 = round(f1 * 100, 2) | |||
| pre = round(pre * 100, 2) | |||
| rec = round(rec * 100, 2) | |||
| print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | |||
| # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | |||
| return f1, pre, rec | |||
| @@ -251,30 +252,23 @@ class Parser(API): | |||
| class Analyzer: | |||
| def __init__(self, seg=True, pos=True, parser=True, device='cpu'): | |||
| self.seg = seg | |||
| self.pos = pos | |||
| self.parser = parser | |||
| def __init__(self, device='cpu'): | |||
| if self.seg: | |||
| self.cws = CWS(device=device) | |||
| if self.pos: | |||
| self.pos = POS(device=device) | |||
| if parser: | |||
| self.parser = None | |||
| self.cws = CWS(device=device) | |||
| self.pos = POS(device=device) | |||
| self.parser = Parser(device=device) | |||
| def predict(self, content, seg=False, pos=False, parser=False): | |||
| if seg is False and pos is False and parser is False: | |||
| seg = True | |||
| output_dict = {} | |||
| if self.seg: | |||
| if seg: | |||
| seg_output = self.cws.predict(content) | |||
| output_dict['seg'] = seg_output | |||
| if self.pos: | |||
| if pos: | |||
| pos_output = self.pos.predict(content) | |||
| output_dict['pos'] = pos_output | |||
| if self.parser: | |||
| if parser: | |||
| parser_output = self.parser.predict(content) | |||
| output_dict['parser'] = parser_output | |||
| @@ -301,7 +295,7 @@ if __name__ == "__main__": | |||
| # s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' , | |||
| # '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
| # '那么这款无人机到底有多厉害?'] | |||
| # print(pos.test('/Users/yh/Desktop/test_data/small_test.conll')) | |||
| # print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll')) | |||
| # print(pos.predict(s)) | |||
| # cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' | |||
| @@ -317,7 +311,4 @@ if __name__ == "__main__": | |||
| s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||
| '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
| '那么这款无人机到底有多厉害?'] | |||
| print(cws.test('/Users/yh/Desktop/test_data/small_test.conll')) | |||
| print(cws.predict(s)) | |||
| print(parser.predict(s)) | |||
| @@ -313,9 +313,14 @@ class DataSet(object): | |||
| for col in headers: | |||
| _dict[col] = [] | |||
| for line_idx, line in enumerate(f, start_idx): | |||
| contents = line.rstrip('\r\n').split(sep) | |||
| assert len(contents)==len(headers), "Line {} has {} parts, while header has {}."\ | |||
| .format(line_idx, len(contents), len(headers)) | |||
| contents = line.split(sep) | |||
| if len(contents)!=len(headers): | |||
| if dropna: | |||
| continue | |||
| else: | |||
| #TODO change error type | |||
| raise ValueError("Line {} has {} parts, while header has {} parts."\ | |||
| .format(line_idx, len(contents), len(headers))) | |||
| for header, content in zip(headers, contents): | |||
| _dict[header].append(content) | |||
| return cls(_dict) | |||
| @@ -38,15 +38,15 @@ class SeqLabelEvaluator(Evaluator): | |||
| def __call__(self, predict, truth, **_): | |||
| """ | |||
| :param predict: list of dict, the network outputs from all batches. | |||
| :param predict: list of List, the network outputs from all batches. | |||
| :param truth: list of dict, the ground truths from all batch_y. | |||
| :return accuracy: | |||
| """ | |||
| total_correct, total_count = 0., 0. | |||
| total_correct, total_count = 0., 0. | |||
| for x, y in zip(predict, truth): | |||
| # x = torch.tensor(x) | |||
| x = torch.tensor(x) | |||
| y = y.to(x) # make sure they are in the same device | |||
| mask = (y > 0) | |||
| mask = (y > 0) | |||
| correct = torch.sum(((x == y) * mask).long()) | |||
| total_correct += float(correct) | |||
| total_count += float(torch.sum(mask.long())) | |||
| @@ -4,6 +4,7 @@ from datetime import datetime | |||
| import warnings | |||
| from collections import defaultdict | |||
| import os | |||
| import itertools | |||
| import shutil | |||
| from tensorboardX import SummaryWriter | |||
| @@ -121,10 +122,7 @@ class Trainer(object): | |||
| for batch_x, batch_y in data_iterator: | |||
| prediction = self.data_forward(model, batch_x) | |||
| # TODO: refactor self.get_loss | |||
| loss = prediction["loss"] if "loss" in prediction else self.get_loss(prediction, batch_y) | |||
| # acc = self._evaluator([{"predict": prediction["predict"]}], [{"truth": batch_x["truth"]}]) | |||
| loss = self.get_loss(prediction, batch_y) | |||
| self.grad_backward(loss) | |||
| self.update() | |||
| self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||
| @@ -133,7 +131,7 @@ class Trainer(object): | |||
| self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) | |||
| # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | |||
| # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | |||
| if n_print > 0 and self.step % n_print == 0: | |||
| if self.print_every > 0 and self.step % self.print_every == 0: | |||
| end = time.time() | |||
| diff = timedelta(seconds=round(end - start)) | |||
| print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||
| @@ -241,7 +239,7 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No | |||
| batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
| for batch_count, (batch_x, batch_y) in enumerate(batch): | |||
| _syn_model_data(model, batch_x, batch_y) | |||
| _syn_model_data(model, batch_x, batch_y) | |||
| # forward check | |||
| if batch_count==0: | |||
| _check_forward_error(model_func=model.forward, check_level=check_level, | |||
| @@ -269,7 +267,8 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No | |||
| model_name, loss.size() | |||
| )) | |||
| loss.backward() | |||
| if batch_count + 1 >= DEFAULT_CHECK_BATCH_SIZE: | |||
| model.zero_grad() | |||
| if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: | |||
| break | |||
| if dev_data is not None: | |||
| @@ -1,4 +1,3 @@ | |||
| import numpy as np | |||
| import torch | |||
| import numpy as np | |||
| @@ -141,7 +140,6 @@ class AdvSeqLabel(SeqLabeling): | |||
| idx_sort = idx_sort.cuda() | |||
| idx_unsort = idx_unsort.cuda() | |||
| self.mask = self.mask.cuda() | |||
| truth = truth.cuda() if truth is not None else None | |||
| x = self.Embedding(word_seq) | |||
| x = self.norm1(x) | |||
| @@ -36,4 +36,4 @@ pickle_path = "./save/" | |||
| use_crf = true | |||
| use_cuda = true | |||
| rnn_hidden_units = 100 | |||
| word_emb_dim = 100 | |||
| word_emb_dim = 100 | |||
| @@ -13,7 +13,7 @@ with open('requirements.txt', encoding='utf-8') as f: | |||
| setup( | |||
| name='fastNLP', | |||
| version='0.1.0', | |||
| version='0.1.1', | |||
| description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | |||
| long_description=readme, | |||
| license=license, | |||