- Action collects shared operations: data_forward, mode, pad, make_batch - Trainer and Tester receives Action as a parameter - seq_labeling works in such settingtags/v0.1.0
| @@ -1,16 +1,129 @@ | |||
| """ | |||
| This file defines Action(s) and sample methods. | |||
| """ | |||
| from collections import Counter | |||
| import torch | |||
| import numpy as np | |||
| import _pickle | |||
| class Action(object): | |||
| """ | |||
| base class for Trainer and Tester | |||
| Operations shared by Trainer, Tester, and Inference. | |||
| This is designed for reducing replicate codes. | |||
| - prepare_input: data preparation before a forward pass. | |||
| - make_batch: produce a min-batch of data. @staticmethod | |||
| - pad: padding method used in sequence modeling. @staticmethod | |||
| - mode: change network mode for either train or test. (for PyTorch) @staticmethod | |||
| - data_forward: a forward pass of the network. | |||
| The base Action shall define operations shared by as much task-specific Actions as possible. | |||
| """ | |||
| def __init__(self): | |||
| super(Action, self).__init__() | |||
| @staticmethod | |||
| def make_batch(iterator, data, output_length=True): | |||
| """ | |||
| 1. Perform batching from data and produce a batch of training data. | |||
| 2. Add padding. | |||
| :param iterator: an iterator, (object that implements __next__ method) which returns the next sample. | |||
| :param data: list. Each entry is a sample, which is also a list of features and label(s). | |||
| E.g. | |||
| [ | |||
| [[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||
| [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||
| ... | |||
| ] | |||
| :param output_length: whether to output the original length of the sequence before padding. | |||
| :return (batch_x, seq_len): tuple of two elements, if output_length is true. | |||
| batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||
| seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||
| batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||
| return batch_x and batch_y, if output_length is False | |||
| """ | |||
| indices = next(iterator) | |||
| batch = [data[idx] for idx in indices] | |||
| batch_x = [sample[0] for sample in batch] | |||
| batch_y = [sample[1] for sample in batch] | |||
| batch_x_pad = Action.pad(batch_x) | |||
| batch_y_pad = Action.pad(batch_y) | |||
| if output_length: | |||
| seq_len = [len(x) for x in batch_x] | |||
| return (batch_x_pad, seq_len), batch_y_pad | |||
| else: | |||
| return batch_x_pad, batch_y_pad | |||
| @staticmethod | |||
| def pad(batch, fill=0): | |||
| """ | |||
| Pad a batch of samples to maximum length of this batch. | |||
| :param batch: list of list | |||
| :param fill: word index to pad, default 0. | |||
| :return: a padded batch | |||
| """ | |||
| max_length = max([len(x) for x in batch]) | |||
| for idx, sample in enumerate(batch): | |||
| if len(sample) < max_length: | |||
| batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||
| return batch | |||
| @staticmethod | |||
| def mode(model, test=False): | |||
| """ | |||
| Train mode or Test mode. This is for PyTorch currently. | |||
| :param model: | |||
| :param test: | |||
| """ | |||
| if test: | |||
| model.eval() | |||
| else: | |||
| model.train() | |||
| def data_forward(self, network, x): | |||
| """ | |||
| Forward pass of the data. | |||
| :param network: a model | |||
| :param x: input feature matrix and label vector | |||
| :return: output by the models | |||
| For PyTorch, just do "network(*x)" | |||
| """ | |||
| raise NotImplementedError | |||
| class SeqLabelAction(Action): | |||
| def __init__(self, action_args): | |||
| """ | |||
| Define task-specific member variables. | |||
| :param action_args: | |||
| """ | |||
| super(SeqLabelAction, self).__init__() | |||
| self.max_len = None | |||
| self.mask = None | |||
| self.best_accuracy = 0.0 | |||
| self.use_cuda = action_args["use_cuda"] | |||
| self.seq_len = None | |||
| self.batch_size = None | |||
| def data_forward(self, network, inputs): | |||
| # unpack the returned value from make_batch | |||
| if isinstance(inputs, tuple): | |||
| x = inputs[0] | |||
| self.seq_len = inputs[1] | |||
| else: | |||
| x = inputs | |||
| x = torch.Tensor(x).long() | |||
| if torch.cuda.is_available() and self.use_cuda: | |||
| x = x.cuda() | |||
| self.batch_size = x.size(0) | |||
| self.max_len = x.size(1) | |||
| y = network(x) | |||
| return y | |||
| def k_means_1d(x, k, max_iter=100): | |||
| """ | |||
| @@ -11,11 +11,12 @@ from fastNLP.core.action import RandomSampler, Batchifier | |||
| class BaseTester(Action): | |||
| """docstring for Tester""" | |||
| def __init__(self, test_args): | |||
| def __init__(self, test_args, action): | |||
| """ | |||
| :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||
| """ | |||
| super(BaseTester, self).__init__() | |||
| self.action = action | |||
| self.validate_in_training = test_args["validate_in_training"] | |||
| self.save_dev_data = None | |||
| self.save_output = test_args["save_output"] | |||
| @@ -38,18 +39,21 @@ class BaseTester(Action): | |||
| self.model = network | |||
| # turn on the testing mode; clean up the history | |||
| self.mode(network, test=True) | |||
| self.action.mode(network, test=True) | |||
| self.eval_history.clear() | |||
| self.batch_output.clear() | |||
| dev_data = self.prepare_input(self.pickle_path) | |||
| self.iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | |||
| iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | |||
| num_iter = len(dev_data) // self.batch_size | |||
| for step in range(num_iter): | |||
| batch_x, batch_y = self.make_batch(dev_data) | |||
| batch_x, batch_y = self.action.make_batch(iterator, dev_data) | |||
| prediction = self.action.data_forward(network, batch_x) | |||
| prediction = self.data_forward(network, batch_x) | |||
| eval_results = self.evaluate(prediction, batch_y) | |||
| if self.save_output: | |||
| @@ -64,53 +68,10 @@ class BaseTester(Action): | |||
| :return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). | |||
| """ | |||
| if self.save_dev_data is None: | |||
| data_dev = _pickle.load(open(data_path + "/data_dev.pkl", "rb")) | |||
| data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) | |||
| self.save_dev_data = data_dev | |||
| return self.save_dev_data | |||
| def make_batch(self, data, output_length=True): | |||
| """ | |||
| 1. Perform batching from data and produce a batch of training data. | |||
| 2. Add padding. | |||
| :param data: list. Each entry is a sample, which is also a list of features and label(s). | |||
| E.g. | |||
| [ | |||
| [[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||
| [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||
| ... | |||
| ] | |||
| :return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||
| batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||
| """ | |||
| indices = next(self.iterator) | |||
| batch = [data[idx] for idx in indices] | |||
| batch_x = [sample[0] for sample in batch] | |||
| batch_y = [sample[1] for sample in batch] | |||
| batch_x_pad = self.pad(batch_x) | |||
| batch_y_pad = self.pad(batch_y) | |||
| if output_length: | |||
| seq_len = [len(x) for x in batch_x] | |||
| return (batch_x_pad, seq_len), batch_y_pad | |||
| else: | |||
| return batch_x_pad, batch_y_pad | |||
| @staticmethod | |||
| def pad(batch, fill=0): | |||
| """ | |||
| Pad a batch of samples to maximum length. | |||
| :param batch: list of list | |||
| :param fill: word index to pad, default 0. | |||
| :return: a padded batch | |||
| """ | |||
| max_length = max([len(x) for x in batch]) | |||
| for idx, sample in enumerate(batch): | |||
| if len(sample) < max_length: | |||
| batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||
| return batch | |||
| def data_forward(self, network, data): | |||
| raise NotImplementedError | |||
| def evaluate(self, predict, truth): | |||
| raise NotImplementedError | |||
| @@ -118,14 +79,6 @@ class BaseTester(Action): | |||
| def metrics(self): | |||
| raise NotImplementedError | |||
| def mode(self, model, test=True): | |||
| """TODO: combine this function with Trainer ?? """ | |||
| if test: | |||
| model.eval() | |||
| else: | |||
| model.train() | |||
| self.eval_history.clear() | |||
| def show_matrices(self): | |||
| """ | |||
| This is called by Trainer to print evaluation on dev set. | |||
| @@ -139,43 +92,21 @@ class POSTester(BaseTester): | |||
| Tester for sequence labeling. | |||
| """ | |||
| def __init__(self, test_args): | |||
| def __init__(self, test_args, action): | |||
| """ | |||
| :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||
| """ | |||
| super(POSTester, self).__init__(test_args) | |||
| super(POSTester, self).__init__(test_args, action) | |||
| self.max_len = None | |||
| self.mask = None | |||
| self.batch_result = None | |||
| def data_forward(self, network, inputs): | |||
| """TODO: combine with Trainer | |||
| :param network: the PyTorch model | |||
| :param x: list of list, [batch_size, max_len] | |||
| :return y: [batch_size, num_classes] | |||
| """ | |||
| # unpack the returned value from make_batch | |||
| if isinstance(inputs, tuple): | |||
| x = inputs[0] | |||
| self.seq_len = inputs[1] | |||
| else: | |||
| x = inputs | |||
| x = torch.Tensor(x).long() | |||
| if torch.cuda.is_available() and self.use_cuda: | |||
| x = x.cuda() | |||
| self.batch_size = x.size(0) | |||
| self.max_len = x.size(1) | |||
| y = network(x) | |||
| return y | |||
| def evaluate(self, predict, truth): | |||
| truth = torch.Tensor(truth) | |||
| if torch.cuda.is_available() and self.use_cuda: | |||
| truth = truth.cuda() | |||
| loss = self.model.loss(predict, truth, self.seq_len) / self.batch_size | |||
| prediction = self.model.prediction(predict, self.seq_len) | |||
| loss = self.model.loss(predict, truth, self.action.seq_len) / self.batch_size | |||
| prediction = self.model.prediction(predict, self.action.seq_len) | |||
| results = torch.Tensor(prediction).view(-1,) | |||
| if torch.cuda.is_available() and self.use_cuda: | |||
| results = results.cuda() | |||
| @@ -18,17 +18,15 @@ class BaseTrainer(Action): | |||
| Trainer receives a model and data, and then performs training. | |||
| Subclasses must implement the following abstract methods: | |||
| - prepare_input | |||
| - mode | |||
| - define_optimizer | |||
| - data_forward | |||
| - grad_backward | |||
| - get_loss | |||
| """ | |||
| def __init__(self, train_args): | |||
| def __init__(self, train_args, action): | |||
| """ | |||
| :param train_args: dict of (key, value), or dict-like object. key is str. | |||
| :param action: an Action object that wrap most operations shared by Trainer, Tester, and Inference. | |||
| The base trainer requires the following keys: | |||
| - epochs: int, the number of epochs in training | |||
| @@ -37,6 +35,7 @@ class BaseTrainer(Action): | |||
| - pickle_path: str, the path to pickle files for pre-processing | |||
| """ | |||
| super(BaseTrainer, self).__init__() | |||
| self.action = action | |||
| self.n_epochs = train_args["epochs"] | |||
| self.batch_size = train_args["batch_size"] | |||
| self.pickle_path = train_args["pickle_path"] | |||
| @@ -72,14 +71,14 @@ class BaseTrainer(Action): | |||
| else: | |||
| self.model = network | |||
| data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path) | |||
| data_train = self.prepare_input(self.pickle_path) | |||
| # define tester over dev data | |||
| # TODO: more flexible | |||
| valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||
| default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||
| "save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||
| "use_cuda": self.use_cuda} | |||
| validator = POSTester(valid_args) | |||
| validator = POSTester(default_valid_args, self.action) | |||
| # main training epochs | |||
| iterations = len(data_train) // self.batch_size | |||
| @@ -88,14 +87,14 @@ class BaseTrainer(Action): | |||
| for epoch in range(1, self.n_epochs + 1): | |||
| # turn on network training mode; define optimizer; prepare batch iterator | |||
| self.mode(test=False) | |||
| self.iterator = iter(Batchifier(BucketSampler(data_train), self.batch_size, drop_last=True)) | |||
| self.action.mode(self.model, test=False) | |||
| iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True)) | |||
| # training iterations in one epoch | |||
| for step in range(iterations): | |||
| batch_x, batch_y = self.make_batch(data_train) | |||
| batch_x, batch_y = self.action.make_batch(iterator, data_train) | |||
| prediction = self.data_forward(network, batch_x) | |||
| prediction = self.action.data_forward(network, batch_x) | |||
| loss = self.get_loss(prediction, batch_y) | |||
| self.grad_backward(loss) | |||
| @@ -105,8 +104,6 @@ class BaseTrainer(Action): | |||
| print("[epoch {} step {}] train loss={:.2f}".format(epoch, step, loss.data)) | |||
| if self.validate: | |||
| if data_dev is None: | |||
| raise RuntimeError("No validation data provided.") | |||
| validator.test(network) | |||
| if self.save_best_dev and self.best_eval_result(validator): | |||
| @@ -118,19 +115,13 @@ class BaseTrainer(Action): | |||
| # finish training | |||
| def prepare_input(self, data_path): | |||
| data_train = _pickle.load(open(data_path + "data_train.pkl", "rb")) | |||
| data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) | |||
| data_test = _pickle.load(open(data_path + "data_test.pkl", "rb")) | |||
| embedding = _pickle.load(open(data_path + "embedding.pkl", "rb")) | |||
| return data_train, data_dev, data_test, embedding | |||
| def mode(self, test=False): | |||
| def prepare_input(self, pickle_path): | |||
| """ | |||
| Tell the network to be trained or not. | |||
| :param test: bool | |||
| This is reserved for task-specific processing. | |||
| :param data_path: | |||
| :return: | |||
| """ | |||
| raise NotImplementedError | |||
| return _pickle.load(open(pickle_path + "/data_train.pkl", "rb")) | |||
| def define_optimizer(self): | |||
| """ | |||
| @@ -146,17 +137,6 @@ class BaseTrainer(Action): | |||
| """ | |||
| raise NotImplementedError | |||
| def data_forward(self, network, x): | |||
| """ | |||
| Forward pass of the data. | |||
| :param network: a model | |||
| :param x: input feature matrix and label vector | |||
| :return: output by the models | |||
| For PyTorch, just do "network(*x)" | |||
| """ | |||
| raise NotImplementedError | |||
| def grad_backward(self, loss): | |||
| """ | |||
| Compute gradient with link rules. | |||
| @@ -187,50 +167,6 @@ class BaseTrainer(Action): | |||
| """ | |||
| raise NotImplementedError | |||
| def make_batch(self, data, output_length=True): | |||
| """ | |||
| 1. Perform batching from data and produce a batch of training data. | |||
| 2. Add padding. | |||
| :param data: list. Each entry is a sample, which is also a list of features and label(s). | |||
| E.g. | |||
| [ | |||
| [[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||
| [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||
| ... | |||
| ] | |||
| :return (batch_x, seq_len): tuple of two elements, if output_length is true. | |||
| batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||
| seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||
| batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||
| return batch_x and batch_y, if output_length is False | |||
| """ | |||
| indices = next(self.iterator) | |||
| batch = [data[idx] for idx in indices] | |||
| batch_x = [sample[0] for sample in batch] | |||
| batch_y = [sample[1] for sample in batch] | |||
| batch_x_pad = self.pad(batch_x) | |||
| batch_y_pad = self.pad(batch_y) | |||
| if output_length: | |||
| seq_len = [len(x) for x in batch_x] | |||
| return (batch_x_pad, seq_len), batch_y_pad | |||
| else: | |||
| return batch_x_pad, batch_y_pad | |||
| @staticmethod | |||
| def pad(batch, fill=0): | |||
| """ | |||
| Pad a batch of samples to maximum length. | |||
| :param batch: list of list | |||
| :param fill: word index to pad, default 0. | |||
| :return: a padded batch | |||
| """ | |||
| max_length = max([len(x) for x in batch]) | |||
| for idx, sample in enumerate(batch): | |||
| if len(sample) < max_length: | |||
| batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||
| return batch | |||
| def best_eval_result(self, validator): | |||
| """ | |||
| :param validator: a Tester instance | |||
| @@ -287,48 +223,14 @@ class POSTrainer(BaseTrainer): | |||
| Trainer for Sequence Modeling | |||
| """ | |||
| def __init__(self, train_args): | |||
| super(POSTrainer, self).__init__(train_args) | |||
| def __init__(self, train_args, action): | |||
| super(POSTrainer, self).__init__(train_args, action) | |||
| self.vocab_size = train_args["vocab_size"] | |||
| self.num_classes = train_args["num_classes"] | |||
| self.max_len = None | |||
| self.mask = None | |||
| self.best_accuracy = 0.0 | |||
| def prepare_input(self, data_path): | |||
| data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||
| data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||
| return data_train, data_dev, 0, 1 | |||
| def data_forward(self, network, inputs): | |||
| """ | |||
| :param network: the PyTorch model | |||
| :param inputs: list of list, [batch_size, max_len], | |||
| or tuple of (batch_x, seq_len), batch_x == [batch_size, max_len] | |||
| :return y: [batch_size, max_len, tag_size] | |||
| """ | |||
| # unpack the returned value from make_batch | |||
| if isinstance(inputs, tuple): | |||
| x = inputs[0] | |||
| self.seq_len = inputs[1] | |||
| else: | |||
| x = inputs | |||
| x = torch.Tensor(x).long() | |||
| if torch.cuda.is_available() and self.use_cuda: | |||
| x = x.cuda() | |||
| self.batch_size = x.size(0) | |||
| self.max_len = x.size(1) | |||
| y = network(x) | |||
| return y | |||
| def mode(self, test=False): | |||
| if test: | |||
| self.model.eval() | |||
| else: | |||
| self.model.train() | |||
| def define_optimizer(self): | |||
| self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) | |||
| @@ -349,14 +251,13 @@ class POSTrainer(BaseTrainer): | |||
| truth = torch.Tensor(truth) | |||
| if torch.cuda.is_available() and self.use_cuda: | |||
| truth = truth.cuda() | |||
| assert truth.shape == (self.batch_size, self.max_len) | |||
| assert truth.shape == (self.batch_size, self.action.max_len) | |||
| if self.loss_func is None: | |||
| if hasattr(self.model, "loss"): | |||
| self.loss_func = self.model.loss | |||
| else: | |||
| self.define_loss() | |||
| loss = self.loss_func(predict, truth, self.seq_len) | |||
| # print("loss={:.2f}".format(loss.data)) | |||
| loss = self.loss_func(predict, truth, self.action.seq_len) | |||
| return loss | |||
| def best_eval_result(self, validator): | |||
| @@ -367,36 +268,6 @@ class POSTrainer(BaseTrainer): | |||
| else: | |||
| return False | |||
| def make_batch(self, data, output_length=True): | |||
| """ | |||
| 1. Perform batching from data and produce a batch of training data. | |||
| 2. Add padding. | |||
| :param data: list. Each entry is a sample, which is also a list of features and label(s). | |||
| E.g. | |||
| [ | |||
| [[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||
| [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||
| ... | |||
| ] | |||
| :return (batch_x, seq_len): tuple of two elements, if output_length is true. | |||
| batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||
| seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||
| batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||
| return batch_x and batch_y, if output_length is False | |||
| """ | |||
| indices = next(self.iterator) | |||
| batch = [data[idx] for idx in indices] | |||
| batch_x = [sample[0] for sample in batch] | |||
| batch_y = [sample[1] for sample in batch] | |||
| batch_x_pad = self.pad(batch_x) | |||
| batch_y_pad = self.pad(batch_y) | |||
| if output_length: | |||
| seq_len = [len(x) for x in batch_x] | |||
| return (batch_x_pad, seq_len), batch_y_pad | |||
| else: | |||
| return batch_x_pad, batch_y_pad | |||
| class LanguageModelTrainer(BaseTrainer): | |||
| """ | |||
| @@ -2,6 +2,7 @@ import sys | |||
| sys.path.append("..") | |||
| from fastNLP.core.action import SeqLabelAction | |||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.core.trainer import POSTrainer | |||
| from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | |||
| @@ -57,7 +58,7 @@ def infer(): | |||
| print("Inference finished!") | |||
| def train_test(): | |||
| def train_and_test(): | |||
| # Config Loader | |||
| train_args = ConfigSection() | |||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | |||
| @@ -67,12 +68,14 @@ def train_test(): | |||
| train_data = pos_loader.load_lines() | |||
| # Preprocessor | |||
| p = POSPreprocess(train_data, pickle_path) | |||
| p = POSPreprocess(train_data, pickle_path, train_dev_split=0.5) | |||
| train_args["vocab_size"] = p.vocab_size | |||
| train_args["num_classes"] = p.num_classes | |||
| action = SeqLabelAction(train_args) | |||
| # Trainer | |||
| trainer = POSTrainer(train_args) | |||
| trainer = POSTrainer(train_args, action) | |||
| # Model | |||
| model = SeqLabeling(train_args) | |||
| @@ -100,7 +103,7 @@ def train_test(): | |||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||
| # Tester | |||
| tester = POSTester(test_args) | |||
| tester = POSTester(test_args, action) | |||
| # Start testing | |||
| tester.test(model) | |||
| @@ -111,5 +114,5 @@ def train_test(): | |||
| if __name__ == "__main__": | |||
| train_test() | |||
| # infer() | |||
| train_and_test() | |||