- 在api中创建converter.py - Pipeline添加初始化方法,方便一次性添加processors - 删除pos_tagger.py - 优化整体code styletags/v0.2.0
| @@ -0,0 +1,182 @@ | |||
| import re | |||
| class SpanConverter: | |||
| def __init__(self, replace_tag, pattern): | |||
| super(SpanConverter, self).__init__() | |||
| self.replace_tag = replace_tag | |||
| self.pattern = pattern | |||
| def find_certain_span_and_replace(self, sentence): | |||
| replaced_sentence = '' | |||
| prev_end = 0 | |||
| for match in re.finditer(self.pattern, sentence): | |||
| start, end = match.span() | |||
| span = sentence[start:end] | |||
| replaced_sentence += sentence[prev_end:start] + \ | |||
| self.span_to_special_tag(span) | |||
| prev_end = end | |||
| replaced_sentence += sentence[prev_end:] | |||
| return replaced_sentence | |||
| def span_to_special_tag(self, span): | |||
| return self.replace_tag | |||
| def find_certain_span(self, sentence): | |||
| spans = [] | |||
| for match in re.finditer(self.pattern, sentence): | |||
| spans.append(match.span()) | |||
| return spans | |||
| class AlphaSpanConverter(SpanConverter): | |||
| def __init__(self): | |||
| replace_tag = '<ALPHA>' | |||
| # 理想状态下仅处理纯为字母的情况, 但不处理<[a-zA-Z]+>(因为这应该是特殊的tag). | |||
| pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%.!<\\-"])' | |||
| super(AlphaSpanConverter, self).__init__(replace_tag, pattern) | |||
| class DigitSpanConverter(SpanConverter): | |||
| def __init__(self): | |||
| replace_tag = '<NUM>' | |||
| pattern = '\d[\d\\.]*(?=[\u4e00-\u9fff ,%.!<-])' | |||
| super(DigitSpanConverter, self).__init__(replace_tag, pattern) | |||
| def span_to_special_tag(self, span): | |||
| # return self.special_tag | |||
| if span[0] == '0' and len(span) > 2: | |||
| return '<NUM>' | |||
| decimal_point_count = 0 # one might have more than one decimal pointers | |||
| for idx, char in enumerate(span): | |||
| if char == '.' or char == '﹒' or char == '·': | |||
| decimal_point_count += 1 | |||
| if span[-1] == '.' or span[-1] == '﹒' or span[ | |||
| -1] == '·': # last digit being decimal point means this is not a number | |||
| if decimal_point_count == 1: | |||
| return span | |||
| else: | |||
| return '<UNKDGT>' | |||
| if decimal_point_count == 1: | |||
| return '<DEC>' | |||
| elif decimal_point_count > 1: | |||
| return '<UNKDGT>' | |||
| else: | |||
| return '<NUM>' | |||
| class TimeConverter(SpanConverter): | |||
| def __init__(self): | |||
| replace_tag = '<TOC>' | |||
| pattern = '\d+[::∶][\d::∶]+(?=[\u4e00-\u9fff ,%.!<-])' | |||
| super().__init__(replace_tag, pattern) | |||
| class MixNumAlphaConverter(SpanConverter): | |||
| def __init__(self): | |||
| replace_tag = '<MIX>' | |||
| pattern = None | |||
| super().__init__(replace_tag, pattern) | |||
| def find_certain_span_and_replace(self, sentence): | |||
| replaced_sentence = '' | |||
| start = 0 | |||
| matching_flag = False | |||
| number_flag = False | |||
| alpha_flag = False | |||
| link_flag = False | |||
| slash_flag = False | |||
| bracket_flag = False | |||
| for idx in range(len(sentence)): | |||
| if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||
| if not matching_flag: | |||
| replaced_sentence += sentence[start:idx] | |||
| start = idx | |||
| if re.match('[0-9]', sentence[idx]): | |||
| number_flag = True | |||
| elif re.match('[\'′&\\-]', sentence[idx]): | |||
| link_flag = True | |||
| elif re.match('/', sentence[idx]): | |||
| slash_flag = True | |||
| elif re.match('[\\(\\)]', sentence[idx]): | |||
| bracket_flag = True | |||
| else: | |||
| alpha_flag = True | |||
| matching_flag = True | |||
| elif re.match('[\\.]', sentence[idx]): | |||
| pass | |||
| else: | |||
| if matching_flag: | |||
| if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||
| or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||
| or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||
| span = sentence[start:idx] | |||
| start = idx | |||
| replaced_sentence += self.span_to_special_tag(span) | |||
| matching_flag = False | |||
| number_flag = False | |||
| alpha_flag = False | |||
| link_flag = False | |||
| slash_flag = False | |||
| bracket_flag = False | |||
| replaced_sentence += sentence[start:] | |||
| return replaced_sentence | |||
| def find_certain_span(self, sentence): | |||
| spans = [] | |||
| start = 0 | |||
| matching_flag = False | |||
| number_flag = False | |||
| alpha_flag = False | |||
| link_flag = False | |||
| slash_flag = False | |||
| bracket_flag = False | |||
| for idx in range(len(sentence)): | |||
| if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||
| if not matching_flag: | |||
| start = idx | |||
| if re.match('[0-9]', sentence[idx]): | |||
| number_flag = True | |||
| elif re.match('[\'′&\\-]', sentence[idx]): | |||
| link_flag = True | |||
| elif re.match('/', sentence[idx]): | |||
| slash_flag = True | |||
| elif re.match('[\\(\\)]', sentence[idx]): | |||
| bracket_flag = True | |||
| else: | |||
| alpha_flag = True | |||
| matching_flag = True | |||
| elif re.match('[\\.]', sentence[idx]): | |||
| pass | |||
| else: | |||
| if matching_flag: | |||
| if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||
| or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||
| or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||
| spans.append((start, idx)) | |||
| start = idx | |||
| matching_flag = False | |||
| number_flag = False | |||
| alpha_flag = False | |||
| link_flag = False | |||
| slash_flag = False | |||
| bracket_flag = False | |||
| return spans | |||
| class EmailConverter(SpanConverter): | |||
| def __init__(self): | |||
| replaced_tag = "<EML>" | |||
| pattern = '[0-9a-zA-Z]+[@][.﹒0-9a-zA-Z@]+(?=[\u4e00-\u9fff ,%.!<\\-"$])' | |||
| super(EmailConverter, self).__init__(replaced_tag, pattern) | |||
| @@ -1,17 +1,25 @@ | |||
| from fastNLP.api.processor import Processor | |||
| class Pipeline: | |||
| def __init__(self): | |||
| """ | |||
| Pipeline takes a DataSet object as input, runs multiple processors sequentially, and | |||
| outputs a DataSet object. | |||
| """ | |||
| def __init__(self, processors=None): | |||
| self.pipeline = [] | |||
| if isinstance(processors, list): | |||
| for proc in processors: | |||
| assert isinstance(proc, Processor), "Must be a Processor, not {}.".format(type(processor)) | |||
| self.pipeline = processors | |||
| def add_processor(self, processor): | |||
| assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor)) | |||
| self.pipeline.append(processor) | |||
| def process(self, dataset): | |||
| assert len(self.pipeline)!=0, "You need to add some processor first." | |||
| assert len(self.pipeline) != 0, "You need to add some processor first." | |||
| for proc_name, proc in self.pipeline: | |||
| dataset = proc(dataset) | |||
| @@ -19,4 +27,4 @@ class Pipeline: | |||
| return dataset | |||
| def __call__(self, *args, **kwargs): | |||
| return self.process(*args, **kwargs) | |||
| return self.process(*args, **kwargs) | |||
| @@ -1,44 +0,0 @@ | |||
| import pickle | |||
| import numpy as np | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.loader.model_loader import ModelLoader | |||
| from fastNLP.core.predictor import Predictor | |||
| class POS_tagger: | |||
| def __init__(self): | |||
| pass | |||
| def predict(self, query): | |||
| """ | |||
| :param query: List[str] | |||
| :return answer: List[str] | |||
| """ | |||
| # TODO: 根据query 构建DataSet | |||
| pos_dataset = DataSet() | |||
| pos_dataset["text_field"] = np.array(query) | |||
| # 加载pipeline和model | |||
| pipeline = self.load_pipeline("./xxxx") | |||
| # 将DataSet作为参数运行 pipeline | |||
| pos_dataset = pipeline(pos_dataset) | |||
| # 加载模型 | |||
| model = ModelLoader().load_pytorch("./xxx") | |||
| # 调 predictor | |||
| predictor = Predictor() | |||
| output = predictor.predict(model, pos_dataset) | |||
| # TODO: 转成最终输出 | |||
| return None | |||
| @staticmethod | |||
| def load_pipeline(path): | |||
| with open(path, "r") as fp: | |||
| pipeline = pickle.load(fp) | |||
| return pipeline | |||
| @@ -1,7 +1,7 @@ | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| class Processor: | |||
| def __init__(self, field_name, new_added_field_name): | |||
| self.field_name = field_name | |||
| @@ -10,15 +10,18 @@ class Processor: | |||
| else: | |||
| self.new_added_field_name = new_added_field_name | |||
| def process(self): | |||
| def process(self, *args, **kwargs): | |||
| pass | |||
| def __call__(self, *args, **kwargs): | |||
| return self.process(*args, **kwargs) | |||
| class FullSpaceToHalfSpaceProcessor(Processor): | |||
| """全角转半角,以字符为处理单元 | |||
| """ | |||
| def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True, | |||
| change_space=True): | |||
| super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None) | |||
| @@ -64,11 +67,12 @@ class FullSpaceToHalfSpaceProcessor(Processor): | |||
| if self.change_space: | |||
| FHs += FH_SPACE | |||
| self.convert_map = {k: v for k, v in FHs} | |||
| def process(self, dataset): | |||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
| for ins in dataset: | |||
| sentence = ins[self.field_name] | |||
| new_sentence = [None]*len(sentence) | |||
| new_sentence = [None] * len(sentence) | |||
| for idx, char in enumerate(sentence): | |||
| if char in self.convert_map: | |||
| char = self.convert_map[char] | |||
| @@ -98,7 +102,7 @@ class IndexerProcessor(Processor): | |||
| index = [self.vocab.to_index(token) for token in tokens] | |||
| ins[self.new_added_field_name] = index | |||
| dataset.set_need_tensor(**{self.new_added_field_name:True}) | |||
| dataset.set_need_tensor(**{self.new_added_field_name: True}) | |||
| if self.delete_old_field: | |||
| dataset.delete_field(self.field_name) | |||
| @@ -122,3 +126,16 @@ class VocabProcessor(Processor): | |||
| def get_vocab(self): | |||
| self.vocab.build_vocab() | |||
| return self.vocab | |||
| class SeqLenProcessor(Processor): | |||
| def __init__(self, field_name, new_added_field_name='seq_lens'): | |||
| super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||
| def process(self, dataset): | |||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
| for ins in dataset: | |||
| length = len(ins[self.field_name]) | |||
| ins[self.new_added_field_name] = length | |||
| dataset.set_need_tensor(**{self.new_added_field_name: True}) | |||
| return dataset | |||
| @@ -1,5 +1,3 @@ | |||
| from collections import defaultdict | |||
| import torch | |||
| @@ -68,4 +66,3 @@ class Batch(object): | |||
| self.curidx = endidx | |||
| return batch_x, batch_y | |||
| @@ -1,23 +1,27 @@ | |||
| import random | |||
| import sys, os | |||
| sys.path.append('../..') | |||
| sys.path = [os.path.join(os.path.dirname(__file__), '../..')] + sys.path | |||
| from collections import defaultdict | |||
| from copy import deepcopy | |||
| import numpy as np | |||
| from fastNLP.core.field import TextField, LabelField | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.core.fieldarray import FieldArray | |||
| _READERS = {} | |||
| def construct_dataset(sentences): | |||
| """Construct a data set from a list of sentences. | |||
| :param sentences: list of str | |||
| :return dataset: a DataSet object | |||
| """ | |||
| dataset = DataSet() | |||
| for sentence in sentences: | |||
| instance = Instance() | |||
| instance['raw_sentence'] = sentence | |||
| dataset.append(instance) | |||
| return dataset | |||
| class DataSet(object): | |||
| """A DataSet object is a list of Instance objects. | |||
| """ | |||
| class DataSetIter(object): | |||
| def __init__(self, dataset): | |||
| self.dataset = dataset | |||
| @@ -34,13 +38,12 @@ class DataSet(object): | |||
| def __setitem__(self, name, val): | |||
| if name not in self.dataset: | |||
| new_fields = [None]*len(self.dataset) | |||
| new_fields = [None] * len(self.dataset) | |||
| self.dataset.add_field(name, new_fields) | |||
| self.dataset[name][self.idx] = val | |||
| def __repr__(self): | |||
| # TODO | |||
| pass | |||
| return " ".join([repr(self.dataset[name][self.idx]) for name in self.dataset]) | |||
| def __init__(self, instance=None): | |||
| self.field_arrays = {} | |||
| @@ -72,7 +75,7 @@ class DataSet(object): | |||
| self.field_arrays[name].append(field) | |||
| def add_field(self, name, fields): | |||
| if len(self.field_arrays)!=0: | |||
| if len(self.field_arrays) != 0: | |||
| assert len(self) == len(fields) | |||
| self.field_arrays[name] = FieldArray(name, fields) | |||
| @@ -90,27 +93,10 @@ class DataSet(object): | |||
| return len(field) | |||
| def get_length(self): | |||
| """Fetch lengths of all fields in all instances in a dataset. | |||
| :return lengths: dict of (str: list). The str is the field name. | |||
| The list contains lengths of this field in all instances. | |||
| """ | |||
| pass | |||
| def shuffle(self): | |||
| pass | |||
| def split(self, ratio, shuffle=True): | |||
| """Train/dev splitting | |||
| :param ratio: float, between 0 and 1. The ratio of development set in origin data set. | |||
| :param shuffle: bool, whether shuffle the data set before splitting. Default: True. | |||
| :return train_set: a DataSet object, representing the training set | |||
| dev_set: a DataSet object, representing the validation set | |||
| """The same as __len__ | |||
| """ | |||
| pass | |||
| return len(self) | |||
| def rename_field(self, old_name, new_name): | |||
| """rename a field | |||
| @@ -118,7 +104,7 @@ class DataSet(object): | |||
| if old_name in self.field_arrays: | |||
| self.field_arrays[new_name] = self.field_arrays.pop(old_name) | |||
| else: | |||
| raise KeyError | |||
| raise KeyError("{} is not a valid name. ".format(old_name)) | |||
| return self | |||
| def set_is_target(self, **fields): | |||
| @@ -150,6 +136,7 @@ class DataSet(object): | |||
| data = _READERS[name]().load(*args, **kwargs) | |||
| self.extend(data) | |||
| return self | |||
| return _read | |||
| else: | |||
| return object.__getattribute__(self, name) | |||
| @@ -159,18 +146,21 @@ class DataSet(object): | |||
| """decorator to add dataloader support | |||
| """ | |||
| assert isinstance(method_name, str) | |||
| def wrapper(read_cls): | |||
| _READERS[method_name] = read_cls | |||
| return read_cls | |||
| return wrapper | |||
| if __name__ == '__main__': | |||
| from fastNLP.core.instance import Instance | |||
| ins = Instance(test='test0') | |||
| dataset = DataSet([ins]) | |||
| for _iter in dataset: | |||
| print(_iter['test']) | |||
| _iter['test'] = 'abc' | |||
| print(_iter['test']) | |||
| print(dataset.field_arrays) | |||
| print(dataset.field_arrays) | |||
| @@ -1,4 +1,4 @@ | |||
| import torch | |||
| class Instance(object): | |||
| """An instance which consists of Fields is an example in the DataSet. | |||
| @@ -35,4 +35,4 @@ class Instance(object): | |||
| return self.add_field(name, field) | |||
| def __repr__(self): | |||
| return self.fields.__repr__() | |||
| return self.fields.__repr__() | |||
| @@ -1,9 +1,9 @@ | |||
| import os | |||
| from fastNLP.loader.base_loader import BaseLoader | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.field import * | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.loader.base_loader import BaseLoader | |||
| def convert_seq_dataset(data): | |||
| @@ -393,6 +393,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||
| sent_words.append(token) | |||
| pos_tag_examples.append([sent_words, sent_pos_tag]) | |||
| ner_examples.append([sent_words, sent_ner]) | |||
| # List[List[List[str], List[str]]] | |||
| return pos_tag_examples, ner_examples | |||
| def convert(self, data): | |||
| @@ -44,6 +44,9 @@ class SeqLabeling(BaseModel): | |||
| :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | |||
| If truth is not None, return loss, a scalar. Used in training. | |||
| """ | |||
| assert word_seq.shape[0] == word_seq_origin_len.shape[0] | |||
| if truth is not None: | |||
| assert truth.shape == word_seq.shape | |||
| self.mask = self.make_mask(word_seq, word_seq_origin_len) | |||
| x = self.Embedding(word_seq) | |||
| @@ -80,7 +83,7 @@ class SeqLabeling(BaseModel): | |||
| batch_size, max_len = x.size(0), x.size(1) | |||
| mask = seq_mask(seq_len, max_len) | |||
| mask = mask.byte().view(batch_size, max_len) | |||
| mask = mask.to(x) | |||
| mask = mask.to(x).float() | |||
| return mask | |||
| def decode(self, x, pad=True): | |||
| @@ -130,6 +133,9 @@ class AdvSeqLabel(SeqLabeling): | |||
| :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | |||
| If truth is not None, return loss, a scalar. Used in training. | |||
| """ | |||
| word_seq = word_seq.long() | |||
| word_seq_origin_len = word_seq_origin_len.long() | |||
| truth = truth.long() | |||
| self.mask = self.make_mask(word_seq, word_seq_origin_len) | |||
| batch_size = word_seq.size(0) | |||
| @@ -3,6 +3,7 @@ from torch import nn | |||
| from fastNLP.modules.utils import initial_parameter | |||
| def log_sum_exp(x, dim=-1): | |||
| max_value, _ = x.max(dim=dim, keepdim=True) | |||
| res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | |||
| @@ -20,7 +21,7 @@ def seq_len_to_byte_mask(seq_lens): | |||
| class ConditionalRandomField(nn.Module): | |||
| def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None): | |||
| def __init__(self, tag_size, include_start_end_trans=False, initial_method=None): | |||
| """ | |||
| :param tag_size: int, num of tags | |||
| :param include_start_end_trans: bool, whether to include start/end tag | |||
| @@ -38,6 +39,7 @@ class ConditionalRandomField(nn.Module): | |||
| # self.reset_parameter() | |||
| initial_parameter(self, initial_method) | |||
| def reset_parameter(self): | |||
| nn.init.xavier_normal_(self.trans_m) | |||
| if self.include_start_end_trans: | |||
| @@ -81,15 +83,15 @@ class ConditionalRandomField(nn.Module): | |||
| seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | |||
| # trans_socre [L-1, B] | |||
| trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :] | |||
| trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]] * mask[1:, :] | |||
| # emit_score [L, B] | |||
| emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask | |||
| emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags] * mask | |||
| # score [L-1, B] | |||
| score = trans_score + emit_score[:seq_len-1, :] | |||
| score = trans_score + emit_score[:seq_len - 1, :] | |||
| score = score.sum(0) + emit_score[-1] | |||
| if self.include_start_end_trans: | |||
| st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | |||
| last_idx = masks.long().sum(0) | |||
| last_idx = mask.long().sum(0) | |||
| ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | |||
| score += st_scores + ed_scores | |||
| # return [B,] | |||
| @@ -120,14 +122,14 @@ class ConditionalRandomField(nn.Module): | |||
| :return: scores, paths | |||
| """ | |||
| batch_size, seq_len, n_tags = data.size() | |||
| data = data.transpose(0, 1).data # L, B, H | |||
| mask = mask.transpose(0, 1).data.float() # L, B | |||
| data = data.transpose(0, 1).data # L, B, H | |||
| mask = mask.transpose(0, 1).data.float() # L, B | |||
| # dp | |||
| vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
| vscore = data[0] | |||
| if self.include_start_end_trans: | |||
| vscore += self.start_scores.view(1. -1) | |||
| vscore += self.start_scores.view(1. - 1) | |||
| for i in range(1, seq_len): | |||
| prev_score = vscore.view(batch_size, n_tags, 1) | |||
| cur_score = data[i].view(batch_size, 1, n_tags) | |||
| @@ -145,15 +147,15 @@ class ConditionalRandomField(nn.Module): | |||
| seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) | |||
| lens = (mask.long().sum(0) - 1) | |||
| # idxes [L, B], batched idx from seq_len-1 to 0 | |||
| idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len | |||
| idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len | |||
| ans = data.new_empty((seq_len, batch_size), dtype=torch.long) | |||
| ans_score, last_tags = vscore.max(1) | |||
| ans[idxes[0], batch_idx] = last_tags | |||
| for i in range(seq_len - 1): | |||
| last_tags = vpath[idxes[i], batch_idx, last_tags] | |||
| ans[idxes[i+1], batch_idx] = last_tags | |||
| ans[idxes[i + 1], batch_idx] = last_tags | |||
| if get_score: | |||
| return ans_score, ans.transpose(0, 1) | |||
| return ans.transpose(0, 1) | |||
| return ans.transpose(0, 1) | |||
| @@ -1,10 +1,12 @@ | |||
| [train] | |||
| epochs = 30 | |||
| batch_size = 64 | |||
| epochs = 5 | |||
| batch_size = 2 | |||
| pickle_path = "./save/" | |||
| validate = true | |||
| validate = false | |||
| save_best_dev = true | |||
| model_saved_path = "./save/" | |||
| [model] | |||
| rnn_hidden_units = 100 | |||
| word_emb_dim = 100 | |||
| use_crf = true | |||
| @@ -1,130 +1,88 @@ | |||
| import os | |||
| import sys | |||
| sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||
| import torch | |||
| from fastNLP.api.pipeline import Pipeline | |||
| from fastNLP.api.processor import VocabProcessor, IndexerProcessor, SeqLenProcessor | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.trainer import Trainer | |||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.core.trainer import SeqLabelTrainer | |||
| from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader, BaseLoader | |||
| from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||
| from fastNLP.saver.model_saver import ModelSaver | |||
| from fastNLP.loader.model_loader import ModelLoader | |||
| from fastNLP.core.tester import SeqLabelTester | |||
| from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader | |||
| from fastNLP.models.sequence_modeling import AdvSeqLabel | |||
| from fastNLP.core.predictor import SeqLabelInfer | |||
| # not in the file's dir | |||
| if len(os.path.dirname(__file__)) != 0: | |||
| os.chdir(os.path.dirname(__file__)) | |||
| datadir = "/home/zyfeng/data/" | |||
| cfgfile = './pos_tag.cfg' | |||
| data_name = "CWS_POS_TAG_NER_people_daily.txt" | |||
| datadir = "/home/zyfeng/fastnlp_0.2.0/test/data_for_tests/" | |||
| data_name = "people_daily_raw.txt" | |||
| pos_tag_data_path = os.path.join(datadir, data_name) | |||
| pickle_path = "save" | |||
| data_infer_path = os.path.join(datadir, "infer.utf8") | |||
| def infer(): | |||
| # Config Loader | |||
| test_args = ConfigSection() | |||
| ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||
| # fetch dictionary size and number of labels from pickle files | |||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||
| test_args["vocab_size"] = len(word2index) | |||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||
| test_args["num_classes"] = len(index2label) | |||
| # Define the same model | |||
| model = AdvSeqLabel(test_args) | |||
| try: | |||
| ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
| print('model loaded!') | |||
| except Exception as e: | |||
| print('cannot load model!') | |||
| raise | |||
| # Data Loader | |||
| raw_data_loader = BaseLoader(data_infer_path) | |||
| infer_data = raw_data_loader.load_lines() | |||
| print('data loaded') | |||
| # Inference interface | |||
| infer = SeqLabelInfer(pickle_path) | |||
| results = infer.predict(model, infer_data) | |||
| print(results) | |||
| print("Inference finished!") | |||
| def train(): | |||
| def train(): | |||
| # load config | |||
| trainer_args = ConfigSection() | |||
| model_args = ConfigSection() | |||
| ConfigLoader().load_config(cfgfile, {"train": train_args, "test": test_args}) | |||
| train_param = ConfigSection() | |||
| model_param = ConfigSection() | |||
| ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param}) | |||
| print("config loaded") | |||
| # Data Loader | |||
| loader = PeopleDailyCorpusLoader() | |||
| train_data, _ = loader.load() | |||
| # TODO: define processors | |||
| # define pipeline | |||
| pp = Pipeline() | |||
| # TODO: pp.add_processor() | |||
| # run the pipeline, get data_set | |||
| train_data = pp(train_data) | |||
| train_data, _ = loader.load(os.path.join(datadir, data_name)) | |||
| print("data loaded") | |||
| dataset = DataSet() | |||
| for data in train_data: | |||
| instance = Instance() | |||
| instance["words"] = data[0] | |||
| instance["tag"] = data[1] | |||
| dataset.append(instance) | |||
| print("dataset transformed") | |||
| # processor_1 = FullSpaceToHalfSpaceProcessor('words') | |||
| # processor_1(dataset) | |||
| word_vocab_proc = VocabProcessor('words') | |||
| tag_vocab_proc = VocabProcessor("tag") | |||
| word_vocab_proc(dataset) | |||
| tag_vocab_proc(dataset) | |||
| word_indexer = IndexerProcessor(word_vocab_proc.get_vocab(), 'words', 'word_seq', delete_old_field=True) | |||
| word_indexer(dataset) | |||
| tag_indexer = IndexerProcessor(tag_vocab_proc.get_vocab(), 'tag', 'truth', delete_old_field=True) | |||
| tag_indexer(dataset) | |||
| seq_len_proc = SeqLenProcessor("word_seq", "word_seq_origin_len") | |||
| seq_len_proc(dataset) | |||
| print("processors defined") | |||
| # dataset.set_is_target(tag_ids=True) | |||
| model_param["vocab_size"] = len(word_vocab_proc.get_vocab()) | |||
| model_param["num_classes"] = len(tag_vocab_proc.get_vocab()) | |||
| print("vocab_size={} num_classes={}".format(len(word_vocab_proc.get_vocab()), len(tag_vocab_proc.get_vocab()))) | |||
| # define a model | |||
| model = AdvSeqLabel(train_args) | |||
| model = AdvSeqLabel(model_param) | |||
| # call trainer to train | |||
| trainer = SeqLabelTrainer(train_args) | |||
| trainer.train(model, data_train, data_dev) | |||
| # save model | |||
| ModelSaver("./saved_model.pkl").save_pytorch(model, param_only=False) | |||
| # TODO:save pipeline | |||
| trainer = Trainer(**train_param.data) | |||
| trainer.train(model, dataset) | |||
| # save model & pipeline | |||
| pp = Pipeline([word_vocab_proc, word_indexer, seq_len_proc]) | |||
| save_dict = {"pipeline": pp, "model": model} | |||
| torch.save(save_dict, "model_pp.pkl") | |||
| def test(): | |||
| # Config Loader | |||
| test_args = ConfigSection() | |||
| ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||
| # fetch dictionary size and number of labels from pickle files | |||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||
| test_args["vocab_size"] = len(word2index) | |||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||
| test_args["num_classes"] = len(index2label) | |||
| # load dev data | |||
| dev_data = load_pickle(pickle_path, "data_dev.pkl") | |||
| # Define the same model | |||
| model = AdvSeqLabel(test_args) | |||
| pass | |||
| # Dump trained parameters into the model | |||
| ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
| print("model loaded!") | |||
| # Tester | |||
| tester = SeqLabelTester(**test_args.data) | |||
| # Start testing | |||
| tester.test(model, dev_data) | |||
| # print test results | |||
| print(tester.show_metrics()) | |||
| print("model tested!") | |||
| def infer(): | |||
| pass | |||
| if __name__ == "__main__": | |||
| train() | |||
| """ | |||
| import argparse | |||
| parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | |||
| @@ -139,3 +97,5 @@ if __name__ == "__main__": | |||
| else: | |||
| print('no mode specified for model!') | |||
| parser.print_help() | |||
| """ | |||