| @@ -1,3 +1,4 @@ | |||||
| class Action(object): | class Action(object): | ||||
| """ | """ | ||||
| base class for Trainer and Tester | base class for Trainer and Tester | ||||
| @@ -14,6 +14,7 @@ class Tester(Action): | |||||
| self.test_args = test_args | self.test_args = test_args | ||||
| self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} | self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} | ||||
| self.mean_loss = None | self.mean_loss = None | ||||
| self.output = None | |||||
| def test(self, network, data): | def test(self, network, data): | ||||
| # transform into network input and label | # transform into network input and label | ||||
| @@ -22,6 +23,7 @@ class Tester(Action): | |||||
| # split into batches by self.batch_size | # split into batches by self.batch_size | ||||
| iterations, test_batch_generator = self.batchify(X, Y) | iterations, test_batch_generator = self.batchify(X, Y) | ||||
| batch_output = list() | |||||
| loss_history = list() | loss_history = list() | ||||
| # turn on the testing mode of the network | # turn on the testing mode of the network | ||||
| network.mode(test=True) | network.mode(test=True) | ||||
| @@ -31,6 +33,7 @@ class Tester(Action): | |||||
| # forward pass from tests input to predicted output | # forward pass from tests input to predicted output | ||||
| prediction = network.data_forward(batch_x) | prediction = network.data_forward(batch_x) | ||||
| batch_output.append(prediction) | |||||
| # get the loss | # get the loss | ||||
| loss = network.loss(batch_y, prediction) | loss = network.loss(batch_y, prediction) | ||||
| @@ -39,7 +42,16 @@ class Tester(Action): | |||||
| self.log(self.make_log(step, loss)) | self.log(self.make_log(step, loss)) | ||||
| self.mean_loss = np.mean(np.array(loss_history)) | self.mean_loss = np.mean(np.array(loss_history)) | ||||
| self.output = self.make_output(batch_output) | |||||
| @property | @property | ||||
| def loss(self): | def loss(self): | ||||
| return self.mean_loss | return self.mean_loss | ||||
| @property | |||||
| def result(self): | |||||
| return self.output | |||||
| def make_output(self, batch_output): | |||||
| # construct full prediction with batch outputs | |||||
| raise NotImplementedError | |||||
| @@ -1,4 +1,5 @@ | |||||
| from action.action import Action | from action.action import Action | ||||
| from action.tester import Tester | |||||
| class Trainer(Action): | class Trainer(Action): | ||||
| @@ -6,9 +7,54 @@ class Trainer(Action): | |||||
| Trainer for common training logic of all models | Trainer for common training logic of all models | ||||
| """ | """ | ||||
| def __init__(self, arg): | |||||
| def __init__(self, train_args): | |||||
| """ | |||||
| :param train_args: namedtuple | |||||
| """ | |||||
| super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
| self.arg = arg | |||||
| self.train_args = train_args | |||||
| self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()} | |||||
| self.n_epochs = self.train_args.epochs | |||||
| self.validate = True | |||||
| self.save_when_better = True | |||||
| def train(self, args): | |||||
| def train(self, network, data, dev_data): | |||||
| X, Y = network.prepare_input(data) | |||||
| iterations, train_batch_generator = self.batchify(X, Y) | |||||
| loss_history = list() | |||||
| network.mode(test=False) | |||||
| test_args = "..." | |||||
| evaluator = Tester(test_args) | |||||
| best_loss = 1e10 | |||||
| for epoch in range(self.n_epochs): | |||||
| for step in range(iterations): | |||||
| batch_x, batch_y = train_batch_generator.__next__() | |||||
| prediction = network.data_forward(batch_x) | |||||
| loss = network.loss(batch_y, prediction) | |||||
| network.grad_backward() | |||||
| loss_history.append(loss) | |||||
| self.log(self.make_log(epoch, step, loss)) | |||||
| # evaluate over dev set | |||||
| if self.validate: | |||||
| evaluator.test(network, dev_data) | |||||
| self.log(self.make_valid_log(epoch, evaluator.loss)) | |||||
| if evaluator.loss < best_loss: | |||||
| best_loss = evaluator.loss | |||||
| if self.save_when_better: | |||||
| self.save_model(network) | |||||
| # finish training | |||||
| def make_log(self, *args): | |||||
| raise NotImplementedError | |||||
| def make_valid_log(self, *args): | |||||
| raise NotImplementedError | |||||
| def save_model(self, model): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @@ -18,3 +18,30 @@ class BaseModel(object): | |||||
| def loss(self, pred, truth): | def loss(self, pred, truth): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| class Vocabulary(object): | |||||
| """ | |||||
| A collection of lookup tables. | |||||
| """ | |||||
| def __init__(self): | |||||
| self.word_set = None | |||||
| self.word2idx = None | |||||
| self.emb_matrix = None | |||||
| def lookup(self, word): | |||||
| if word in self.word_set: | |||||
| return self.emb_matrix[self.word2idx[word]] | |||||
| return LookupError("The key " + word + " does not exist.") | |||||
| class Document(object): | |||||
| """ | |||||
| contains a sequence of tokens | |||||
| each token is a character with linguistic attributes | |||||
| """ | |||||
| def __init__(self): | |||||
| # wrap pandas.dataframe | |||||
| self.dataframe = None | |||||