| @@ -0,0 +1,32 @@ | |||||
| from fastNLP.api.api import API | |||||
| from fastNLP.core.dataset import DataSet | |||||
| class CWS(API): | |||||
| def __init__(self, model_path='xxx'): | |||||
| super(CWS, self).__init__() | |||||
| self.load(model_path) | |||||
| def predict(self, sentence, pretrain=False): | |||||
| if hasattr(self, 'model') and hasattr(self, 'pipeline'): | |||||
| raise ValueError("You have to load model first. Or specify pretrain=True.") | |||||
| sentence_list = [] | |||||
| # 1. 检查sentence的类型 | |||||
| if isinstance(sentence, str): | |||||
| sentence_list.append(sentence) | |||||
| elif isinstance(sentence, list): | |||||
| sentence_list = sentence | |||||
| # 2. 组建dataset | |||||
| dataset = DataSet() | |||||
| dataset.add_field('raw_sentence', sentence_list) | |||||
| # 3. 使用pipeline | |||||
| self.pipeline(dataset) | |||||
| # 4. TODO 这里应该要交给一个iterator一样的东西预测这个结果 | |||||
| # 5. TODO 得到结果,需要考虑是否需要反转回去, 及post_process的操作 | |||||
| @@ -13,7 +13,7 @@ class Pipeline: | |||||
| def process(self, dataset): | def process(self, dataset): | ||||
| assert len(self.pipeline)!=0, "You need to add some processor first." | assert len(self.pipeline)!=0, "You need to add some processor first." | ||||
| for proc_name, proc in self.pipeline: | |||||
| for proc in self.pipeline: | |||||
| dataset = proc(dataset) | dataset = proc(dataset) | ||||
| return dataset | return dataset | ||||
| @@ -223,8 +223,21 @@ pp = Pipeline() | |||||
| pp.add_processor(fs2hs_proc) | pp.add_processor(fs2hs_proc) | ||||
| pp.add_processor(sp_proc) | pp.add_processor(sp_proc) | ||||
| pp.add_processor(char_proc) | pp.add_processor(char_proc) | ||||
| pp.add_processor(tag_proc) | |||||
| pp.add_processor(bigram_proc) | pp.add_processor(bigram_proc) | ||||
| pp.add_processor(char_index_proc) | pp.add_processor(char_index_proc) | ||||
| pp.add_processor(bigram_index_proc) | pp.add_processor(bigram_index_proc) | ||||
| pp.add_processor(seq_len_proc) | pp.add_processor(seq_len_proc) | ||||
| te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_test.txt' | |||||
| te_dataset = reader.load(te_filename) | |||||
| pp(te_dataset) | |||||
| batch_size = 64 | |||||
| te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||||
| pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher) | |||||
| print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100, | |||||
| pre * 100, | |||||
| rec * 100)) | |||||