| @@ -1,3 +1,4 @@ | |||
| class Action(object): | |||
| """ | |||
| base class for Trainer and Tester | |||
| @@ -14,6 +14,7 @@ class Tester(Action): | |||
| self.test_args = test_args | |||
| self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} | |||
| self.mean_loss = None | |||
| self.output = None | |||
| def test(self, network, data): | |||
| # transform into network input and label | |||
| @@ -22,6 +23,7 @@ class Tester(Action): | |||
| # split into batches by self.batch_size | |||
| iterations, test_batch_generator = self.batchify(X, Y) | |||
| batch_output = list() | |||
| loss_history = list() | |||
| # turn on the testing mode of the network | |||
| network.mode(test=True) | |||
| @@ -31,6 +33,7 @@ class Tester(Action): | |||
| # forward pass from tests input to predicted output | |||
| prediction = network.data_forward(batch_x) | |||
| batch_output.append(prediction) | |||
| # get the loss | |||
| loss = network.loss(batch_y, prediction) | |||
| @@ -39,7 +42,16 @@ class Tester(Action): | |||
| self.log(self.make_log(step, loss)) | |||
| self.mean_loss = np.mean(np.array(loss_history)) | |||
| self.output = self.make_output(batch_output) | |||
| @property | |||
| def loss(self): | |||
| return self.mean_loss | |||
| @property | |||
| def result(self): | |||
| return self.output | |||
| def make_output(self, batch_output): | |||
| # construct full prediction with batch outputs | |||
| raise NotImplementedError | |||
| @@ -1,4 +1,5 @@ | |||
| from action.action import Action | |||
| from action.tester import Tester | |||
| class Trainer(Action): | |||
| @@ -6,9 +7,54 @@ class Trainer(Action): | |||
| Trainer for common training logic of all models | |||
| """ | |||
| def __init__(self, arg): | |||
| def __init__(self, train_args): | |||
| """ | |||
| :param train_args: namedtuple | |||
| """ | |||
| super(Trainer, self).__init__() | |||
| self.arg = arg | |||
| self.train_args = train_args | |||
| self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()} | |||
| self.n_epochs = self.train_args.epochs | |||
| self.validate = True | |||
| self.save_when_better = True | |||
| def train(self, args): | |||
| def train(self, network, data, dev_data): | |||
| X, Y = network.prepare_input(data) | |||
| iterations, train_batch_generator = self.batchify(X, Y) | |||
| loss_history = list() | |||
| network.mode(test=False) | |||
| test_args = "..." | |||
| evaluator = Tester(test_args) | |||
| best_loss = 1e10 | |||
| for epoch in range(self.n_epochs): | |||
| for step in range(iterations): | |||
| batch_x, batch_y = train_batch_generator.__next__() | |||
| prediction = network.data_forward(batch_x) | |||
| loss = network.loss(batch_y, prediction) | |||
| network.grad_backward() | |||
| loss_history.append(loss) | |||
| self.log(self.make_log(epoch, step, loss)) | |||
| # evaluate over dev set | |||
| if self.validate: | |||
| evaluator.test(network, dev_data) | |||
| self.log(self.make_valid_log(epoch, evaluator.loss)) | |||
| if evaluator.loss < best_loss: | |||
| best_loss = evaluator.loss | |||
| if self.save_when_better: | |||
| self.save_model(network) | |||
| # finish training | |||
| def make_log(self, *args): | |||
| raise NotImplementedError | |||
| def make_valid_log(self, *args): | |||
| raise NotImplementedError | |||
| def save_model(self, model): | |||
| raise NotImplementedError | |||
| @@ -18,3 +18,30 @@ class BaseModel(object): | |||
| def loss(self, pred, truth): | |||
| raise NotImplementedError | |||
| class Vocabulary(object): | |||
| """ | |||
| A collection of lookup tables. | |||
| """ | |||
| def __init__(self): | |||
| self.word_set = None | |||
| self.word2idx = None | |||
| self.emb_matrix = None | |||
| def lookup(self, word): | |||
| if word in self.word_set: | |||
| return self.emb_matrix[self.word2idx[word]] | |||
| return LookupError("The key " + word + " does not exist.") | |||
| class Document(object): | |||
| """ | |||
| contains a sequence of tokens | |||
| each token is a character with linguistic attributes | |||
| """ | |||
| def __init__(self): | |||
| # wrap pandas.dataframe | |||
| self.dataframe = None | |||