| @@ -63,7 +63,7 @@ class Inference(object): | |||
| """ | |||
| Perform inference. | |||
| :param network: | |||
| :param data: multi-level lists of strings | |||
| :param data: two-level lists of strings | |||
| :return result: the model outputs | |||
| """ | |||
| # transform strings into indices | |||
| @@ -97,7 +97,7 @@ class Inference(object): | |||
| def prepare_input(self, data): | |||
| """ | |||
| Transform three-level list of strings into that of index. | |||
| Transform two-level list of strings into that of index. | |||
| :param data: | |||
| [ | |||
| [word_11, word_12, ...], | |||
| @@ -140,7 +140,7 @@ class SeqLabelInfer(Inference): | |||
| mask = mask.byte().view(batch_size, max_len) | |||
| y = network(x) | |||
| prediction = network.prediction(y, mask) | |||
| return torch.Tensor(prediction, required_grad=False) | |||
| return torch.Tensor(prediction) | |||
| def make_batch(self, iterator, data, use_cuda): | |||
| return make_batch(iterator, data, use_cuda, output_length=True) | |||
| @@ -149,7 +149,7 @@ class SeqLabelInfer(Inference): | |||
| """ | |||
| Transform list of batch outputs into strings. | |||
| :param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length]. | |||
| :return: | |||
| :return results: 2-D list of strings | |||
| """ | |||
| results = [] | |||
| for batch in batch_outputs: | |||
| @@ -178,7 +178,7 @@ class ClassificationInfer(Inference): | |||
| """ | |||
| Transform list of batch outputs into strings. | |||
| :param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, num_classes]. | |||
| :return: | |||
| :return results: list of strings | |||
| """ | |||
| results = [] | |||
| for batch_out in batch_outputs: | |||
| @@ -37,10 +37,6 @@ class BaseTester(object): | |||
| else: | |||
| self.model = network | |||
| # no backward setting for model | |||
| for param in network.parameters(): | |||
| param.requires_grad = False | |||
| # turn on the testing mode; clean up the history | |||
| self.mode(network, test=True) | |||
| self.eval_history.clear() | |||
| @@ -112,6 +108,7 @@ class SeqLabelTester(BaseTester): | |||
| super(SeqLabelTester, self).__init__(test_args) | |||
| self.max_len = None | |||
| self.mask = None | |||
| self.seq_len = None | |||
| self.batch_result = None | |||
| def data_forward(self, network, inputs): | |||
| @@ -125,7 +122,7 @@ class SeqLabelTester(BaseTester): | |||
| if torch.cuda.is_available() and self.use_cuda: | |||
| mask = mask.cuda() | |||
| self.mask = mask | |||
| self.seq_len = seq_len | |||
| y = network(x) | |||
| return y | |||
| @@ -315,14 +315,8 @@ class ClassificationTrainer(BaseTrainer): | |||
| def __init__(self, train_args): | |||
| super(ClassificationTrainer, self).__init__(train_args) | |||
| if "learn_rate" in train_args: | |||
| self.learn_rate = train_args["learn_rate"] | |||
| else: | |||
| self.learn_rate = 1e-3 | |||
| if "momentum" in train_args: | |||
| self.momentum = train_args["momentum"] | |||
| else: | |||
| self.momentum = 0.9 | |||
| self.learn_rate = train_args["learn_rate"] | |||
| self.momentum = train_args["momentum"] | |||
| self.iterator = None | |||
| self.loss_func = None | |||
| @@ -1,4 +1,4 @@ | |||
| from fastNLP.core.inference import Inference | |||
| from fastNLP.core.inference import SeqLabelInfer, ClassificationInfer | |||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.loader.model_loader import ModelLoader | |||
| @@ -10,14 +10,28 @@ Example: | |||
| "zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"] | |||
| """ | |||
| FastNLP_MODEL_COLLECTION = { | |||
| "zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"] | |||
| "seq_label_model": { | |||
| "url": "www.fudan.edu.cn", | |||
| "class": "sequence_modeling.SeqLabeling", | |||
| "pickle": "seq_label_model.pkl", | |||
| "type": "seq_label" | |||
| }, | |||
| "text_class_model": { | |||
| "url": "www.fudan.edu.cn", | |||
| "class": "cnn_text_classification.CNNText", | |||
| "pickle": "text_class_model.pkl", | |||
| "type": "text_class" | |||
| } | |||
| } | |||
| CONFIG_FILE_NAME = "config" | |||
| SECTION_NAME = "text_class_model" | |||
| class FastNLP(object): | |||
| """ | |||
| High-level interface for direct model inference. | |||
| Usage: | |||
| Example Usage: | |||
| fastnlp = FastNLP() | |||
| fastnlp.load("zh_pos_tag_model") | |||
| text = "这是最好的基于深度学习的中文分词系统。" | |||
| @@ -35,6 +49,7 @@ class FastNLP(object): | |||
| """ | |||
| self.model_dir = model_dir | |||
| self.model = None | |||
| self.infer_type = None # "seq_label"/"text_class" | |||
| def load(self, model_name): | |||
| """ | |||
| @@ -46,21 +61,21 @@ class FastNLP(object): | |||
| raise ValueError("No FastNLP model named {}.".format(model_name)) | |||
| if not self.model_exist(model_dir=self.model_dir): | |||
| self._download(model_name, FastNLP_MODEL_COLLECTION[model_name][0]) | |||
| self._download(model_name, FastNLP_MODEL_COLLECTION[model_name]["url"]) | |||
| model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name][1]) | |||
| model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"]) | |||
| model_args = ConfigSection() | |||
| # To do: customized config file for model init parameters | |||
| ConfigLoader.load_config(self.model_dir + "config", {"POS_infer": model_args}) | |||
| ConfigLoader.load_config(self.model_dir + CONFIG_FILE_NAME, {SECTION_NAME: model_args}) | |||
| # Construct the model | |||
| model = model_class(model_args) | |||
| # To do: framework independent | |||
| ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name][2]) | |||
| ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name]["pickle"]) | |||
| self.model = model | |||
| self.infer_type = FastNLP_MODEL_COLLECTION[model_name]["type"] | |||
| print("Model loaded. ") | |||
| @@ -71,12 +86,16 @@ class FastNLP(object): | |||
| :return results: | |||
| """ | |||
| infer = Inference(self.model_dir) | |||
| infer = self._create_inference(self.model_dir) | |||
| # string ---> 2-D list of string | |||
| infer_input = self.string_to_list(raw_input) | |||
| # 2-D list of string ---> list of strings | |||
| results = infer.predict(self.model, infer_input) | |||
| outputs = self.make_output(results) | |||
| # list of strings ---> final answers | |||
| outputs = self._make_output(results, infer_input) | |||
| return outputs | |||
| @staticmethod | |||
| @@ -95,6 +114,14 @@ class FastNLP(object): | |||
| module = getattr(module, sub) | |||
| return module | |||
| def _create_inference(self, model_dir): | |||
| if self.infer_type == "seq_label": | |||
| return SeqLabelInfer(model_dir) | |||
| elif self.infer_type == "text_class": | |||
| return ClassificationInfer(model_dir) | |||
| else: | |||
| raise ValueError("fail to create inference instance") | |||
| def _load(self, model_dir, model_name): | |||
| # To do | |||
| return 0 | |||
| @@ -117,7 +144,6 @@ class FastNLP(object): | |||
| def string_to_list(self, text, delimiter="\n"): | |||
| """ | |||
| For word seg only, currently. | |||
| This function is used to transform raw input to lists, which is done by DatasetLoader in training. | |||
| Split text string into three-level lists. | |||
| [ | |||
| @@ -127,7 +153,7 @@ class FastNLP(object): | |||
| ] | |||
| :param text: string | |||
| :param delimiter: str, character used to split text into sentences. | |||
| :return data: three-level lists | |||
| :return data: two-level lists | |||
| """ | |||
| data = [] | |||
| sents = text.strip().split(delimiter) | |||
| @@ -136,38 +162,61 @@ class FastNLP(object): | |||
| for ch in sent: | |||
| characters.append(ch) | |||
| data.append(characters) | |||
| # To refactor: this is used in make_output | |||
| self.data = data | |||
| return data | |||
| def make_output(self, results): | |||
| """ | |||
| Transform model output into user-friendly contents. | |||
| Example: In CWS, convert <BMES> labeling into segmented text. | |||
| :param results: | |||
| :return: | |||
| """ | |||
| outputs = [] | |||
| for sent_char, sent_label in zip(self.data, results): | |||
| words = [] | |||
| word = "" | |||
| for char, label in zip(sent_char, sent_label): | |||
| if label[0] == "B": | |||
| if word != "": | |||
| words.append(word) | |||
| word = char | |||
| elif label[0] == "M": | |||
| word += char | |||
| elif label[0] == "E": | |||
| word += char | |||
| words.append(word) | |||
| word = "" | |||
| elif label[0] == "S": | |||
| if word != "": | |||
| words.append(word) | |||
| word = "" | |||
| words.append(char) | |||
| else: | |||
| raise ValueError("invalid label") | |||
| outputs.append(" ".join(words)) | |||
| def _make_output(self, results, infer_input): | |||
| if self.infer_type == "seq_label": | |||
| outputs = make_seq_label_output(results, infer_input) | |||
| elif self.infer_type == "text_class": | |||
| outputs = make_class_output(results, infer_input) | |||
| else: | |||
| raise ValueError("fail to make outputs with infer type {}".format(self.infer_type)) | |||
| return outputs | |||
| def make_seq_label_output(result, infer_input): | |||
| """ | |||
| Transform model output into user-friendly contents. | |||
| :param result: 1-D list of strings. (model output) | |||
| :param infer_input: 2-D list of string (model input) | |||
| :return outputs: | |||
| """ | |||
| return result | |||
| def make_class_output(result, infer_input): | |||
| return result | |||
| def interpret_word_seg_results(infer_input, results): | |||
| """ | |||
| Transform model output into user-friendly contents. | |||
| Example: In CWS, convert <BMES> labeling into segmented text. | |||
| :param results: list of strings. (model output) | |||
| :param infer_input: 2-D list of string (model input) | |||
| :return output: list of strings | |||
| """ | |||
| outputs = [] | |||
| for sent_char, sent_label in zip(infer_input, results): | |||
| words = [] | |||
| word = "" | |||
| for char, label in zip(sent_char, sent_label): | |||
| if label[0] == "B": | |||
| if word != "": | |||
| words.append(word) | |||
| word = char | |||
| elif label[0] == "M": | |||
| word += char | |||
| elif label[0] == "E": | |||
| word += char | |||
| words.append(word) | |||
| word = "" | |||
| elif label[0] == "S": | |||
| if word != "": | |||
| words.append(word) | |||
| word = "" | |||
| words.append(char) | |||
| else: | |||
| raise ValueError("invalid label") | |||
| outputs.append(" ".join(words)) | |||
| return outputs | |||
| @@ -15,12 +15,17 @@ class CNNText(torch.nn.Module): | |||
| Classification.' | |||
| """ | |||
| def __init__(self, class_num=9, | |||
| kernel_nums=[100, 100, 100], kernel_sizes=[3, 4, 5], | |||
| embed_num=1000, embed_dim=300, pretrained_embed=None, | |||
| drop_prob=0.5): | |||
| def __init__(self, args): | |||
| super(CNNText, self).__init__() | |||
| class_num = args["num_classes"] | |||
| kernel_nums = [100, 100, 100] | |||
| kernel_sizes = [3, 4, 5] | |||
| embed_num = args["vocab_size"] | |||
| embed_dim = 300 | |||
| pretrained_embed = None | |||
| drop_prob = 0.5 | |||
| # no support for pre-trained embedding currently | |||
| self.embed = nn.Embedding(embed_num, embed_dim, padding_idx=0) | |||
| self.conv_pool = ConvMaxpool( | |||
| @@ -56,3 +56,49 @@ class SeqLabeling(BaseModel): | |||
| """ | |||
| tag_seq = self.Crf.viterbi_decode(x, mask) | |||
| return tag_seq | |||
| class AdvSeqLabel(SeqLabeling): | |||
| """ | |||
| Advanced Sequence Labeling Model | |||
| """ | |||
| def __init__(self, args, emb=None): | |||
| super(AdvSeqLabel, self).__init__(args) | |||
| vocab_size = args["vocab_size"] | |||
| word_emb_dim = args["word_emb_dim"] | |||
| hidden_dim = args["rnn_hidden_units"] | |||
| num_classes = args["num_classes"] | |||
| self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | |||
| self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim, num_layers=3, dropout=0.3, bidirectional=True) | |||
| self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | |||
| self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | |||
| self.relu = torch.nn.ReLU() | |||
| self.drop = torch.nn.Dropout(0.3) | |||
| self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | |||
| self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | |||
| def forward(self, x): | |||
| """ | |||
| :param x: LongTensor, [batch_size, mex_len] | |||
| :return y: [batch_size, mex_len, tag_size] | |||
| """ | |||
| batch_size = x.size(0) | |||
| max_len = x.size(1) | |||
| x = self.Embedding(x) | |||
| # [batch_size, max_len, word_emb_dim] | |||
| x = self.Rnn(x) | |||
| # [batch_size, max_len, hidden_size * direction] | |||
| x = x.contiguous() | |||
| x = x.view(batch_size * max_len, -1) | |||
| x = self.Linear1(x) | |||
| x = self.batch_norm(x) | |||
| x = self.relu(x) | |||
| x = self.drop(x) | |||
| x = self.Linear2(x) | |||
| x = x.view(batch_size, max_len, -1) | |||
| # [batch_size, max_len, num_classes] | |||
| return x | |||
| @@ -89,5 +89,20 @@ rnn_hidden_units = 100 | |||
| rnn_layers = 1 | |||
| rnn_bi_direction = true | |||
| word_emb_dim = 100 | |||
| vocab_size = 52 | |||
| num_classes = 22 | |||
| vocab_size = 53 | |||
| num_classes = 27 | |||
| [text_class] | |||
| epochs = 1 | |||
| batch_size = 10 | |||
| pickle_path = "./data_for_tests/" | |||
| validate = false | |||
| save_best_dev = false | |||
| model_saved_path = "./data_for_tests/" | |||
| use_cuda = true | |||
| learn_rate = 1e-3 | |||
| momentum = 0.9 | |||
| [text_class_model] | |||
| vocab_size = 867 | |||
| num_classes = 18 | |||
| @@ -123,6 +123,160 @@ | |||
| 张 S-q | |||
| ) S-w | |||
| 迈 B-v | |||
| 向 E-v | |||
| 充 B-v | |||
| 满 E-v | |||
| 希 B-n | |||
| 望 E-n | |||
| 的 S-u | |||
| 新 S-a | |||
| 世 B-n | |||
| 纪 E-n | |||
| — B-w | |||
| — E-w | |||
| 一 B-t | |||
| 九 M-t | |||
| 九 M-t | |||
| 八 M-t | |||
| 年 E-t | |||
| 新 B-t | |||
| 年 E-t | |||
| 讲 B-n | |||
| 话 E-n | |||
| ( S-w | |||
| 附 S-v | |||
| 图 B-n | |||
| 片 E-n | |||
| 1 S-m | |||
| 张 S-q | |||
| ) S-w | |||
| 迈 B-v | |||
| 向 E-v | |||
| 充 B-v | |||
| 满 E-v | |||
| 希 B-n | |||
| 望 E-n | |||
| 的 S-u | |||
| 新 S-a | |||
| 世 B-n | |||
| 纪 E-n | |||
| — B-w | |||
| — E-w | |||
| 一 B-t | |||
| 九 M-t | |||
| 九 M-t | |||
| 八 M-t | |||
| 年 E-t | |||
| 新 B-t | |||
| 年 E-t | |||
| 讲 B-n | |||
| 话 E-n | |||
| ( S-w | |||
| 附 S-v | |||
| 图 B-n | |||
| 片 E-n | |||
| 1 S-m | |||
| 张 S-q | |||
| ) S-w | |||
| 中 B-nt | |||
| 共 M-nt | |||
| 中 M-nt | |||
| 央 E-nt | |||
| 总 B-n | |||
| 书 M-n | |||
| 记 E-n | |||
| 、 S-w | |||
| 国 B-n | |||
| 家 E-n | |||
| 主 B-n | |||
| 席 E-n | |||
| 江 B-nr | |||
| 泽 M-nr | |||
| 民 E-nr | |||
| ( S-w | |||
| 一 B-t | |||
| 九 M-t | |||
| 九 M-t | |||
| 七 M-t | |||
| 年 E-t | |||
| 十 B-t | |||
| 二 M-t | |||
| 月 E-t | |||
| 三 B-t | |||
| 十 M-t | |||
| 一 M-t | |||
| 日 E-t | |||
| ) S-w | |||
| 1 B-t | |||
| 2 M-t | |||
| 月 E-t | |||
| 3 B-t | |||
| 1 M-t | |||
| 日 E-t | |||
| , S-w | |||
| 迈 B-v | |||
| 向 E-v | |||
| 充 B-v | |||
| 满 E-v | |||
| 希 B-n | |||
| 望 E-n | |||
| 的 S-u | |||
| 新 S-a | |||
| 世 B-n | |||
| 纪 E-n | |||
| — B-w | |||
| — E-w | |||
| 一 B-t | |||
| 九 M-t | |||
| 九 M-t | |||
| 八 M-t | |||
| 年 E-t | |||
| 新 B-t | |||
| 年 E-t | |||
| 讲 B-n | |||
| 话 E-n | |||
| ( S-w | |||
| 附 S-v | |||
| 图 B-n | |||
| 片 E-n | |||
| 1 S-m | |||
| 张 S-q | |||
| ) S-w | |||
| 迈 B-v | |||
| 向 E-v | |||
| 充 B-v | |||
| 满 E-v | |||
| 希 B-n | |||
| 望 E-n | |||
| 的 S-u | |||
| 新 S-a | |||
| 世 B-n | |||
| 纪 E-n | |||
| — B-w | |||
| — E-w | |||
| 一 B-t | |||
| 九 M-t | |||
| 九 M-t | |||
| 八 M-t | |||
| 年 E-t | |||
| 新 B-t | |||
| 年 E-t | |||
| 讲 B-n | |||
| 话 E-n | |||
| ( S-w | |||
| 附 S-v | |||
| 图 B-n | |||
| 片 E-n | |||
| 1 S-m | |||
| 张 S-q | |||
| ) S-w | |||
| 迈 B-v | |||
| 向 E-v | |||
| 充 B-v | |||
| @@ -0,0 +1,137 @@ | |||
| import _pickle | |||
| import os | |||
| import numpy as np | |||
| import torch | |||
| from fastNLP.core.tester import SeqLabelTester | |||
| from fastNLP.core.trainer import SeqLabelTrainer | |||
| from fastNLP.loader.preprocess import POSPreprocess | |||
| from fastNLP.models.sequence_modeling import AdvSeqLabel | |||
| class MyNERTrainer(SeqLabelTrainer): | |||
| def __init__(self, train_args): | |||
| super(MyNERTrainer, self).__init__(train_args) | |||
| self.scheduler = None | |||
| def define_optimizer(self): | |||
| """ | |||
| override | |||
| :return: | |||
| """ | |||
| self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) | |||
| self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=3000, gamma=0.5) | |||
| def update(self): | |||
| """ | |||
| override | |||
| :return: | |||
| """ | |||
| self.optimizer.step() | |||
| self.scheduler.step() | |||
| def _create_validator(self, valid_args): | |||
| return MyNERTester(valid_args) | |||
| def best_eval_result(self, validator): | |||
| accuracy = validator.metrics() | |||
| if accuracy > self.best_accuracy: | |||
| self.best_accuracy = accuracy | |||
| return True | |||
| else: | |||
| return False | |||
| class MyNERTester(SeqLabelTester): | |||
| def __init__(self, test_args): | |||
| super(MyNERTester, self).__init__(test_args) | |||
| def _evaluate(self, prediction, batch_y, seq_len): | |||
| """ | |||
| :param prediction: [batch_size, seq_len, num_classes] | |||
| :param batch_y: [batch_size, seq_len] | |||
| :param seq_len: [batch_size] | |||
| :return: | |||
| """ | |||
| summ = 0 | |||
| correct = 0 | |||
| _, indices = torch.max(prediction, 2) | |||
| for p, y, l in zip(indices, batch_y, seq_len): | |||
| summ += l | |||
| correct += np.sum(p[:l].cpu().numpy() == y[:l].cpu().numpy()) | |||
| return float(correct / summ) | |||
| def evaluate(self, predict, truth): | |||
| return self._evaluate(predict, truth, self.seq_len) | |||
| def metrics(self): | |||
| return np.mean(self.eval_history) | |||
| def show_matrices(self): | |||
| return "dev accuracy={:.2f}".format(float(self.metrics())) | |||
| def embedding_process(emb_file, word_dict, emb_dim, emb_pkl): | |||
| if os.path.exists(emb_pkl): | |||
| with open(emb_pkl, "rb") as f: | |||
| embedding_np = _pickle.load(f) | |||
| return embedding_np | |||
| with open(emb_file, "r", encoding="utf-8") as f: | |||
| embedding_np = np.random.uniform(-1, 1, size=(len(word_dict), emb_dim)) | |||
| for line in f: | |||
| line = line.strip().split() | |||
| if len(line) != emb_dim + 1: | |||
| continue | |||
| if line[0] in word_dict: | |||
| embedding_np[word_dict[line[0]]] = [float(i) for i in line[1:]] | |||
| with open(emb_pkl, "wb") as f: | |||
| _pickle.dump(embedding_np, f) | |||
| return embedding_np | |||
| def data_load(data_file): | |||
| with open(data_file, "r", encoding="utf-8") as f: | |||
| all_data = [] | |||
| sent = [] | |||
| label = [] | |||
| for line in f: | |||
| line = line.strip().split() | |||
| if not len(line) <= 1: | |||
| sent.append(line[0]) | |||
| label.append(line[1]) | |||
| else: | |||
| all_data.append([sent, label]) | |||
| sent = [] | |||
| label = [] | |||
| return all_data | |||
| data_path = "data_for_tests/people.txt" | |||
| pick_path = "data_for_tests/" | |||
| emb_path = "data_for_tests/emb50.txt" | |||
| save_path = "data_for_tests/" | |||
| if __name__ == "__main__": | |||
| data = data_load(data_path) | |||
| p = POSPreprocess(data, pickle_path=pick_path, train_dev_split=0.3) | |||
| # emb = embedding_process(emb_path, p.word2index, 50, os.path.join(pick_path, "embedding.pkl")) | |||
| emb = None | |||
| args = {"epochs": 20, | |||
| "batch_size": 1, | |||
| "pickle_path": pick_path, | |||
| "validate": True, | |||
| "save_best_dev": True, | |||
| "model_saved_path": save_path, | |||
| "use_cuda": True, | |||
| "vocab_size": p.vocab_size, | |||
| "num_classes": p.num_classes, | |||
| "word_emb_dim": 50, | |||
| "rnn_hidden_units": 100 | |||
| } | |||
| # emb = torch.Tensor(emb).float().cuda() | |||
| networks = AdvSeqLabel(args, emb) | |||
| trainer = MyNERTrainer(args) | |||
| trainer.train(network=networks) | |||
| print("Training finished!") | |||
| @@ -0,0 +1,129 @@ | |||
| import _pickle | |||
| import os | |||
| import torch | |||
| from fastNLP.core.inference import SeqLabelInfer | |||
| from fastNLP.core.trainer import SeqLabelTrainer | |||
| from fastNLP.loader.model_loader import ModelLoader | |||
| from fastNLP.models.sequence_modeling import AdvSeqLabel | |||
| class Decode(SeqLabelTrainer): | |||
| def __init__(self, args): | |||
| super(Decode, self).__init__(args) | |||
| def decoder(self, network, sents, model_path): | |||
| self.model = network | |||
| self.model.load_state_dict(torch.load(model_path)) | |||
| out_put = [] | |||
| self.mode(network, test=True) | |||
| for batch_x in sents: | |||
| prediction = self.data_forward(self.model, batch_x) | |||
| seq_tag = self.model.prediction(prediction, batch_x[1]) | |||
| out_put.append(list(seq_tag)[0]) | |||
| return out_put | |||
| def process_sent(sents, word2id): | |||
| sents_num = [] | |||
| for s in sents: | |||
| sent_num = [] | |||
| for c in s: | |||
| if c in word2id: | |||
| sent_num.append(word2id[c]) | |||
| else: | |||
| sent_num.append(word2id["<unk>"]) | |||
| sents_num.append(([sent_num], [len(sent_num)])) # batch_size is 1 | |||
| return sents_num | |||
| def process_tag(sents, tags, id2class): | |||
| Tags = [] | |||
| for ttt in tags: | |||
| Tags.append([id2class[t] for t in ttt]) | |||
| Segs = [] | |||
| PosNers = [] | |||
| for sent, tag in zip(sents, tags): | |||
| word__ = [] | |||
| lll__ = [] | |||
| for c, t in zip(sent, tag): | |||
| t = id2class[t] | |||
| l = t.split("-") | |||
| split_ = l[0] | |||
| pn = l[1] | |||
| if split_ == "S": | |||
| word__.append(c) | |||
| lll__.append(pn) | |||
| word_1 = "" | |||
| elif split_ == "E": | |||
| word_1 += c | |||
| word__.append(word_1) | |||
| lll__.append(pn) | |||
| word_1 = "" | |||
| elif split_ == "B": | |||
| word_1 = "" | |||
| word_1 += c | |||
| else: | |||
| word_1 += c | |||
| Segs.append(word__) | |||
| PosNers.append(lll__) | |||
| return Segs, PosNers | |||
| pickle_path = "data_for_tests/" | |||
| model_path = "data_for_tests/model_best_dev.pkl" | |||
| if __name__ == "__main__": | |||
| with open(os.path.join(pickle_path, "id2word.pkl"), "rb") as f: | |||
| id2word = _pickle.load(f) | |||
| with open(os.path.join(pickle_path, "word2id.pkl"), "rb") as f: | |||
| word2id = _pickle.load(f) | |||
| with open(os.path.join(pickle_path, "id2class.pkl"), "rb") as f: | |||
| id2class = _pickle.load(f) | |||
| sent = ["中共中央总书记、国家主席江泽民", | |||
| "逆向处理输入序列并返回逆序后的序列"] # here is input | |||
| args = {"epochs": 1, | |||
| "batch_size": 1, | |||
| "pickle_path": "data_for_tests/", | |||
| "validate": True, | |||
| "save_best_dev": True, | |||
| "model_saved_path": "data_for_tests/", | |||
| "use_cuda": False, | |||
| "vocab_size": len(word2id), | |||
| "num_classes": len(id2class), | |||
| "word_emb_dim": 50, | |||
| "rnn_hidden_units": 100, | |||
| } | |||
| """ | |||
| network = AdvSeqLabel(args, None) | |||
| decoder_ = Decode(args) | |||
| tags_num = decoder_.decoder(network, process_sent(sent, word2id), model_path=model_path) | |||
| output_seg, output_pn = process_tag(sent, tags_num, id2class) # here is output | |||
| print(output_seg) | |||
| print(output_pn) | |||
| """ | |||
| # Define the same model | |||
| model = AdvSeqLabel(args, None) | |||
| # Dump trained parameters into the model | |||
| ModelLoader.load_pytorch(model, "./data_for_tests/model_best_dev.pkl") | |||
| print("model loaded!") | |||
| # Inference interface | |||
| infer = SeqLabelInfer(pickle_path) | |||
| sent = [[ch for ch in s] for s in sent] | |||
| results = infer.predict(model, sent) | |||
| for res in results: | |||
| print(res) | |||
| print("Inference finished!") | |||
| @@ -112,5 +112,5 @@ def train_and_test(): | |||
| if __name__ == "__main__": | |||
| # train_and_test() | |||
| infer() | |||
| train_and_test() | |||
| # infer() | |||
| @@ -1,9 +1,18 @@ | |||
| from fastNLP.fastnlp import FastNLP | |||
| def foo(): | |||
| def word_seg(): | |||
| nlp = FastNLP("./data_for_tests/") | |||
| nlp.load("zh_pos_tag_model") | |||
| nlp.load("seq_label_model") | |||
| text = "这是最好的基于深度学习的中文分词系统。" | |||
| result = nlp.run(text) | |||
| print(result) | |||
| print("FastNLP finished!") | |||
| def text_class(): | |||
| nlp = FastNLP("./data_for_tests/") | |||
| nlp.load("text_class_model") | |||
| text = "这是最好的基于深度学习的中文分词系统。" | |||
| result = nlp.run(text) | |||
| print(result) | |||
| @@ -11,4 +20,4 @@ def foo(): | |||
| if __name__ == "__main__": | |||
| foo() | |||
| text_class() | |||
| @@ -5,6 +5,7 @@ import os | |||
| from fastNLP.core.inference import ClassificationInfer | |||
| from fastNLP.core.trainer import ClassificationTrainer | |||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.loader.dataset_loader import ClassDatasetLoader | |||
| from fastNLP.loader.model_loader import ModelLoader | |||
| from fastNLP.loader.preprocess import ClassPreprocess | |||
| @@ -29,9 +30,13 @@ def infer(): | |||
| print("vocabulary size:", vocab_size) | |||
| print("number of classes:", n_classes) | |||
| model_args = ConfigSection() | |||
| ConfigLoader.load_config("data_for_tests/config", {"text_class_model": model_args}) | |||
| # construct model | |||
| print("Building model...") | |||
| cnn = CNNText(class_num=n_classes, embed_num=vocab_size) | |||
| cnn = CNNText(model_args) | |||
| # Dump trained parameters into the model | |||
| ModelLoader.load_pytorch(cnn, "./data_for_tests/saved_model.pkl") | |||
| print("model loaded!") | |||
| @@ -42,6 +47,9 @@ def infer(): | |||
| def train(): | |||
| train_args, model_args = ConfigSection(), ConfigSection() | |||
| ConfigLoader.load_config("data_for_tests/config", {"text_class": train_args, "text_class_model": model_args}) | |||
| # load dataset | |||
| print("Loading data...") | |||
| ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file)) | |||
| @@ -56,19 +64,11 @@ def train(): | |||
| # construct model | |||
| print("Building model...") | |||
| cnn = CNNText(class_num=n_classes, embed_num=vocab_size) | |||
| cnn = CNNText(model_args) | |||
| # train | |||
| print("Training...") | |||
| train_args = { | |||
| "epochs": 1, | |||
| "batch_size": 10, | |||
| "pickle_path": data_dir, | |||
| "validate": False, | |||
| "save_best_dev": False, | |||
| "model_saved_path": "./data_for_tests/", | |||
| "use_cuda": True | |||
| } | |||
| trainer = ClassificationTrainer(train_args) | |||
| trainer.train(cnn) | |||