| @@ -15,15 +15,26 @@ class Action(object): | |||
| raise NotImplementedError | |||
| def log(self, args): | |||
| self.logger.log(args) | |||
| """ | |||
| Basic operations shared between Trainer and Tester. | |||
| """ | |||
| print("call logger.log") | |||
| def batchify(self, X, Y=None): | |||
| # a generator | |||
| raise NotImplementedError | |||
| """ | |||
| :param X: | |||
| :param Y: | |||
| :return iteration:int, the number of step in each epoch | |||
| generator:generator, to generate batch inputs | |||
| """ | |||
| data = X | |||
| if Y is not None: | |||
| data = [X, Y] | |||
| return 2, self._batch_generate(data) | |||
| def _batch_generate(self, data): | |||
| step = 10 | |||
| for i in range(2): | |||
| start = i * step | |||
| end = (i + 1) * step | |||
| yield data[0][start:end], data[1][start:end] | |||
| def make_log(self, *args): | |||
| raise NotImplementedError | |||
| return "log" | |||
| @@ -12,7 +12,7 @@ class Tester(Action): | |||
| """ | |||
| super(Tester, self).__init__() | |||
| self.test_args = test_args | |||
| self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} | |||
| # self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} | |||
| self.mean_loss = None | |||
| self.output = None | |||
| @@ -54,4 +54,4 @@ class Tester(Action): | |||
| def make_output(self, batch_output): | |||
| # construct full prediction with batch outputs | |||
| raise NotImplementedError | |||
| return np.concatenate((batch_output[0], batch_output[1]), axis=0) | |||
| @@ -1,5 +1,5 @@ | |||
| from action.action import Action | |||
| from action.tester import Tester | |||
| from .action import Action | |||
| from .tester import Tester | |||
| class Trainer(Action): | |||
| @@ -13,10 +13,10 @@ class Trainer(Action): | |||
| """ | |||
| super(Trainer, self).__init__() | |||
| self.train_args = train_args | |||
| self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()} | |||
| # self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()} | |||
| self.n_epochs = self.train_args.epochs | |||
| self.validate = True | |||
| self.save_when_better = True | |||
| self.validate = self.train_args.validate | |||
| self.save_when_better = self.train_args.save_when_better | |||
| def train(self, network, data, dev_data): | |||
| X, Y = network.prepare_input(data) | |||
| @@ -51,10 +51,10 @@ class Trainer(Action): | |||
| # finish training | |||
| def make_log(self, *args): | |||
| raise NotImplementedError | |||
| print("logged") | |||
| def make_valid_log(self, *args): | |||
| raise NotImplementedError | |||
| print("logged") | |||
| def save_model(self, model): | |||
| raise NotImplementedError | |||
| print("model saved") | |||
| @@ -1,3 +1,6 @@ | |||
| import numpy as np | |||
| class BaseModel(object): | |||
| """base model for all models""" | |||
| @@ -5,6 +8,10 @@ class BaseModel(object): | |||
| pass | |||
| def prepare_input(self, data): | |||
| """ | |||
| :param data: str, raw input vector(?) | |||
| :return (X, Y): tuple, input features and labels | |||
| """ | |||
| raise NotImplementedError | |||
| def mode(self, test=False): | |||
| @@ -20,6 +27,33 @@ class BaseModel(object): | |||
| raise NotImplementedError | |||
| class ToyModel(BaseModel): | |||
| """This is for code testing.""" | |||
| def __init__(self): | |||
| super(ToyModel, self).__init__() | |||
| self.test_mode = False | |||
| self.weight = np.random.rand(5, 1) | |||
| self.bias = np.random.rand() | |||
| self._loss = 0 | |||
| def prepare_input(self, data): | |||
| return data[:, :-1], data[:, -1] | |||
| def mode(self, test=False): | |||
| self.test_mode = test | |||
| def data_forward(self, x): | |||
| return np.matmul(x, self.weight) + self.bias | |||
| def grad_backward(self): | |||
| print("loss gradient backward") | |||
| def loss(self, pred, truth): | |||
| self._loss = np.mean(np.square(pred - truth)) | |||
| return self._loss | |||
| class Vocabulary(object): | |||
| """ | |||
| A collection of lookup tables. | |||
| @@ -0,0 +1,21 @@ | |||
| from collections import namedtuple | |||
| import numpy as np | |||
| from action.trainer import Trainer | |||
| from model.base_model import ToyModel | |||
| def test_trainer(): | |||
| Config = namedtuple("config", ["epochs", "validate", "save_when_better"]) | |||
| train_config = Config(epochs=5, validate=True, save_when_better=True) | |||
| trainer = Trainer(train_config) | |||
| net = ToyModel() | |||
| data = np.random.rand(20, 6) | |||
| dev_data = np.random.rand(20, 6) | |||
| trainer.train(net, data, dev_data) | |||
| if __name__ == "__main__": | |||
| test_trainer() | |||